1. Dataset MNIST - load, train/test split, batching
2. Patch Embeddings -> input flattened then linealry projected 
3. Transformer Encoder
4. MLP Head
5. Final Classification
6. All these in one class - ViT
7. A Training loop
8. Validation
9. Randomized visualizations
10. No decoder - sp no generation,only classification

In [1]:
# import necessay libraries and dataset
import pandas as pd
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn

In [35]:
transformation_operation = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Standardize to mean 0, std 1
])

In [36]:
# import dataset
training_data = torchvision.datasets.MNIST(root='./data', train = True, download = True, transform = transformation_operation)
val_data = torchvision.datasets.MNIST(root='./data', train = False, download = True, transform = transformation_operation)

In [37]:
# defining variables
input_size = 28 # because MNIST is 28x28 pixels # input means one image here
patch_size = 4 # paper follows 16x16 but for 28x28 we can use 7x7,else lesser patches
num_classes = 10 # 0-9 digits
batch_size = 64
num_channels = 1 # MNIST is grayscale
num_patches = (input_size // patch_size) ** 2 # (h/p * w/p, h = p here)
embedding_dim = 64 
attention_heads = 4
transformer_blocks = 4 #( whole encoer blokcs repeated 4 times )
mlp_hidden_nodes = 128 # why 128 ? paper follows 3072 but for small dataset we can use 128
learning_rate = 0.001
epochs = 5

In [38]:
# define dataset batches
import torch.utils.data as data_loader
train_loader = data_loader.DataLoader(training_data, batch_size=batch_size, shuffle=True)
val_loader = data_loader.DataLoader(val_data, batch_size=batch_size, shuffle=False)

1. Patch embedding
2. Transformer encoder
3. MLP head
 
The Transformer class incorporanting all above parts

In [48]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        # Assuming 'patch_size' is a global or class-level variable accessible here
        self.z = patch_size // 2 
        
        # CORRECTED SYNTAX: Use a simple expression or the variable name
        self.p = patch_size if self.z == 0 else self.z 
        
        # Note: The variables self.z and self.p are calculated but not used 
        # in the nn.Conv2d definition, which is standard.
        self.patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: (batch, channels, height, width)
        # apply conv to get patch embeddings: (batch, embedding_dim, n_h, n_w)
        x = self.patch_embed(x)
        # Add non-linearity as suggested earlier (a good, subtle improvement!)
        x = nn.GELU()(x)
        # flatten spatial dims to get patches: (batch, embedding_dim, num_patches)
        x = x.flatten(2)
        # transpose to (batch, num_patches, embedding_dim) which transformer expects
        x = x.transpose(1, 2)
        return x

In [49]:
data_point  = next(iter(train_loader))
images, labels = data_point
patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size = patch_size, stride = patch_size)
print(f"shape of data point: {images.shape}")
patch_embed_output = patch_embed(images)
print(f"shape of patch_embed_output: {patch_embed_output.shape}")
# print(patch_embed(images))
patch_embed_output_flatten = patch_embed_output.flatten(2) # from 3rd dimension 4, 4 -> 16
print(patch_embed_output_flatten.shape)

shape of data point: torch.Size([64, 1, 28, 28])
shape of patch_embed_output: torch.Size([64, 64, 7, 7])
torch.Size([64, 64, 49])


In [50]:
# Part 2 Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # define layers here
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        self.layer_norm3 = nn.LayerNorm(embedding_dim)
        # set batch_first=True so inputs shaped (batch, seq_len, embed_dim) work without transposes
        self.multihead_attn = nn.MultiheadAttention(embedding_dim, attention_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_hidden_nodes),
            nn.GELU(),
            nn.Linear(mlp_hidden_nodes, embedding_dim)
        )

    def forward(self ,x):
        residual1 = x
        x = self.layer_norm1(x)
        x = self.multihead_attn(x, x, x)[0]  # QKV, returns (attn_output, attn_weights)
        x = x + residual1
        residual2 = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = x + residual2
        residual3 = x
        x = self.layer_norm3(x)
        
        x += residual3
        return x

In [51]:
# Part 3: MLP Head for classification
# only class token output is taken for classification
class MLPHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.mlp_head = nn.Linear(embedding_dim, num_classes)
        
    def forward(self ,x):
        x = self.layer_norm1(x)
        x = self.mlp_head(x) # taking only the class token output
        return x

In [52]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbedding()
        # cls token
        self.cls_token = nn.Parameter(torch.randn(1,1,embedding_dim)) # 1 token, 1 channel, embedding_dim trainable parameter
        # positional embedding
        # learnable or hard coded - main paper - learnable, not much difference
        self.positional_embed = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim)) # +1 for cls token
        # cls token passed into mlp head
        # unpack the list so nn.Sequential receives Module arguments, not a single list
        self.transformer_blocks = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
        self.mlp_head = MLPHead()
    
    def forward(self, x):
        x = self.patch_embed(x) # (batch_size, num_patches, embedding_dim)
        B = x.size(0)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        # prepend cls token to the patch embeddings
        x = torch.cat((cls_tokens, x), dim=1) # (batch_size, num_patches + 1, embedding_dim)
        # add positional embeddings
        x = x + self.positional_embed
        x = self.transformer_blocks(x)
        y = x.shape[1]
        x =  x[:, 0]+x[:, y - 1]
        x = self.mlp_head(x)
        return x

In [53]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [54]:
# Training
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"epoch = {epoch+1}")
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # 1. Zero the gradients
        optimizer.zero_grad()
        
        # 2. Forward pass
        outputs = model(images)
        
        # 3. Compute the loss
        loss = criterion(outputs, labels)
        
        # 4. Backward pass and optimization step
        loss.backward()
        optimizer.step()
        
        # Tracking training metrics
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_epoch += labels.size(0)
        correct_epoch += (predicted == labels).sum().item()
        
        if (batch_idx + 1) % 100 == 0:
            print(f"Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
            
    # Calculate and print epoch summary
    epoch_loss = total_loss / len(train_loader)
    epoch_accuracy = 100 * correct_epoch / total_epoch
    print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_accuracy:.2f}%")

epoch = 1
Batch [100/938], Loss: 0.7178
Batch [200/938], Loss: 0.3613
Batch [300/938], Loss: 0.3925
Batch [400/938], Loss: 0.2322
Batch [500/938], Loss: 0.1776
Batch [600/938], Loss: 0.2690
Batch [700/938], Loss: 0.1258
Batch [800/938], Loss: 0.1674
Batch [900/938], Loss: 0.2656
Epoch [1/5], Training Loss: 0.3761, Training Accuracy: 88.02%
epoch = 2
Batch [100/938], Loss: 0.1227
Batch [200/938], Loss: 0.1184
Batch [300/938], Loss: 0.1080
Batch [400/938], Loss: 0.1202
Batch [500/938], Loss: 0.1212
Batch [600/938], Loss: 0.2278
Batch [700/938], Loss: 0.0587
Batch [800/938], Loss: 0.1510
Batch [900/938], Loss: 0.0964
Epoch [2/5], Training Loss: 0.1246, Training Accuracy: 96.19%
epoch = 3
Batch [100/938], Loss: 0.1593
Batch [200/938], Loss: 0.0757
Batch [300/938], Loss: 0.1388
Batch [400/938], Loss: 0.0879
Batch [500/938], Loss: 0.0973
Batch [600/938], Loss: 0.0684
Batch [700/938], Loss: 0.0548
Batch [800/938], Loss: 0.0941
Batch [900/938], Loss: 0.0900
Epoch [3/5], Training Loss: 0.0946, 

In [None]:
# Validation/Evaluation Function

def evaluate_model(model, val_loader, criterion, device):
    # Set the model to evaluation mode
    model.eval() 
    total_val_loss = 0
    correct_val = 0
    total_val = 0
    
    # Disable gradient computation
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss and predictions
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
            
    val_loss = total_val_loss / len(val_loader)
    val_accuracy = 100 * correct_val / total_val
    
    print(f"\n✨ Validation Results ✨")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    
    return val_loss, val_accuracy

# Call the evaluation function after training
evaluate_model(model, val_loader, criterion, device)


✨ Validation Results ✨
Validation Loss: 0.0661, Validation Accuracy: 97.89%


(0.06610508098943266, 97.89)