### Dependencies

In [194]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
import pytorch_lightning
import wandb

## Model

![Vision Transformer](vit.png)

In [195]:
class Encoder(nn.Module):
    
    def __init__(self, embedding_dim, d_model=256, nhead=4, num_layers=2):
        super(Encoder, self).__init__()
        "encoder layer is single encoder block that is shown in above pic  (right)"
        self.encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers,
            norm = nn.LayerNorm(normalized_shape=embedding_dim)
        )
    def forward(self, x):
        out = self.encoder(x)
        return out


In [247]:
class VisionTransformer(nn.Module):
    """
        transformer module to encoder the image patches
        output of this module will be flatten encoded patches
    """
    
    def __init__(self, image_size, channels, patch_size, stride, embedding_dim, nhead, num_layers, num_classes, fc_dim=256):
        super(VisionTransformer, self).__init__()
        
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2
        self.stride = stride

        # patch_pos embedding and patch projection layer
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embedding_dim))
        self.patch_projection = nn.Linear(in_features=self.patch_dim, out_features=embedding_dim, bias=False)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) 
        
        # transformer module to encoder the image patches output of this module will be flatten encoded patches
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=nhead),
            num_layers=num_layers,
            norm=nn.LayerNorm(normalized_shape=embedding_dim)
        )
        
        # to take cls token
        self.to_cls_token = nn.Identity()
        # classifier or mlp head to classify the data
        self.fc = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=fc_dim),
            nn.ReLU(),
            nn.Linear(in_features=fc_dim, out_features=num_classes)
        )
        
        
    def forward(self, x):
        # x.shape = [batch, w, h, channel]
        # patchifyt the image
        batch_size = x.shape[0]
        
        x = self.patchify(x, self.patch_size, self.patch_size)
        x = self.patch_projection(x)
        
        # concat cls token into projected patch
        cls_token = self.cls_token.expand(batch_size, -1, -1) 
        x = torch.cat((cls_token, x), dim=1)
        # add positional embedding + projected patches 
        x = x + self.pos_embedding
        
        # encoded the input and take the cls token and then feed it to mlp
        x = self.transformer(x)
        x = self.to_cls_token(x[:, 0])        
        outputs = self.fc(x)
        return outputs
        
        
    def patchify(self, images, patch_size, stride):
        # get all image windows of size (patch_size, patch_size) and stride (stride, stride)
        patches = images.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
        patches = patches.permute(0, 2, 3, 4, 5, 1).contiguous()
        # patches.shape -> [batch, ... .... ... ..., ]
        # the size of flatten vector
        bs, pr, pc, h, w, ch = patches.shape[0], patches.shape[1], patches.shape[2], patches.shape[3], patches.shape[4], patches.shape[5]
        # bs->batch_size, rp->patches_row, pc->patches_col, h->patch_height, w->patch_width, w->patch_widht, ch->channels

        # dissolve it 
        patches = patches.view(bs, pr*pc, h*w*ch)

        return patches
        
    
        
        

In [248]:
vit = VisionTransformer(
    image_size=256,
    channels=3,
    patch_size=16,
    stride=16,
    embedding_dim=512,
    nhead=8,
    num_layers=2,
    num_classes=100
)

In [249]:
x = torch.rand(10, 3, 256, 256)
outputs = vit(x)

In [250]:
outputs.shape

torch.Size([10, 100])