# Vision Transformer Architecture

## Preprocessing
1. Patch making
2. Linearly projecting flattened patches
3. Adding class token

In [159]:
D=64
L=4
C=10

In [160]:
import torch
import numpy as np
import torch.nn as nn

In [161]:
data = torch.randn((4,3,10,10)) # dataset with 4 batch, 3 channels, 10x10 image
patches = data.reshape(4,3,25,2,-1) # 25 patches each of 2x2
pre_embed = patches.reshape(4,3,25,-1) # Flattened patches

In [162]:
# getting more details on the image with Convolution
detailed_preembed = nn.Conv2d( in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1)

# Moving from flattened patch to the image embeddings
embed = nn.Sequential(
    nn.Linear(32, D),
    nn.ReLU()
)

embed_vec = embed(detailed_preembed(pre_embed).transpose(1,2).flatten(2))

In [163]:
# Adding cls token !
cls_token = torch.randn(4,1,D)
token_emb = torch.cat([cls_token, embed_vec], dim=1) # Class token at the start
pos_emb = torch.randn(embed_vec.shape[1] + 1, D) # pos embeddings

input_emb = token_emb + pos_emb

## Transformer Encoder
1. Normalise
2. Multi head attention
3. Normalise
4. MLP

In [164]:
from typing import Optional

class TinyTransformerEncoderBlock(nn.Module):
    """
    Minimal Transformer encoder block using nn.MultiheadAttention.
    Shape: x is (B, N, D) with batch_first=True.
    """
    def __init__( self, dim: int, num_heads: int = 8, mlp_ratio: float = 4.0, attn_dropout: float = 0.0, dropout: float = 0.0,):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # ---- Self-Attention (pre-norm) ----
        x_res = x
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(
            x_norm, x_norm, x_norm,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            need_weights=False
        )
        x = x_res + self.drop1(attn_out)

        # ---- MLP (pre-norm) ----
        x = x + self.mlp(self.norm2(x))
        return x


In [165]:
encoder_list = nn.Sequential(
    TinyTransformerEncoderBlock(D),
    TinyTransformerEncoderBlock(D),
    TinyTransformerEncoderBlock(D),
    TinyTransformerEncoderBlock(D)
)

In [166]:
context = encoder_list(input_batch)

## Post processing
1. Normalise
2. MLP Head

In [167]:
import torch.nn.functional as F

In [176]:
context_class = context[:, 0, :]

In [179]:
mlp_head = nn.Linear(D, C)
logits = F.softmax(mlp_head(context_class), dim=-1)
class_logit = logits.argmax(-1)

In [180]:
class_logit

tensor([4, 3, 2, 6])