In [1]:
import os
import random
import copy
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import (
    Module,
    ModuleList,
    Sequential,
    Parameter,
    Linear, 
    Dropout,
    LayerNorm,
    Softmax,
)
import numpy as np
import pandas as pd
from einops import rearrange as re
from opt_einsum import contract as einsum
%load_ext line_profiler

In [2]:
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
random.seed(0)
np.random.seed(0)

In [3]:
torch.set_printoptions(precision=4)

In [4]:
def heat(x):
    df = pd.DataFrame(x.detach().numpy())
    return df.style.background_gradient(cmap='Blues')  # .format('{:.0f}')

[Attention is All You Need](https://arxiv.org/abs/1706.03762?context=cs)  
[Formal Algorithms for Transformers](https://arxiv.org/abs/2207.09238)  
[Transformer Language Model Mathematical Definition](https://www.apronus.com/math/transformer-language-model-definition)  
[AI Explained - 3D viz of transformer structure](https://youtu.be/-9vVhYEXeyQ?t=456)  
[Bloem Transformer Implementation source](https://github.com/pbloem/former/blob/master/former/modules.py)  
[The Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/#encoder-and-decoder-stacks)  
[Pytorch multi_head_attention_forward source](https://github.com/pytorch/pytorch/blob/dcf51885618e7d1d9aa6e628f3354f67ad82b446/torch/nn/functional.py#L4917)   
[Writing a better code with pytorch and einops](http://einops.rocks/pytorch-examples.html)  
[R-Drop: Regularized Dropout for Neural Networks](https://arxiv.org/pdf/2106.14448v2.pdf)

In [37]:
class Transformer(nn.Module):
    def __init__(self, d_model=512, n_blocks=6, vocab=30_000):
        super().__init__()
        self.emb = nn.Embedding(vocab, d_model)
        self.encoder = Stack(EncoderBlock, n_blocks)
        self.decoder = Stack(DecoderBlock, n_blocks)
        self.head = Sequential(Linear(d_model, vocab), Softmax(dim=-1))
       
    def forward(self, src, tgt, src_mask, tgt_mask):
        ctx = self.encoder(src, mask=src_mask)
        return self.head(self.decoder(tgt, ctx=ctx, mask=tgt_mask, ctx_mask=tgt_mask))

In [38]:
class Stack(nn.Sequential):
    def __init__(self, Layer, n):
        stack = [copy.deepcopy(Layer()) for l in range(n)]
        super().__init__(*stack)
        
    def forward(self, x, *args, **kwargs):
        for module in self:
            x = module(x, *args, **kwargs)
        return x

In [39]:
class EncoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.mha = MultiHeadAttention()
        self.ff = FeedForward()
        
    def forward(self, x, *, mask=None):
        return self.ff(self.mha(x, mask=mask))

In [40]:
class DecoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.masked_mha = MultiHeadAttention()
        self.mha = MultiHeadAttention()
        self.ff = FeedForward()
        
    def forward(self, x, *, ctx=None, mask=None, ctx_mask=None):
        x = self.masked_mha(x, mask=mask)
        return self.ff(self.mha(x, ctx, ctx, ctx_mask))

In [41]:
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return self.norm(x + self._ff(x))
    
    def _ff(self, x):
        return self.fc2(F.relu(self.fc1(x)))

In [42]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads: int = 8, emb_dim: int = 512):
        super().__init__()
        assert emb_dim % n_heads == 0  # d_h = d_q = d_k = d_v = emb_dim // n_heads
        self.proj_qkv = [Linear(emb_dim, emb_dim, bias=False) for _ in range(3)]
        self.proj_o = Linear(emb_dim, emb_dim, bias=False)
        self.norm = nn.LayerNorm(emb_dim)
        self.h, self.d = n_heads, emb_dim

    def forward(self, q, k=None, v=None, mask=None):
        """ q, k, v: (batch, seq_len, emb_dim) mask: (seq_len, seq_len) """
        if k is None and v is None:
            k = v = q
        if mask is None:
            mask = torch.zeros((q.shape[-2]))
        return self.norm(q + self._mha(q, k, v, mask))
    
    def _mha(self, q, k, v, mask):
        q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
        q, k, v = (re(x, "b l (h d) -> b h l d", h=self.h) for x in (q, k, v))
        attn = einsum("...ij,...kj->...ik", q, k)
        attn = mask + torch.einsum("...ij,...kj->...ik", q, k)
        attn = F.softmax(attn / q.shape[-1] ** (1/2), dim=-1)
        out = einsum("...ij,...jk->...ik", attn, v)
        out = re(out, "b h n d -> b n (h d)")
        out = self.proj_o(out)
        return out

In [11]:
class MHAInference(MultiHeadAttention):
    def __init__(self, n_heads: int = 8, emb_dim: int = 512):
        super().__init__(n_heads, emb_dim)
        self.attn = torch.empty((1000 * 8, 256, 256))
        self.out = torch.empty((1000 * 8, 256, 64))

    def _mha(self, q, k, v, mask):
        with torch.no_grad():
            q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
            q, k, v = (re(x, "b l (h d_h) -> (b h) l d_h", h=self.h) for x in (q, k, v))
            torch.bmm(q, re(k, "bh l d_h -> bh d_h l"), out=self.attn)
            self.attn += mask
            self.attn /= q.shape[-1] ** (1/2)
            self.attn = F.softmax(self.attn, dim=-1)
            torch.bmm(self.attn, v, out=self.out)
            return self.proj_o(re(self.out, "(b h) l d_h -> b l (h d_h)", h=self.h))

# Layer Checks

In [35]:
b, l, d = 1, 256, 512
sh = (b, l, d)
v = 30_000

In [23]:
emb = torch.rand(sh)
mask = torch.triu(torch.full((l, l), -torch.inf), diagonal=1)

In [24]:
%time assert MultiHeadAttention()(emb).shape == sh

CPU times: user 54 ms, sys: 23.5 ms, total: 77.5 ms
Wall time: 21.6 ms


In [25]:
%time assert FeedForward()(emb).shape == sh

CPU times: user 67.8 ms, sys: 7.96 ms, total: 75.8 ms
Wall time: 20.4 ms


In [26]:
%time assert EncoderBlock()(emb).shape == sh

CPU times: user 123 ms, sys: 30.3 ms, total: 154 ms
Wall time: 36.3 ms


In [27]:
%time assert DecoderBlock()(emb).shape == sh

CPU times: user 172 ms, sys: 32.4 ms, total: 204 ms
Wall time: 51.6 ms


In [28]:
%time assert MultiHeadAttention()(emb).shape == sh

CPU times: user 58.5 ms, sys: 9.98 ms, total: 68.5 ms
Wall time: 14.4 ms


In [36]:
%time assert Transformer()(emb, emb, mask, mask).shape == (b, l, v)

CPU times: user 3.49 s, sys: 229 ms, total: 3.72 s
Wall time: 486 ms


# Compare

In [8]:
emb = torch.rand((1_000, 256, 512))
mask = torch.triu(torch.full((256, 256), -torch.inf), diagonal=1)
mha = MultiHeadAttention()
mha.eval();

#### _mha vs _mha_infer

In [9]:
%%timeit
_ = mha._mha(emb, emb, emb, mask=mask)

2.83 s ± 69.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
mha_infer = MHAInference()
mha_infer.eval();

In [16]:
%%timeit
_ = mha_infer(emb, emb, emb, mask=mask)

2.32 s ± 27.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
%lprun -u 0.001 -f mha._mha mha(emb, emb, emb, mask=mask)

Timer unit: 0.001 s

Total time: 2.88578 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_51150/1451311884.py
Function: _mha at line 68

Line #      Hits         Time  Per Hit   % Time  Line Contents
    68                                               def _mha(self, q, k, v, mask):
    69         1        633.7    633.7     22.0          q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
    70         1          0.1      0.1      0.0          q, k, v = (re(x, "b l (h d) -> b h l d", h=self.h) for x in (q, k, v))
    71         1        355.6    355.6     12.3          attn = einsum("...ij,...kj->...ik", q, k)
    72         1        840.5    840.5     29.1          attn = mask + torch.einsum("...ij,...kj->...ik", q, k)
    73         1        633.2    633.2     21.9          attn = F.softmax(attn / q.shape[-1] ** (1/2), dim=-1)
    74         1        200.7    200.7      7.0          out = einsum("...ij,...jk->...ik", attn, v)
    75         1         

In [18]:
%lprun -u 0.001 -f mha_infer._mha mha_infer(emb, emb, emb, mask=mask)

Timer unit: 0.001 s

Total time: 0 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_51150/1451311884.py
Function: _mha at line 68

Line #      Hits         Time  Per Hit   % Time  Line Contents
    68                                               def _mha(self, q, k, v, mask):
    69                                                   q, k, v = (proj(x) for x, proj in zip((q, k, v), self.proj_qkv))
    70                                                   q, k, v = (re(x, "b l (h d) -> b h l d", h=self.h) for x in (q, k, v))
    71                                                   attn = einsum("...ij,...kj->...ik", q, k)
    72                                                   attn = mask + torch.einsum("...ij,...kj->...ik", q, k)
    73                                                   attn = F.softmax(attn / q.shape[-1] ** (1/2), dim=-1)
    74                                                   out = einsum("...ij,...jk->...ik", attn, v)
    75                         

#### Pytorch [MultiheadAttention.forward](https://github.com/pytorch/pytorch/blob/bbe8d019f280478dc3b143f6988e3e5668499f28/torch/nn/modules/activation.py#L1010) local



In [19]:
th_emb = re(emb, 'b l d -> l b d')
bsz, tgt_len, embed_dim = emb.shape
num_heads, head_dim = 8, 64
in_proj_weight = nn.Parameter(torch.vstack([l.weight for l in mha.proj_qkv]))
out_proj_weight = mha.proj_o.weight
th_proj_o = Linear(512, 512, bias=False)
th_proj_o.weight = out_proj_weight

def th_mha():
    th_q, th_k, th_v = F._in_projection_packed(th_emb, th_emb, th_emb, in_proj_weight, None)
    th_q = th_q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    th_k = th_k.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    th_v = th_v.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    th_out, th_attn = F._scaled_dot_product_attention(th_q, th_k, th_v, mask, 0.0)
    th_out = th_out.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    th_out = th_proj_o(th_out)
    th_out = th_out.view(tgt_len, bsz, th_out.size(1))
    th_out = mha.norm(th_emb + th_out)
    return re(th_out, 'l b d -> b l d')

#### vs. Pytorch F.multi_head_attn

In [20]:
%%timeit
mha_out = mha(emb, mask=mask)

3.13 s ± 92.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
th_out = th_mha()

2.67 s ± 43.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%lprun -u 0.001 -f th_mha th_out = th_mha()

Timer unit: 0.001 s

Total time: 2.59382 s
File: /var/folders/5y/b092b3m96yb8nglxy9dzqbnr0000gn/T/ipykernel_51150/486783080.py
Function: th_mha at line 9

Line #      Hits         Time  Per Hit   % Time  Line Contents
     9                                           def th_mha():
    10         1        821.6    821.6     31.7      th_q, th_k, th_v = F._in_projection_packed(th_emb, th_emb, th_emb, in_proj_weight, None)
    11         1         52.8     52.8      2.0      th_q = th_q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    12         1         55.9     55.9      2.2      th_k = th_k.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    13         1        126.6    126.6      4.9      th_v = th_v.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    14         1        989.2    989.2     38.1      th_out, th_attn = F._scaled_dot_product_attention(th_q, th_k, th_v, mask, 0.0)
    15         1         87.4     87.4  

In [24]:
mha_out = mha(emb, mask=mask)

In [25]:
torch.allclose(mha_out, th_out, atol=1e-6, equal_nan=True)

True