In [2]:
import torch 
from torch import nn 

### ViT Architecture from the 16x16 words paper

![vit_architecture](ViT_architecture.png)
 

In [6]:
LEARNING_RATE = 1e-4
NUM_CLASSES = 10 #because MNIST
PATCH_SIZE = 4 #we chose 4-> pixel length of 1 dimension
IMAGE_SIZE = 28 #The MNIST dataset images are 28 × 28 pixels in size. (H,W) = (28, 28) 
IN_CHANNELS = 1 #MNIST only has 1 channel (Grayscale). Note: RGB would be 3 channels. 
NUM_HEADS = 8 #Within the transformer encoder there are attention heads- we choose 8 of them.                           
DROPOUT = 0.001 
HIDDEN_DIM = 768 #hidden dimentsion of MLP head for classification 
ADAM_WEIGHT_DECAY = 0 # paper uses 0.1, set it to 0 (defautl value)
ADAM_BETAS = (0.9, 0.999) # again from paper. 

ACTIVATION = "gelu" #again use the same as the paper 
NUM_ENCODER = 4 #stack encoders on top of each other (architecture just shows one)


##This is the input size to the patch embedding layer (aka flattening image into sequence of patches )
EMBED_DIM = (PATCH_SIZE**2) * IN_CHANNELS # 16 -> basically the number of patches

## Paper defines the below as: N =HW / P^2
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 # 49

device = "cude" if torch.cuda.is_available() else "cpu"

## Quick Check in on [CLS] Token and Positional Embeddings 

In ViT: 
- split image into patches and turn them into embeddings (1 patch = 1 embedding vector)
- model pretends that the patches are a sequence of tokens, just like words in NLP models like BERT

2 Important points : 
1) prepend a special [CLS] token (like "classification token") at the start of the sequence.
2) add positional embeddings to every token (patches and the [CLS] token) so the model knows the order.

# Example:
1) After patch embedding, suppose we have: 
- 100 patches, where each patch is an embedding vector of size D (say, 768)
- so our sequence has shape: (100,768)

2) Create [CLS] token 
- Create a new learnable vector (randomly initialized) of size D, called the [CLS] token -> just another vector like a patch but it doesn't come from the image

3) Prepend the [CLS] token
- now sequence becomes: (1+100,768)
- where First position: [CLS] token & Next positions: patch tokens

4) Add Positional Embeddings 
- Transformers have no sense of order natively, so you add (element-wise) a positional embedding vector to each token
- The [CLS] token gets a positional embedding for position 0.
-  Patch tokens get positional embeddings for positions 1, 2, 3, ..., 100.
- Now the model knows which patch is where.

In [7]:
# Creating CLS Tokens and merging with Positional Embeddings 

class PatchEmbedding(nn.Module):
    def __init__(self, embedding_dim, patch_size, num_patches, dropout, in_channels): 
        super().__init__()
        
        #function that divides images into patches
        self.patcher = nn.Sequential(
            # all Conv2d does is divide our image into patch sizes
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embedding_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ), 
            nn.Flatten(start_dim=2)) #equivalent to nn.Flatten(start_dim=2, end_dim=-1) -> not a learnable layer (converts patched into sequence of vectors)
        
            #OUTPUT SHAPE: (batch_size, embedding_dim, num_patches) AKA the full sequence of patches
            
        
        #---- CLS Token ---- 
     
        #here we define the [CLS] token. nn.Parameter is a learnable tensor (its a single parameter not a full layer)
        # Create a random tensor of shape (1, in_channels, embedding_dim), wrap it as a learnable parameter, and assign it as the CLS token
        self.cls_token = nn.Parameter(torch.randn(size=(1,in_channels,embedding_dim)), requires_grad=True)
        
        
        #---- Positional Embedding ---- 
        
        
        #positional embedding is a learnable parameter 
        self.position_embedding = nn.Parameter(torch.randn(size=(1,num_patches+1,embedding_dim)), requires_grad=True) #we add 1 to num_patches because we have the [CLS] token
        
        self.dropout = nn.Dropout(dropout)
    
    
    #after patching and flattening we have a tesnor of shape (batch_size, embedding_dim, num_patches) e.g., (32, 16, 49)
    # x = x.permute(0, 2, 1) rearranges to (batch_size, num_patches, embedding_dim) e.g., (32, 49, 16)
        
        
        
        
    def forward(self, x): 
        #here we expand the cls token so its not just the shape for 1 sample but for a batch of images
        cls_token = self.cls_token.expand(x.shape[0], -1, -1) #expand the cls token to the batch size. x.shape[0] is the batch size. -1, -1 tells expand function to keep original dimensions. 
        x = self.patcher(x).permute(0,2,1) # first patch x through patcher -> where nn.Conv2d: splits x into patches and embeds them, nn.Flatten(start_dim=2) converts into 1D sequence
        
        #1 axis for batches, 1 axis for sequence of patches, 1 axis for embedding dimension 
        x = torch.cat([cls_token, x], dim=1) #so we want to add the CLS token to the left of the patches
        x = self.dropout(x)
        return x


#always test model after you define it    
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)  
x = torch.randn(512, 1, 28, 28) #create dummy image of batch size 512, channels 1, and dimensions 28x28 
print(model(x).shape) #expect (512, 50, 16) where batch size 512, 50 is number of tokens we feed transformer (correct because we have 49 patches + CLS token), 16 is size of patches (embedding dimension)

torch.Size([512, 50, 16])


Note: nn.Sequential is a convenience container in PyTorch.
It lets you stack layers together in order, without writing a full forward() method manually.

Instead of writing:
 ```python
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
```
You can do the same thing with nn.Sequential:

```python

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)
```

Notes on the layers:

-  nn.Conv2d is a layer in PyTorch (torch.nn) used for applying a 2D convolution over an input, typically an image.
    It slides filters (kernels) over a 2D input (like an image) and computes feature maps.
- nn.Flatten() reshapes a tensor by flattening part of its dimensions into a single one.
     Turns multi-dimensional data (like 2D or 3D feature maps) into a 1D vector per sample, usually before feeding it into fully connected (Linear) layers.

Notes on cls_token:
- shape is: 1 × in_channels × embedding_dim
    - 1 -> Becaues we only have one CLS token per sample
    - in_channels × embedding_dim -> to match the dimensions of the patch embedding vector because we add it to the sequence before feeding into the transformer.
    - another note on the 1, it is a batch-size like-dimenssion, and we replace it with the batch size
    
    
- why have it? 
    - CLS token = Learnable summary of the whole input.
    - It acts as a summary token: after going through the transformer layers, the model will read the CLS token to decide the final class label.
    - Think of it like a "learnable prompt" — the model writes its summary into it during training.
    

Tracking the shapes:

You start with an image of shape:
(batch_size, in_channels, height, width) = (32, 1, 28, 28)


then we apply patcher 

```
self.patcher = nn.Sequential(
    nn.Conv2d(in_channels, embedding_dim, patch_size, stride=patch_size),
    nn.Flatten(start_dim=2)
)

```
where nn.Conv2D divides the image into patches of size 4×4 pixels so we get: 
(batch_size, embedding_dim, height//patch_size, width//patch_size) = (32, 16, 7, 7)
     where: 
     - 28 // 4 = 7 patches along height
     - 28 // 4 = 7 patches along width
     - 16 filters = 16 features per patch
    
Apply nn.Flatten(start_dim=2), which Flatten from dimension 2 onward:
- Flatten (7,7) together into 49, so after flattening:
- (batch_size, embedding_dim, num_patches) = (32, 16, 49)

Quick notes on below:
- We implemented our encoder block using pytorch but we're meant to explicitely code it out -> I will do this after doing one full iteration

In [None]:
class Vit(nn.Module):
    def __init__(self, num_patches, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):
        super().__init__()
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels) #call the class we trained earlier -> this will give us the input to our encoder (divide image into patches and generate sequences)
        
        
        #---- ENCODER ---- 
        #PyTorch version:
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True) # we defined our images so that batch size comes first, so we should add: batch_first=True 
        
        #above is only one encoder layer, we're stacking many encoder layers:
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
        
        #---- MLP HEAD ---- 
        # The ViT typically uses only the [CLS] token for classification
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim) #normalise each sequence with itself 
            nn.Linear(in_features=embed_dim, out_features=num_classes) #since we are doing classification, output features is number of classes (10 in our case)
        )
    
    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x - self.mlp_head(x[:, 0, :]) #we dont classify the whole embedding, instead we classify the CLS token in the beginning because its a learnable parameter and its meant to contain all the information the other parameters have
        return x 
        

So the flow is:
Transformer Encoder outputs (batch_size, num_patches + 1, embed_dim).

You select the first token (CLS):
```
cls_token = output[:, 0, :]
```
→ shape (batch_size, embed_dim)

Then apply:

LayerNorm (normalize each embed_dim vector independently)

Linear (map to num_classes)