#Vision Transformer (ViT) Exercise
###Welcome to this hands-on exercise on Vision Transformers! You'll build a ViT from scratch to classify images in the CIFAR-10 dataset, learning how the revolutionary Transformer architecture works for computer vision tasks.
##Learning Objectives

* Understand the Transformer architecture and self-attention mechanism
* Learn how to adapt Transformers for computer vision (Vision Transformers)
* Implement patch embedding and positional encoding for images
* Build multi-head attention and feed-forward components
* Master the complete ViT pipeline from patches to predictions
* Compare ViT performance with CNNs

##Why Transformers for Vision?
###Traditional CNNs use local receptive fields, but Transformers can:

* Global Attention: Every patch can attend to every other patch
* No Inductive Bias: Learn spatial relationships from data
* Scalability: Performance improves with larger datasets
* Unified Architecture: Same architecture for NLP and Vision

#Architecture: Image → Patches → Embeddings → Transformer → Classification


##Introduction to Vision Transformers
###Key Concepts:

* Patch Embedding: Split image into patches and embed them
* Positional Encoding: Add spatial information to patches
* Self-Attention: Let patches "talk" to each other
* Transformer Encoder: Stack of attention + feed-forward layers
* Classification Token: Special [CLS] token for final prediction

###CIFAR-10 Dataset:

* 60,000 32x32 color images
* 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
* 50,000 training + 10,000 test images

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import math
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import warnings
warnings.filterwarnings('ignore')

# Set device and random seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

# CIFAR-10 class names
CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

NameError: name 'transforms' is not defined

## Dataset Loading and Exploration

In [None]:
def load_cifar10_data():
    """Load and explore CIFAR-10 dataset"""

    # TODO: Define data transforms
    # For ViT, we typically need:
    # 1. Convert to tensor
    # 2. Normalize (ImageNet stats work well for transfer learning)
    # 3. Optional: data augmentation
    # YOUR CODE HERE:

    transform_train = transforms.Compose([
        # Add data augmentation for training
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # TODO: Load CIFAR-10 datasets
    # YOUR CODE HERE:
    train_dataset = CIFAR10(root='/content/', train=True, download=True, transform=transform_train)
    test_dataset = CIFAR10(root='/content/', train=False, download=True, transform=transform_test)

    # Split training set into train and validation
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Update validation dataset transform to test transform
    val_dataset.dataset = CIFAR10(root='./data', train=True, download=False, transform=transform_test)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    return train_dataset, val_dataset, test_dataset

def visualize_cifar10_samples(dataset, num_samples=20):
    """Visualize sample images from CIFAR-10"""

    # Create a temporary dataset without normalization for visualization
    viz_transform = transforms.Compose([transforms.ToTensor()])
    viz_dataset = CIFAR10(root='./data', train=True, download=False, transform=viz_transform)

    # Get random samples
    indices = np.random.choice(len(viz_dataset), num_samples, replace=False)

    fig, axes = plt.subplots(4, 5, figsize=(12, 10))
    axes = axes.ravel()

    for i, idx in enumerate(indices):
        image, label = viz_dataset[idx]

        # Convert tensor to numpy and transpose for plotting
        image_np = image.permute(1, 2, 0).numpy()

        axes[i].imshow(image_np)
        axes[i].set_title(f'{CIFAR10_CLASSES[label]}')
        axes[i].axis('off')

    plt.suptitle('CIFAR-10 Sample Images')
    plt.tight_layout()
    plt.show()

def analyze_class_distribution(dataset):
    """Analyze the distribution of classes in the dataset"""

    # Count classes
    class_counts = [0] * 10

    # For subset datasets, we need to access the underlying dataset
    if hasattr(dataset, 'dataset'):
        base_dataset = dataset.dataset
        indices = dataset.indices
        for idx in indices:
            _, label = base_dataset[idx]
            class_counts[label] += 1
    else:
        for _, label in dataset:
            class_counts[label] += 1

    # Plot distribution
    plt.figure(figsize=(12, 6))
    bars = plt.bar(CIFAR10_CLASSES, class_counts, color='skyblue', alpha=0.7)
    plt.title('CIFAR-10 Class Distribution')
    plt.xlabel('Classes')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45)

    # Add value labels on bars
    for bar, count in zip(bars, class_counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
                str(count), ha='center', va='bottom')

    plt.tight_layout()
    plt.show()

    print("Class distribution:")
    for i, (class_name, count) in enumerate(zip(CIFAR10_CLASSES, class_counts)):
        print(f"{i}: {class_name}: {count} images")

# Load and explore the data
train_dataset, val_dataset, test_dataset = load_cifar10_data()

# Visualize samples
visualize_cifar10_samples(train_dataset)

# Analyze class distribution
analyze_class_distribution(train_dataset)

## Patch Embedding - Converting Images to Sequences

In [None]:
class PatchEmbedding(nn.Module):
    """Convert image into sequence of patch embeddings"""

    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=192):
        super(PatchEmbedding, self).__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        # TODO: Calculate number of patches
        # YOUR CODE HERE:
        self.num_patches = (img_size // patch_size) ** 2

        # TODO: Create patch embedding layer
        # Hint: Use Conv2d with kernel_size=patch_size, stride=patch_size
        # This effectively splits the image into patches and projects them
        # YOUR CODE HERE:
        self.projection = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # TODO: Implement patch embedding
        # Steps:
        # 1. Apply convolution to create patches
        # 2. Flatten spatial dimensions
        # 3. Transpose to get (batch_size, num_patches, embed_dim)
        # YOUR CODE HERE:

        # x shape: (batch_size, channels, height, width)
        batch_size = x.shape[0]

        # Apply patch embedding convolution
        # Output shape: (batch_size, embed_dim, num_patches_h, num_patches_w)
        x = self.projection(x)

        # Flatten spatial dimensions and transpose
        # Shape: (batch_size, embed_dim, num_patches_h * num_patches_w)
        x = x.flatten(2)

        # Transpose to get (batch_size, num_patches, embed_dim)
        x = x.transpose(1, 2)

        return x

# Test patch embedding
patch_embed = PatchEmbedding(img_size=32, patch_size=4, in_channels=3, embed_dim=192)

# Test with a batch of images
test_input = torch.randn(8, 3, 32, 32)  # Batch of 8 CIFAR-10 images
patch_output = patch_embed(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Patch embedding output shape: {patch_output.shape}")
print(f"Number of patches: {patch_embed.num_patches}")
print(f"Each patch is embedded to {patch_embed.embed_dim} dimensions")

# Visualize patch splitting
def visualize_patch_splitting():
    """Visualize how an image is split into patches"""

    # Get a sample image
    viz_transform = transforms.Compose([transforms.ToTensor()])
    viz_dataset = CIFAR10(root='./data', train=True, download=False, transform=viz_transform)
    sample_image, label = viz_dataset[0]

    patch_size = 4

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Original image
    axes[0].imshow(sample_image.permute(1, 2, 0))
    axes[0].set_title(f'Original Image: {CIFAR10_CLASSES[label]}')
    axes[0].axis('off')

    # Image with patch grid
    axes[1].imshow(sample_image.permute(1, 2, 0))

    # Draw patch boundaries
    for i in range(0, 32, patch_size):
        axes[1].axhline(y=i, color='red', linewidth=1)
        axes[1].axvline(x=i, color='red', linewidth=1)

    axes[1].set_title(f'Image with {patch_size}x{patch_size} Patch Grid')
    axes[1].axis('off')

    # Individual patches
    patches = []
    for i in range(0, 32, patch_size):
        for j in range(0, 32, patch_size):
            patch = sample_image[:, i:i+patch_size, j:j+patch_size]
            patches.append(patch)

    # Show first 16 patches
    patch_grid = torch.stack(patches[:16])
    grid_image = torchvision.utils.make_grid(patch_grid, nrow=4, padding=1)

    axes[2].imshow(grid_image.permute(1, 2, 0))
    axes[2].set_title('First 16 Patches (4x4 grid)')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

visualize_patch_splitting()

## Positional Encoding - Adding Spatial Information

In [None]:
class PositionalEncoding(nn.Module):
    """Add positional information to patch embeddings"""

    def __init__(self, num_patches, embed_dim, dropout=0.1):
        super(PositionalEncoding, self).__init__()

        self.num_patches = num_patches
        self.embed_dim = embed_dim

        # TODO: Create learnable positional embeddings
        # Option 1: Learnable embeddings (most common in ViT)
        # Option 2: Fixed sinusoidal embeddings (original Transformer)
        # YOUR CODE HERE:

        # Learnable positional embeddings
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        # Classification token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # TODO: Add positional encoding
        # Steps:
        # 1. Add classification token to the beginning
        # 2. Add positional embeddings
        # 3. Apply dropout
        # YOUR CODE HERE:

        batch_size, num_patches, embed_dim = x.shape

        # Expand classification token for batch
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)

        # Concatenate cls token with patch embeddings
        x = torch.cat([cls_tokens, x], dim=1)

        # Add positional embeddings
        x = x + self.pos_embedding

        # Apply dropout
        x = self.dropout(x)

        return x

# Alternative: Fixed sinusoidal positional encoding
def get_sinusoidal_positional_encoding(num_patches, embed_dim):
    """Create fixed sinusoidal positional encodings"""

    # TODO: Implement sinusoidal positional encoding
    # This is the original approach from "Attention is All You Need"
    # YOUR CODE HERE:

    position = torch.arange(num_patches).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, embed_dim, 2).float() *
                        -(math.log(10000.0) / embed_dim))

    pos_encoding = torch.zeros(num_patches, embed_dim)
    pos_encoding[:, 0::2] = torch.sin(position * div_term)
    pos_encoding[:, 1::2] = torch.cos(position * div_term)

    return pos_encoding.unsqueeze(0)  # Add batch dimension

# Test positional encoding
num_patches = 64  # 8x8 patches for 32x32 image with 4x4 patches
embed_dim = 192

pos_encoder = PositionalEncoding(num_patches, embed_dim)

# Test input (batch_size=8, num_patches=64, embed_dim=192)
test_patches = torch.randn(8, num_patches, embed_dim)
encoded_patches = pos_encoder(test_patches)

print(f"Input patches shape: {test_patches.shape}")
print(f"After positional encoding: {encoded_patches.shape}")
print("Note: sequence length increased by 1 due to classification token")

# Visualize positional embeddings
def visualize_positional_embeddings():
    """Visualize the learned positional embeddings"""

    pos_embeddings = pos_encoder.pos_embedding[0, 1:, :].detach()  # Exclude cls token

    # Reshape to spatial layout (8x8 patches)
    spatial_size = int(math.sqrt(num_patches))
    pos_reshaped = pos_embeddings.view(spatial_size, spatial_size, embed_dim)

    # Plot first few dimensions of positional embeddings
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.ravel()

    for i in range(8):
        im = axes[i].imshow(pos_reshaped[:, :, i], cmap='viridis')
        axes[i].set_title(f'Position Embedding Dim {i}')
        axes[i].axis('off')
        plt.colorbar(im, ax=axes[i])

    plt.suptitle('Learned Positional Embeddings (First 8 Dimensions)')
    plt.tight_layout()
    plt.show()

# Uncomment after training to visualize learned embeddings
# visualize_positional_embeddings()

SyntaxError: invalid syntax (ipython-input-9-1062803426.py, line 7)

## Multi-Head Self-Attention - The Heart of Transformers

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention mechanism"""

    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()

        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # TODO: Define linear projections for Q, K, V
        # YOUR CODE HERE:
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # TODO: Implement multi-head self-attention
        # Steps:
        # 1. Create Q, K, V from input
        # 2. Reshape for multi-head attention
        # 3. Compute attention scores
        # 4. Apply softmax and dropout
        # 5. Apply attention to values
        # 6. Concatenate heads and project
        # YOUR CODE HERE:

        batch_size, seq_length, embed_dim = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x)  # (batch_size, seq_length, embed_dim * 3)
        qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_length, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        out = torch.matmul(attn_weights, v)

        # Concatenate heads
        out = out.transpose(1, 2).reshape(batch_size, seq_length, embed_dim)

        # Final projection
        out = self.proj(out)

        return out, attn_weights

class FeedForward(nn.Module):
    """Feed-forward network (MLP) component"""

    def __init__(self, embed_dim, hidden_dim, dropout=0.1):
        super(FeedForward, self).__init__()

        # TODO: Define feed-forward network
        # Typical structure: Linear -> GELU -> Dropout -> Linear
        # YOUR CODE HERE:
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    """Single Transformer encoder block"""

    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super(TransformerBlock, self).__init__()

        # TODO: Define transformer block components
        # Structure: LayerNorm -> MultiHeadAttention -> Residual
        #           LayerNorm -> FeedForward -> Residual
        # YOUR CODE HERE:

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = FeedForward(embed_dim, hidden_dim, dropout)

    def forward(self, x):
        # TODO: Implement transformer block forward pass
        # Use residual connections around attention and MLP
        # YOUR CODE HERE:

        # Multi-head attention with residual connection
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out

        # Feed-forward with residual connection
        mlp_out = self.mlp(self.norm2(x))
        x = x + mlp_out

        return x, attn_weights

# Test transformer components
embed_dim = 192
num_heads = 8
seq_length = 65  # 64 patches + 1 cls token

# Test multi-head attention
attention = MultiHeadSelfAttention(embed_dim, num_heads)
test_input = torch.randn(4, seq_length, embed_dim)
attn_output, attn_weights = attention(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Attention output shape: {attn_output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

# Test transformer block
transformer_block = TransformerBlock(embed_dim, num_heads)
block_output, block_attn_weights = transformer_block(test_input)

print(f"Transformer block output shape: {block_output.shape}")

## Complete Vision Transformer Model

In [None]:
class VisionTransformer(nn.Module):
    """Complete Vision Transformer model"""

    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 embed_dim=192, depth=12, num_heads=8, mlp_ratio=4, dropout=0.1):
        super(VisionTransformer, self).__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.depth = depth

        # TODO: Assemble all ViT components
        # 1. Patch embedding
        # 2. Positional encoding
        # 3. Stack of transformer blocks
        # 4. Final layer norm
        # 5. Classification head
        # YOUR CODE HERE:

        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Positional encoding (includes cls token)
        self.pos_embed = PositionalEncoding(num_patches, embed_dim, dropout)

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        """Initialize model weights"""
        if isinstance(m, nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # TODO: Implement complete ViT forward pass
        # Steps:
        # 1. Patch embedding
        # 2. Add positional encoding
        # 3. Pass through transformer blocks
        # 4. Extract cls token
        # 5. Apply final norm and classification head
        # YOUR CODE HERE:

        # Patch embedding
        x = self.patch_embed(x)

        # Add positional encoding (includes cls token)
        x = self.pos_embed(x)

        # Pass through transformer blocks
        attn_weights_list = []
        for block in self.blocks:
            x, attn_weights = block(x)
            attn_weights_list.append(attn_weights)

        # Apply final layer norm
        x = self.norm(x)

        # Extract cls token (first token) and classify
        cls_token = x[:, 0]
        logits = self.head(cls_token)

        return logits, attn_weights_list

    def get_attention_maps(self, x, layer_idx=-1):
        """Extract attention maps for visualization"""
        self.eval()
        with torch.no_grad():
            _, attn_weights_list = self.forward(x)

            # Return attention weights from specified layer
            return attn_weights_list[layer_idx]

# Create different ViT configurations
def create_vit_tiny():
    """Create a tiny ViT for CIFAR-10"""
    return VisionTransformer(
        img_size=32, patch_size=4, in_channels=3, num_classes=10,
        embed_dim=192, depth=6, num_heads=8, mlp_ratio=4, dropout=0.1
    )

def create_vit_small():
    """Create a small ViT for CIFAR-10"""
    return VisionTransformer(
        img_size=32, patch_size=4, in_channels=3, num_classes=10,
        embed_dim=384, depth=12, num_heads=8, mlp_ratio=4, dropout=0.1
    )

def create_vit_base():
    """Create a base ViT for CIFAR-10"""
    return VisionTransformer(
        img_size=32, patch_size=4, in_channels=3, num_classes=10,
        embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.1
    )

# Test model creation
model = create_vit_tiny()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
test_input = torch.randn(2, 3, 32, 32)
logits, attn_weights = model(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Number of attention layers: {len(attn_weights)}")
print(f"Attention weights shape (last layer): {attn_weights[-1].shape}")

# Model summary
def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Summary:")
print(f"Trainable parameters: {count_parameters(model):,}")

# Compare model sizes
models = {
    'ViT-Tiny': create_vit_tiny(),
    'ViT-Small': create_vit_small(),
    'ViT-Base': create_vit_base()
}

print(f"\nModel Size Comparison:")
for name, model in models.items():
    params = count_parameters(model)
    print(f"{name}: {params:,} parameters")

## Training Setup and Data Loaders

In [None]:
def create_data_loaders(train_dataset, val_dataset, test_dataset, batch_size=128):
    """Create data loaders for training"""

    # TODO: Create data loaders with appropriate batch sizes
    # YOUR CODE HERE:

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

def train_vit(model, train_loader, val_loader, num_epochs=100,
              learning_rate=3e-4, weight_decay=0.1):
    """Train the Vision Transformer"""

    # TODO: Setup training components
    # ViTs typically use AdamW optimizer with cosine learning rate schedule
    # YOUR CODE HERE:

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Cosine annealing scheduler
    scheduler = optim.lr_

##Made by Abdullah Jan