# Vision Transformer Code Explanation Tutorial

This tutorial breaks down the key components and functions from `runner.py` and `network.py`, focusing on practical understanding rather than class structures. You'll learn how each piece works and see demonstrations of their functionality.

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import sys, os
from pathlib import Path
sys.path.append(str(Path(os.getcwd()).parent))
from network import SingleHeadAttention, MultiHeadAttention, TransformerBlock, PatchEmbedding, VisionTransformer

### ViT Block
Like before, we have our Transformer block which is assembled out of several multi-head attention modules. Each transformer block first normalizes the input, then applies attention, and then passes the normalized output through a feedforward network. Unlike our NLP transformers, it is typical to perform normalization *prior* to passing through the module

**Block Architecture:**
```
Input → LayerNorm1 → MultiHeadAttention → + → LayerNorm2 → FeedForward → +
  ↑                                                                      ↑
  └──────────────────── Residual Connection ─────────────────────────────┘
```


In [22]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(TransformerBlock, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.attention = MultiHeadAttention(hidden_size, hidden_size, hidden_size, num_heads)
        self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size))
        
    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.net(self.norm2(x))
        return x

### Patch Embedding
What is a "token" in an image? We can take the token and break it into "patches" - essentially by performing a convolution over the image and downsizing the image

**Key Functions:**
- **Image Patching**: Divides image into 4×4 patches
- **Linear Projection**: Converts patches to high-dimensional vectors
- **CLS Token**: Special token for classification
- **Position Embeddings**: Learnable position information

In [29]:
class PatchEmbedding(nn.Module):
    def __init__(self, channels, img_size=32, patch_size=4, hidden_size=128):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.channels = channels
        # Each patch is patch_size x patch_size pixels
        # Total number of patches is (img_size/patch_size)^2
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(channels, hidden_size, kernel_size=patch_size, stride=patch_size) 
        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        # Positional embedding for patches + CLS token
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, hidden_size))
        
    def forward(self, x):
        # x: [B, C, H, W] -> [B, num_patches, hidden_size]
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x


In [None]:
import torchvision
import torchvision.transforms as transforms

# Load CIFAR10 dataset and get one image
transform = transforms.ToTensor()
cifar10 = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
image, label = cifar10[0]  # Get first image
image = image.unsqueeze(0)  # Add batch dimension [1, 3, 32, 32]

patch_embed = PatchEmbedding(
    channels=3,
    img_size=32,
    patch_size=4,
    hidden_size=128
)

# Get patches from the image
patches = patch_embed(image)

print(f"Original image shape: {image.shape}")  # [1, 3, 32, 32]
print(f"Patches shape: {patches.shape}")  # [1, 64, 128] - 64 "patches" of size (representation) 128
print(f"Label: {cifar10.classes[label]}")

Original image shape: torch.Size([1, 3, 32, 32])
Patches shape: torch.Size([1, 64, 128])
Label: cat
