In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrange

In [2]:
def sdp_attn(q, k, v, is_causal, n_embd, n_heads):
    nh, hc = n_heads, n_embd // n_heads
    q = rearrange(q, 'b t (nh hc) -> b nh t hc', nh=nh, hc=hc)
    k = rearrange(k, 'b t (nh hc) -> b nh t hc', nh=nh, hc=hc)
    v = rearrange(v, 'b t (nh hc) -> b nh t hc', nh=nh, hc=hc)

    attn = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
    attn = rearrange(attn, 'b nh t hc -> b t (nh hc)')
    return attn

In [3]:
class SelfAttn(nn.Module):
    def __init__(self, cfg, is_causal):
        super().__init__()
        self.n_embd = cfg['n_embd']
        self.n_heads = cfg['n_heads']
        self.is_causal = is_causal
        self.QKV = nn.Linear(self.n_embd, 3*self.n_embd, bias=False)
        self.out_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x):
        q, k, v = torch.split(self.QKV(x), self.n_embd, dim=-1)
        attn = sdp_attn(q, k, v, self.is_causal, self.n_embd, self.n_heads)
        return self.out_proj(attn)


class CrossAttn(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.n_embd = cfg['n_embd']
        self.n_heads = cfg['n_heads']
        self.KV = nn.Linear(self.n_embd, 2*self.n_embd, bias=False)
        self.Q = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.out_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x, x_enc):
        k, v = torch.split(self.KV(x_enc), self.n_embd, dim=-1)
        q = self.Q(x)
        attn = sdp_attn(q, k, v, True, self.n_embd, self.n_heads)
        return self.out_proj(attn)

In [4]:
n_embd = 6
n_heads = 2
cfg = { 'n_embd': n_embd, 'n_heads': n_heads }
sa = SelfAttn(cfg, is_causal=True)
ca = CrossAttn(cfg)

In [5]:
x = torch.randn(1, 5, n_embd)
y = torch.randn(1, 3, n_embd)
sa(y).shape, ca(x, y).shape

(torch.Size([1, 3, 6]), torch.Size([1, 5, 6]))

In [6]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        n_embd = cfg['n_embd']
        self.layers = nn.Sequential(
            nn.Linear(n_embd, n_embd*4, bias=False),
            nn.ReLU(),
            nn.Linear(n_embd*4, n_embd, bias=False),
        )

    def forward(self, x):
        return self.layers(x)

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.ln0 = nn.LayerNorm(cfg['n_embd'])
        self.sa = SelfAttn(cfg, is_causal=False)
        self.ln1 = nn.LayerNorm(cfg['n_embd'])
        self.mlp = MLP(cfg)

    def forward(self, x):
        x = x + self.sa(self.ln0(x))
        x = x + self.mlp(self.ln1(x))
        return x


class DecoderBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        n_embd = cfg['n_embd']
        self.ln0 = nn.LayerNorm(n_embd)
        self.sa = SelfAttn(cfg, is_causal=False)
        
        self.ln1 = nn.LayerNorm(n_embd)
        self.ca = CrossAttn(cfg)

        self.ln2 = nn.LayerNorm(n_embd)        
        self.mlp = MLP(cfg)

    def forward(self, x, x_enc):
        x = x + self.sa(self.ln0(x))
        x = x + self.ca(self.ln1(x), x_enc)
        x = x + self.mlp(self.ln2(x))
        return x

In [8]:
enc_blk = EncoderBlock(cfg)
dec_blk = DecoderBlock(cfg)

In [9]:
x = torch.randn(1, 5, n_embd)
y = torch.randn(1, 3, n_embd)

In [10]:
enc_blk(y).shape, dec_blk(x, y).shape

(torch.Size([1, 3, 6]), torch.Size([1, 5, 6]))

In [51]:
def patchify(x, patch_size, pad_val=None):
    B, C, H, W = x.shape
    pw, ph = patch_size
    dy, dx = H % ph, W % pw
    if pad_val is not None:  
        pad_left, pad_top = dx // 2, dy // 2
        pad_right, pad_bot = dx - pad_left, dy - pad_top
        x = F.pad(x, (pad_left, pad_right, pad_top, pad_bot), value=pad_val)
    else:
        assert dx == 0 and dy == 0

    H, W = x.shape[-2], x.shape[-1]
    # split into patches, then flatten them
    x = rearrange(x, 'b c (hs h) (ws w) -> b (hs ws) (h w c)', hs=H//ph, ws=W//pw)
    return x

In [54]:
B, C, H, W = 1, 1, 4, 6
x = torch.arange(B*C*H*W).view(B,C,H,W)
print(x)
x = patchify(x, (2, 2))
print(x.shape)
print(x)

tensor([[[[ 0,  1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10, 11],
          [12, 13, 14, 15, 16, 17],
          [18, 19, 20, 21, 22, 23]]]])
torch.Size([1, 6, 4])
tensor([[[ 0,  1,  6,  7],
         [ 2,  3,  8,  9],
         [ 4,  5, 10, 11],
         [12, 13, 18, 19],
         [14, 15, 20, 21],
         [16, 17, 22, 23]]])


In [55]:
class Encoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        channels = cfg['channels']
        pw, ph = cfg['patch_size']
        n_ctx = cfg['n_ctx']
        n_embd = cfg['n_embd']
        n_blocks = cfg['n_blocks']
        
        self.patch_size = (pw, ph)
        self.pad_value = cfg.get('pad_value', None)

        self.embd = nn.Embedding(n_ctx, n_embd)
        self.project_patch = nn.Linear(pw*ph*channels, n_embd, bias=False)
        self.blocks = nn.ModuleList([
            EncoderBlock(cfg) for _ in range(n_blocks)
        ])
        self.ln_out = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = patchify(x, self.patch_size, self.pad_value)
        x = self.project_patch(x)
        
        device = next(self.parameters()).device
        B, T, E = x.shape
        x = x + self.embd(torch.arange(T, device=device))
            
        for block in self.blocks:
            x = block(x)
        return self.ln_out(x)

In [59]:
enc_cfg = {
    'n_blocks': 2,
    'n_ctx': 16,
    'channels': 1,
    'patch_size': (2, 2),
    'n_embd': 6,
    'n_heads': 2
}
encoder = Encoder(enc_cfg)
sum(p.numel() for p in encoder.parameters())

1044

In [58]:
x = torch.randn(1, enc_cfg['channels'], 2, 4)
encoder(x).shape

torch.Size([1, 2, 6])

In [60]:
class TiedLinear(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w

    def forward(self, x):
        return F.linear(x, self.w, bias=None)
        

class Decoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        n_ctx, n_vocab = cfg['n_ctx'], cfg['n_vocab']
        n_blocks, n_embd = cfg['n_blocks'], cfg['n_embd']

        self.embd_tok = nn.Embedding(n_vocab, n_embd)
        self.embd_pos = nn.Embedding(n_ctx, n_embd)
        
        self.blocks = nn.ModuleList([
            DecoderBlock(cfg) for _ in range(n_blocks)
        ])
        self.ln_out = nn.LayerNorm(n_embd)
        self.head = TiedLinear(self.embd_tok.weight)

    def forward(self, x, x_enc):
        B, T = x.shape
        device = next(self.parameters()).device
        
        tok_embd = self.embd_tok(x)
        pos_embd = self.embd_pos(torch.arange(T, device=device))
        x = tok_embd + pos_embd

        for block in self.blocks:
            x = block(x, x_enc)

        x = self.ln_out(x)
        x = self.head(x)
        return x

In [61]:
cfg = {
    'n_ctx': 16,
    'n_vocab': 7,
    'n_blocks': 2,
    'n_embd': 6,
    'n_heads': 2,
}
decoder = Decoder(cfg)
sum(p.numel() for p in decoder.parameters())

1374

In [62]:
x = torch.randint(cfg['n_vocab'], size=(1,5))
y = torch.randn(1, 3, cfg['n_embd'])
decoder(x, y).shape

torch.Size([1, 5, 7])

In [63]:
class DoubleConv(nn.Module):
    def __init__(self, ch_in, ch_out, ch_mid=None):
        super().__init__()
        if ch_mid is None:
            ch_mid = ch_out
        self.layers = nn.Sequential(
            nn.Conv2d(ch_in, ch_mid, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ch_mid),
            nn.ReLU(),
            nn.Conv2d(ch_mid, ch_out, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


class CNN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        ch_in = cfg['ch_in']
        fts, ch_out = cfg['init_filters'], cfg['ch_out']
        n_layers = cfg['n_layers']

        self.bn_in = nn.BatchNorm2d(ch_in) # input normalization for the lazy
        self.conv_in = nn.Conv2d(ch_in, fts, 3, 1, 1, bias=False)
        self.bn0 = nn.BatchNorm2d(fts)
        
        self.conv_layers = nn.ModuleList([
            DoubleConv(2**k * fts, 2**(k+1) * fts) for k in range(n_layers)
        ])

        self.conv_out = nn.Conv2d(2**n_layers * fts, ch_out, 3, 1, 1)

    def forward(self, x):
        x = self.bn_in(x)
        x = F.relu(self.bn0(self.conv_in(x)))
        for layer in self.conv_layers:
            x = layer(x)
        return self.conv_out(x)

In [65]:
cnn_cfg = {
    'ch_in': 1,
    'n_layers': 3,
    'init_filters': 4,
    'ch_out': 16
}
cnn = CNN(cnn_cfg)
sum(p.numel() for p in cnn.parameters())

23038

In [66]:
x = torch.randn(1, 1, 32, 64)
x = cnn(x)
x.shape

torch.Size([1, 16, 4, 8])

In [67]:
patchify(x, (2, 2)).shape

torch.Size([1, 8, 64])

In [70]:
class OCRNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.backbone = CNN(cfg['backbone'])
        self.encoder = Encoder(cfg['encoder'])
        self.decoder = Decoder(cfg['decoder'])

    def forward(self, x_im, x_tok):
        vis_fts = self.backbone(x_im)
        x_enc = self.encoder(vis_fts)
        x = self.decoder(x_tok, x_enc)
        return x

In [72]:
cfg = {
    'backbone': {
        'ch_in': 1,
        'n_layers': 3,
        'init_filters': 8,
        'ch_out': 16
    },
    'encoder': {
        'n_blocks': 4,
        'n_ctx': 16,
        'channels': 16,
        'patch_size': (2, 2),
        'n_embd': 64,
        'n_heads': 4
    },
    'decoder': {
        'n_ctx': 16,
        'n_vocab': 28,
        'n_blocks': 4,
        'n_embd': 64,
        'n_heads': 4,
    }
}
net = OCRNet(cfg)
sum(p.numel() for p in net.parameters())

551850

In [74]:
x_im = torch.randn(1, 1, 32, 128)
x_tok = torch.randint(28, (1, 3,))

In [75]:
x = net(x_im, x_tok)

In [76]:
x.shape

torch.Size([1, 3, 28])