In [1]:
import os
import random

import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import pandas as pd
from einops import rearrange as re
from opt_einsum import contract as einsum
%load_ext line_profiler

In [2]:
from torch.nn import (
    Module,
    ModuleList,
    Sequential,
    Parameter,
    Linear, 
    Dropout,
    LayerNorm,
)

[attn all you need](https://arxiv.org/abs/1706.03762?context=cs)
[formal alg transformer](https://arxiv.org/abs/2207.09238)

[AI Explained 3D Viz](https://youtu.be/-9vVhYEXeyQ?t=456)

[Bloem impl](https://github.com/pbloem/former/blob/master/former/modules.py)
[harvard annotated](http://nlp.seas.harvard.edu/annotated-transformer/#encoder-and-decoder-stacks)  
[torch mha f](https://github.com/pytorch/pytorch/blob/dcf51885618e7d1d9aa6e628f3354f67ad82b446/torch/nn/functional.py#L4917)   
[einops examples](http://einops.rocks/pytorch-examples.html)  

[R-Drop](https://arxiv.org/pdf/2106.14448v2.pdf)

In [3]:
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
random.seed(0)
np.random.seed(0)

In [4]:
torch.set_printoptions(precision=4)

In [56]:
def heat(x):
    df = pd.DataFrame(x.detach().numpy())
    return df.style.background_gradient(cmap='Blues')  # .format('{:.0f}')

In [6]:
class Transformer(nn.Module):
    def __init__(self, d_model, n, vocab):
        super().__init__()
        self.encoder = self.stack(EncoderBlock, 6)
        self.decoder = self.stack(DecoderBlock, 6)
        self.head = Sequential(
            Linear(d_model, vocab.shape[0]),
            Softmax(dim=-1)
        )
       
    def forward(src, tgt, src_mask, tgt_mask):
        ctx = self.encoder(src, src_mask)
        return self.decoder(tgt, ctx, tgt_mask, ctx_mask)
   
    def stack(Layer, n):
        return Sequential([copy.deepcopy(l) for l in Layer()])
    
class EncoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.mha = MultiHeadAttention()
        self.ff = FeedForward()
        
    def forward(self, x, mask=None): # x -> src?
        return self.ff(self.mha(x, mask))
    
class DecoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.masked_mha = MultiHeadAttention()
        self.mha = MultiHeadAttention()
        self.ff = FeedForward()
        
    def forward(self, x, ctx=None, mask=None, ctx_mask=None):
        x = self.masked_mha(x, mask)
        return self.ff(self.mha(x, ctx, ctx, ctx_mask))
    
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return self.norm(x + self._ff(x))
    
    def _ff(self, x):
        return self.fc2(F.relu(self.fc1(x)))
    
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads: int = 8, emb_dim: int = 512):
        super().__init__()
        assert emb_dim % n_heads == 0  # d_h = d_q = d_k = d_v = emb_dim // n_heads
        self.proj_qkv = [Linear(emb_dim, emb_dim, bias=False) for _ in range(3)]
        self.proj_o = Linear(emb_dim, emb_dim, bias=False)
        self.norm = nn.LayerNorm(emb_dim)
        self.h, self.d = n_heads, emb_dim
        self.attn = torch.empty((1000 * 8, 256, 256))
        self.out = torch.empty((1000 * 8, 256, 64))

    def forward(self, q, k=None, v=None, mask=None):
        """ q, k, v: (batch, seq_len, emb_dim) mask: (seq_len, seq_len) """
        if k is None and v is None:
            k = v = q
        if mask is None:
            mask = torch.zeros((q.shape[-2]))
        return self.norm(q + self._mha(q, k, v, mask))
    
    def _mha(self, q, k, v, mask):
        q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
        q, k, v = (re(x, "b l (h d) -> b h l d", h=self.h) for x in (q, k, v))
        attn = einsum("...ij,...kj->...ik", q, k)
        attn = mask + torch.einsum("...ij,...kj->...ik", q, k)
        attn = F.softmax(attn / q.shape[-1] ** (1/2), dim=-1)
        out = einsum("...ij,...jk->...ik", attn, v)
        out = re(out, "b h n d -> b n (h d)")
        out = self.proj_o(out)
        return out

    def _mha_infer(self, q, k, v, mask):
        with torch.no_grad():
            q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
            q, k, v = (re(x, "b l (h d_h) -> (b h) l d_h", h=self.h) for x in (q, k, v))
            torch.bmm(q, re(k, "bh l d_h -> bh d_h l"), out=self.attn)
            self.attn += mask
            self.attn /= q.shape[-1] ** (1/2)
            self.attn = F.softmax(self.attn, dim=-1)
            torch.bmm(self.attn, v, out=self.out)
            return self.proj_o(re(self.out, "(b h) l d_h -> b l (h d_h)", h=self.h))

# Compare

In [39]:
emb = torch.rand((1_000, 256, 512))
mask = torch.triu(torch.full((256, 256), -torch.inf), diagonal=1)
mha = MultiHeadAttention()
mha.eval();

#### _mha vs _mha_infer

In [12]:
%%timeit
_ = mha._mha(emb, emb, emb, mask=mask)

3.02 s ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%%timeit
_ = mha._mha_infer(emb, emb, emb, mask=mask)

2.34 s ± 8.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%lprun -u 0.001 -f mha._mha mha(emb, emb, emb, mask=mask)

Timer unit: 0.001 s

Total time: 3.10783 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_97173/1043572812.py
Function: _mha at line 70

Line #      Hits         Time  Per Hit   % Time  Line Contents
    70                                               def _mha(self, q, k, v, mask):
    71         1        750.6    750.6     24.2          q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
    72         1          0.5      0.5      0.0          q, k, v = (re(x, "b l (h d) -> b h l d", h=self.h) for x in (q, k, v))
    73         1        406.1    406.1     13.1          attn = einsum("...ij,...kj->...ik", q, k)
    74         1        859.3    859.3     27.6          attn = mask + torch.einsum("...ij,...kj->...ik", q, k)
    75         1        621.7    621.7     20.0          attn = F.softmax(attn / q.shape[-1] ** (1/2), dim=-1)
    76         1        215.5    215.5      6.9          out = einsum("...ij,...jk->...ik", attn, v)
    77         1         

In [9]:
%lprun -u 0.001 -f mha._mha_infer mha._mha_infer(emb, emb, emb, mask=mask)

Timer unit: 0.001 s

Total time: 2.4406 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_97173/1043572812.py
Function: _mha_infer at line 81

Line #      Hits         Time  Per Hit   % Time  Line Contents
    81                                               def _mha_infer(self, q, k, v, mask):
    82         1          0.0      0.0      0.0          with torch.no_grad():
    83         1        730.9    730.9     29.9              q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
    84         1        236.3    236.3      9.7              q, k, v = (re(x, "b l (h d_h) -> (b h) l d_h", h=self.h) for x in (q, k, v))
    85         1        305.5    305.5     12.5              torch.bmm(q, re(k, "bh l d_h -> bh d_h l"), out=self.attn)
    86         1        149.5    149.5      6.1              self.attn += mask
    87         1        147.1    147.1      6.0              self.attn /= q.shape[-1] ** (1/2)
    88         1        385.8    385.8     15.8   

#### Pytorch [MultiheadAttention.forward](https://github.com/pytorch/pytorch/blob/bbe8d019f280478dc3b143f6988e3e5668499f28/torch/nn/modules/activation.py#L1010) local



In [50]:
th_emb = re(emb, 'b l d -> l b d')
bsz, tgt_len, embed_dim = emb.shape
num_heads, head_dim = 8, 64
in_proj_weight = nn.Parameter(torch.vstack([l.weight for l in mha.proj_qkv]))
out_proj_weight = mha.proj_o.weight
th_proj_o = Linear(512, 512, bias=False)
th_proj_o.weight = out_proj_weight

def th_mha():
    th_q, th_k, th_v = F._in_projection_packed(th_emb, th_emb, th_emb, in_proj_weight, None)
    th_q = th_q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    th_k = th_k.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    th_v = th_v.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    th_out, th_attn = F._scaled_dot_product_attention(th_q, th_k, th_v, mask, 0.0)
    th_out = th_out.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    th_out = th_proj_o(th_out)
    th_out = th_out.view(tgt_len, bsz, th_out.size(1))
    th_out = mha.norm(th_emb + th_out)
    return re(th_out, 'l b d -> b l d')

#### vs. Pytorch F.multi_head_attn

In [41]:
%%timeit
mha_out = mha(emb, mask=mask)

3.26 s ± 55.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
%%timeit
th_out = th_mha()

3.03 s ± 53.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [48]:
%lprun -u 0.001 -f mha.forward mha_out = mha(emb, mask=mask)

Timer unit: 0.001 s

Total time: 3.26755 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_97173/1043572812.py
Function: forward at line 62

Line #      Hits         Time  Per Hit   % Time  Line Contents
    62                                               def forward(self, q, k=None, v=None, mask=None):
    63                                                   """ q, k, v: (batch, seq_len, emb_dim) mask: (seq_len, seq_len) """
    64         1          0.0      0.0      0.0          if k is None and v is None:
    65         1          0.0      0.0      0.0              k = v = q
    66         1          0.0      0.0      0.0          if mask is None:
    67                                                       mask = torch.zeros((q.shape[-2]))
    68         1       3267.5   3267.5    100.0          return self.norm(q + self._mha(q, k, v, mask))

In [51]:
%lprun -u 0.001 -f th_mha th_out = th_mha()

Timer unit: 0.001 s

Total time: 3.18376 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_97173/486783080.py
Function: th_mha at line 9

Line #      Hits         Time  Per Hit   % Time  Line Contents
     9                                           def th_mha():
    10         1        754.5    754.5     23.7      th_q, th_k, th_v = F._in_projection_packed(th_emb, th_emb, th_emb, in_proj_weight, None)
    11         1         52.4     52.4      1.6      th_q = th_q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    12         1         53.5     53.5      1.7      th_k = th_k.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    13         1        149.7    149.7      4.7      th_v = th_v.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    14         1       1340.6   1340.6     42.1      th_out, th_attn = F._scaled_dot_product_attention(th_q, th_k, th_v, mask, 0.0)
    15         1        106.9    106.9  

In [55]:
torch.allclose(mha_out, th_out, atol=1e-6, equal_nan=True)

True