In [None]:
#Importing the required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import math

In [None]:
import sys #to solve the IndexError

1. ViT core components

A mechanism that calculates and applies a weighted sum of value vectors based on the scaled dot-product similarity between query and key vectors, performed simultaneously across multiple representation subspaces (heads).

In [None]:
#Implementing the multi-head self-attention mechanism
class MultiHeadAttention(nn.Module): #nn.Module is inherited for automatic parameter tracking
  def __init__(self, embed_dim, num_heads):
    super().__init__() #sets up module's internal state for parameter and submodule registration
    self.num_heads = num_heads #stores the number of parallel attention heads
    self.head_dim = embed_dim // num_heads #calculates the dimension of each individual head (num_heads * head_dim = embed_dim)
    self.scale = self.head_dim ** -0.5 #calculates the scaling factor used in scaled dot-product attention: 1/sqrt(D_k).
    #This prevents dot products from growing too large, which stabalizes the softmax function and thus the training process

    #linear layers for Query, Key, Value & output projection
    self.qkv = nn.Linear(embed_dim, embed_dim*3, bias=False)
    self.proj = nn.Linear(embed_dim, embed_dim)

  def forward(self, x): #forward function
    #input shape: (batch_size, num_tokens, embed_dim)
    B, N, C = x.shape #unpacks the input shape

    #1. generate Q, K, V
    #QKV shape: (B,N,3*C) -> split into 3 tensors
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    # Indices after reshape: 0=Batch(B), 1=Tokens(N), 2=QKV_Split(3), 3=Heads(H), 4=Head_Dim(D_h)

    # qkv[0]: Q, qkv[1]: K, qkv[2]: V (all have shape: B, num_heads, num_tokens, head_dim)
    q, k, v = qkv[0], qkv[1], qkv[2]

    #2. compute attention score: (Q @ K^T) * scale
    # (B, H, N, D) @ (B, H, N, D) -> (B, H, N, D)
    attn = (q @ k.transpose(-2, -1)) * self.scale

    #3. softmax to get attention weights
    attn = attn.softmax(dim=-1) #applies softmax function along the last dimension (N of the keys).
    #this converts the logits into attention weights (probabilities) that sum to 1 for each query token.

    #4. apply attention to V: (Attn @ V)
    # (B, H, N, D) @ (B, H, N, D) -> (B, H, N, D)
    x = (attn @ v).transpose(1, 2).reshape(B, N, C) #.transpose reverts the earlier permutation to bring the token (N) back to the second postion (B, N, H, D)

    #5. output projection
    x = self.proj(x)
    return x #final output tensor of shape (B, N, C) which the result of the multi-head self-attention

A two-layer position-wise feed-forward network that first expands the feature dimension (e.g., $1 \times \text{to } 4 \times$) then contracts it back to the original size, with a non-linearity (GELU) in between.

In [None]:
#feed forward network with 2 linear layers
class MLP(nn.Module): #Multi-Layer Perceptron
  def __init__(self, embed_dim, hidden_dim, dropout_rate=0.1): #hidden_dim means the expansion dimension
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(embed_dim, hidden_dim), #expansion
        nn.GELU(), #Applies the Gausian Error Linear Unit non-linearity to the expanded features
        #this introduces the complexity needed for the network to learn non-linear mappings
        nn.Dropout(dropout_rate),
        nn.Linear(hidden_dim, embed_dim), #contraction
        nn.Dropout(dropout_rate)
    )

  def forward(self, x):
    return self.net(x)

The fundamental repeating unit of a Transformer encoder, consisting of a Multi-Head Attention sublayer and an MLP sublayer, each wrapped with a residual connection and Layer Normalization.

In [None]:
#a single transformer encoder block, consisting MSA & MLP
class TransformerBlock(nn.Module):
  def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate=0.1):
    super().__init__()
    self.norm1 = nn.LayerNorm(embed_dim)
    self.attn = MultiHeadAttention(embed_dim, num_heads)
    self.norm2 = nn.LayerNorm(embed_dim)
    self.mlp = MLP(embed_dim, mlp_dim, dropout_rate)

  def forward(self, x):
    #1. attention (with residual connection and layer normalization)
    x = x + self.attn(self.norm1(x)) #the addition is a crucial step which helps with gradient flow and prevents information loss in deep networks.
    #2. MLP (with residual connection and layer normalization)
    x = x + self.mlp(self.norm2(x))
    return x #same shape as the input

2. Patch Embedding and full ViT model

In [None]:
class VisionTransformer(nn.Module):
  def __init__(self, img_size=32, patch_size=4, num_classes=100, embed_dim=192, depth=4, num_heads=3, mlp_ratio=4, dropout_rate=0.1):
    #depth: number of transformer blocks
    super().__init__()

    num_patches = (img_size // patch_size) ** 2 #calculates the total number of patches created from the image
    # E.g., a 32 x 32 image with 4 x 4 patches yields (32/4)^2 = 64 patches.
    in_channels = 3 #RGB

    #Calculate the MLP hidden layer
    mlp_dim = int(embed_dim * mlp_ratio)

    #1. patching embedding layer
    #uses a conv2d with stride=patch_size, kernel_size=patch_size
    self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    #2. class token (CLS)
    #prepends a learnable parameter to the sequence
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    #3. positional encoding
    #learns the spatial position of each batch + the CLS token
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))

    self.pos_dropout = nn.Dropout(p=dropout_rate)

    #4. transformer encoder
    self.transformer_blocks = nn.Sequential(*[
        TransformerBlock(embed_dim, num_heads, mlp_dim, dropout_rate)
        for _ in range(depth)
    ])

    #5. final normalization & head
    self.norm = nn.LayerNorm(embed_dim)

    # *** CRITICAL FIX/DEBUG: Hardcode 100 and print it to ensure the final layer is correct.
    FINAL_CLASSES = 100
    self.head = nn.Linear(embed_dim, FINAL_CLASSES)

    print(f"--- ViT INIT CHECK: Final Head Output Features: {self.head.out_features} (Expected 100) ---", file=sys.stderr, flush=True)
    # -----------------------------


    self._init_weights()

  def _init_weights(self): #helper method to initialize the weights of the layers and learnable parameters
    #simple initialization scheme
    nn.init.trunc_normal_(self.pos_embed, std=.02) #initializes the postional encoding with truncated normal distribution
    nn.init.trunc_normal_(self.cls_token, std=.02) #initializes the CLS token with truncated normal distribution
    self.apply(self._init_vit_weights) #recursively applies the custom initialization function _init_vit_weights to all submodules (Conv2D, Linear, LayerNorm, etc.)

  def _init_vit_weights(self, m):
    #linear layer's weights are initialized using a truncated normal distribution, and biases are set to 0.
    if isinstance(m, nn.Linear):
      nn.init.trunc_normal_(m.weight, std=.02)
      if m.bias is not None:
        nn.init.constant_(m.bias, 0)
    #layer normalization layer's bias is set to 0 & weight is set to 1.
    elif isinstance(m, nn.LayerNorm):
      nn.init.constant_(m.bias, 0)
      nn.init.constant_(m.weight, 1.0)

  def forward(self, x):
    B = x.shape[0] #batch size

    #1. Patch embedding
    # (B, 3, 32, 32) -> (B, Embed_dim, 8, 8) (since 32/4=8)
    x = self.patch_embed(x)
    #flatten patches: (B, Embed_dim, H, W) -> (B, H*W, Embed_dim)
    x = x.flatten(2).transpose(1, 2)

    #2. prepend CLS token (B, 1, Embed_dim)
    cls_token = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_token, x), dim=1)

    #3. add positional encoding
    x = x + self.pos_embed
    x = self.pos_dropout(x)

    #4. pass through transformer block
    x = self.transformer_blocks(x)

    #5. extract the CLS token
    #take only the first token for classification
    x = self.norm(x[:, 0])

    #6. classification head
    x = self.head(x)
    return x #final logits for classification

Truncated Normalization (or Truncated Normal Initialization) is a method of initializing the weights and biases of a neural network layer by drawing values from a Gaussian (Normal) distribution but strictly restricting the values to lie within a defined range. The key idea is to prevent the occurrence of extreme outliers that can destabilize the early stages of training.

Training process

In [None]:
#setting the configurations
IMG_SIZE = 32
PATCH_SIZE = 4
NUM_CLASSES = 100
EMBED_DIM = 192
DEPTH = 4
NUM_HEADS = 3
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
NUM_EPOCHS = 1
#these configurations are only used because of hardware limitations

In [None]:
device = "cpu"

In [None]:
#data preprocessing and augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), #zero-pads the border by 4 pixels and then randomly crops a 32 X 32 region
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #converts the image into tensor and also scales the value into the range [0.0, 1.0]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #standardizes the tensor by applying channel-wise normalization using the provided mean & standard deviation.
])

In [None]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
#loading CIFAR-100 dataset
train_datasets = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_datasets = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

In [None]:
train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
#initializing model, loss & optimizer
model = VisionTransformer(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    embed_dim=EMBED_DIM,
    depth=DEPTH,
    num_heads=NUM_HEADS
).to(device)

--- ViT INIT CHECK: Final Head Output Features: 100 (Expected 100) ---


In [None]:
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")

Model parameters: 1.82M


In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.AdamW(model.parameters(), lr = LEARNING_RATE)
#AdamW is an enhanced version of the traditional Adam (Adaptive Moment Estimation) optimizer, specifically designed to properly handle weight decay.

In [None]:
checked_output_shape = False

In [None]:
#training loop
for epoch in range(NUM_EPOCHS):
  model.train()
  running_loss = 0.0

  for i, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad()
    outputs = model(inputs)

    # --- DEBUGGING CHECK ---
    if not checked_output_shape:
      print(f"--- DEBUG SHAPE CHECK ---")
      print(f"Model Output Shape (logits): {outputs.shape}")
      print(f"Expected Class Dimension: {NUM_CLASSES}")
      print(f"Labels Shape: {labels.shape} | Max Label Index: {labels.max().item()}")
      print(f"-------------------------")
      checked_output_shape = True
    # --- END DEBUGGING CHECK ---

    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  avg_train_loss = running_loss / len(train_loader)

  #evaluation
  model.eval()
  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (labels==predicted).sum().item()

  accuracy = 100 * correct/total
  print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f} | Test Accuracy: {accuracy:.2f}%")



Epoch 1/1 | Train Loss: 3.5911 | Test Accuracy: 16.41%
