In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [2]:
class PositionalEmbedding1D(nn.Module):
    """Adds (optionally learned) positional embeddings to the inputs."""

    def __init__(self, seq_len, dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim))
    
    def forward(self, x):
        """Input has shape `(batch_size, seq_len, emb_dim)`"""
        # (1, 3, 14, 14) + (1, 3, 196) ? wtf
        # ohhh flatten before embedding kek
        return x + self.pos_embedding

In [3]:
class MLP(nn.Module):
    def __init__(self, dim, ff_dim):
        super().__init__()
        self.l1 = nn.Linear(dim, ff_dim)
        self.l2 = nn.Linear(ff_dim, dim)
        
    def forward(self, x):
        out = self.l2(F.gelu(self.l1(x)))
        return out

In [4]:
class MHSA(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.project = nn.Linear(dim, dim)
        
    def forward(self, x): # input shape [b, s, d]
        # split into q, k, v : (query, key, values)
        # expand by h (num heads... hence multi attention)
        
        q, k, v = self.project(x), self.project(x), self.project(x)
        # convert q, k, v -> [b, h, s, w] where h = n_heads
        q = q.view(q.shape[0], self.n_heads, q.shape[1], -1)
        k = k.view(k.shape[0], self.n_heads, k.shape[1], -1)
        v = v.view(v.shape[0], self.n_heads, v.shape[1], -1)
        
        # scaled dot product attention on q, k (queries, keys) then matmul with values
        # matmul + scale
        # [b, h, s, w] @ [b, h, w, s] -> [b, h, s, s]
        k = k.transpose(-2, -1) # swap last two dimensions
        p = torch.matmul(q, k)
        p = p / np.sqrt(k.size(-1)) # where s is the dimension of k

        p = F.softmax(p, dim = -1) # softmax across last dimension
        
        out = torch.matmul(p, v) # [b, h, s, s] @ [b, h, s, w] -> [b, h, s, w]
        out = out.view(out.shape[0], out.shape[2], -1) # [b, s, d]
        return out

In [5]:
# dont need linear and dont need dropout i think
# add after if performance is shit

class Block(nn.Module): # inputs are B, S, D
    def __init__(self, dim, n_heads, ff_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.MHSA = MHSA(dim, n_heads)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.MLP = MLP(dim, ff_dim)
        
    def forward(self, x):
        h = self.MHSA(self.norm1(x))
        x = x + h
        h = self.MLP(self.norm2(x))
        x = x + h
        return x

In [6]:
class Transformer(nn.Module):
    def __init__(self, n_layers, dim, n_heads, ff_dim):
        super().__init__()
        self.block = nn.ModuleList([Block(dim, n_heads, ff_dim) for _ in range(n_layers)])
    
    def forward(self, x):
        for block in self.block:
            x = block(x)
        return x

In [7]:
class ViT(nn.Module):
    def __init__(self, in_channels, dim, fh, fw, n_layers, n_heads, ff_dim, num_classes):
        super().__init__()
        self.fw = fw
        self.fh = fh
        self.patch_encoding = nn.Conv2d(in_channels, dim, kernel_size=(self.fh, self.fw), stride=(self.fh, self.fw))
        # [B, D, FH, FW]
        '''
        with image 1, 3, 224, 224  and patch encoding of 16x16 we have
        ((224 - 16) / 16) + 1 = 14
        seq len is 14 * 14 bro im trippin bullets
        out = (1, 3, 14, 14)
        '''
        self.positional_embedding = PositionalEmbedding1D(14 * 14, dim) # inputs are seq len, dim
        # [B, D, FH, FW]
        
        # flatten into [B, S, D]
        self.Transformer = Transformer(n_layers, dim, n_heads, ff_dim)
        
        # if this doenst work then do it manually
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.mlp_head = nn.Linear(dim, num_classes)
        
    def forward(self, x):
        x = self.patch_encoding(x)
        x = x.view(x.shape[0], -1, x.shape[1]) # b, s, d
        x = self.positional_embedding(x) 
        x = self.Transformer(x)
        x = self.norm(x)
        x = x[:, -1, :] # b, s, d -> b, d
        x = self.mlp_head(x)
        return x

In [8]:
#hyperparameters
fw = fh = 16
dim = 768
ff_dim = 3072
n_heads = 12
n_layers = 12
in_channels = 3
num_classes = 10


In [9]:
model = ViT(in_channels, dim, fh, fw, n_layers, n_heads, ff_dim, num_classes)

In [10]:
# test input of (1, 3, 224, 224) like in research paper
x = torch.rand(1, 3, 224, 224)

In [11]:
out = model(x)

In [14]:
out

tensor([[ 0.1031,  0.1460, -0.1724,  1.0007, -0.5453,  0.1851,  1.3547, -0.3695,
         -1.0684,  0.5839]], grad_fn=<AddmmBackward0>)

In [16]:
torch.argmax(out)

tensor(6)