In [2]:
import torch
import torch.nn as nn
from einops import rearrange, repeat

## **Image Embeddings**

In [4]:
class Embedding(nn.Module):

    def __init__(self, image_height, image_width, image_channels, patch_size, embedding_dim):
        super().__init__()

        # Image : c,w,h : 3, 224, 224   -->   Patch : n * c,p,p : 3, 16, 16    -->   Embedding : embedding_dim
        
        self.patch_height, self.patch_width = patch_size, patch_size
        num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)

        patch_dim = image_channels * self.patch_height * self.patch_width                           # cp^2

        
        self.patch_to_embed = nn.Linear(patch_dim, embedding_dim)                                   # Linear layer that converts the patch vectors to embeddings

        # nn.Parameter => trainable params
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embedding_dim))           # num_patches + 1  -->  every patch will get an embedding + 1 CLS token   && each embedding will be of len 'embedding_dim'
        self.cls_token = nn.Parameter(torch.randn(embedding_dim))
        

    def forward(self, X):

        # input is of the shape [ b, c, h, w ]
        batch_size = X.shape[0]

        # [ b, c, h, w ] --> [ b, num_patches, patch_dim ]   :   [ b, h*w/p*p, cp^2 ]

        output = rearrange(X, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)',  ph = self.patch_height, pw = self.patch_width)                           
        # even though nh & nw arent explicitly mentioned, it is understood from the shape of X
                           
        
        output = self.patch_to_embed(output)  

        cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = batch_size)                           # Repeat for every image in the batch size  
        output = torch.cat((cls_tokens, output), dim = 1)                                           # add the class token

        output += self.pos_embedding

        return output

## **Attention Mechanism**

In [5]:
class Attention(nn.Module):

    def __init__(self, num_heads, head_dim, embedding_dim):
        super().__init__()

        # Input is given as n embeddings of length p^2c => [ num_patches, D ]
        # This input is transformed using a weight matrix W to get Q, K, V :    

        # [ num_patches, D ]  x  [ D, nh * 3 * hd ]  -->  [ num_patches, nh * 3 * hd]
        
        self.num_heads = num_heads                                      # nh
        self.head_dim = head_dim                                        # hd
        self.embedding_dim = embedding_dim                              # D

        self.attention_dim = self.num_heads * self.head_dim             # nh * hd

        # This is the layer that acts as Wq Wk Wv for the embeddings and converts them to Q K V
        self.wq_wk_wv_projection = nn.Linear(self.embedding_dim, self.attention_dim * 3, bias= False)

        self.output_proj = nn.Linear(self.attention_dim, self. embedding_dim)

    def forward(self, X):
        
        # Input X --> [ batch, num_patches, D ]
        B, N = X.shape[:2]

        # Projecting this input X to Q,K,V : [ B, N, D ] --> [ B, N, nh * 3 * hd ]
        # Splitting into 3 parts [ B, N, nh * 3 * hd ] --> 3 x [ B, N, nh * hd ] 

        Q, K, V = self.wq_wk_wv_projection(X).split(self.attention_dim, dim = -1) 

        # Rearranging [ B, N, nh * hd ] --> [ B, nh, N, hd ]
        Q = rearrange(Q, 'b n (nh hd) -> b nh n hd', nh = self.num_heads, hd = self.head_dim)
        K = rearrange(K, 'b n (nh hd) -> b nh n hd', nh = self.num_heads, hd = self.head_dim)
        V = rearrange(V, 'b n (nh hd) -> b nh n hd', nh = self.num_heads, hd = self.head_dim)

        # att = Q x K.T / sqrt(hd)
        # att = softmax(att)
        # output = att x V
        # concat all the heads 

        att = torch.matmul(Q, K.transpose(-2, -1)) * (self.head_dim ** (-0.5))
        att = torch.nn.functional.softmax(att, dim = -1)
        
        output = torch.matmul(att, V)
        output = rearrange(output, 'b nh n hd -> b n (nh hd)', nh = self.num_heads, hd = self.head_dim)

        return output

## **Transformer Layer**

In [8]:
class TransformerLayer(nn.Module):
    
    def __init__(self, embedding_dim, num_heads, head_dim):
        super().__init__()

        ff_embedding_dim = embedding_dim * 4    # Usually a large value like 2048
        
        self.normalize = nn.LayerNorm(embedding_dim)
        self.attention_block = Attention(num_heads, head_dim, embedding_dim)
        self.feed_forward_block = nn.Sequential(
            nn.Linear(embedding_dim, ff_embedding_dim),
            nn.GeLU(),
            nn.Linear(ff_embedding_dim, embedding_dim)
        )

    def forward(self, X):

        out = X
        out = out + self.attention_block(self.normalize(out))
        out = out + self.feed_forward_block(self.normalize(out))
        return out       

## **Transformer**

In [9]:
class VisionTransformer(nn.Module):

    def __init__(self, num_layers, num_classes, image_height, image_width, image_channels, patch_size, embedding_dim):
        super().__init__()

        # Embedding Layer responsible for patchification
        self.patch_embedding_layer = Embedding(image_height, image_width, image_channels, patch_size, embedding_dim)

        # The several transformer layers, stacked together
        self.layers = nn.ModuleList(
            [TransformerLayer() for _ in range(num_layers)]
        )

        # Normalization block
        self.norm = nn.LayerNorm(embedding_dim)

        # Final layer, this is responsible for making the prediction as it assigns probabilities to each class using the CLS token
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, X):
        
        out = self.patch_embedding_layer(X)

        for layer in self.layers:
            out = layer(out)

        out = self.norm(out)

        out = self.fc(out[:, 0])        # calculates for the CLS token
        return out       