# üëÅÔ∏è Vision Transformer (ViT)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/transformer_architectures/04_vision_transformer/demo.ipynb)

![Architecture](architecture.png)

### Key Innovation
- **Image patches as tokens**: Split image into 16√ó16 patches
- **[CLS] token**: For classification
- **Pure transformer**: No convolutions needed!

In [None]:
!pip install torch torchvision matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## ViT Architecture

In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Conv2d acts as patch embedding
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, n_patches_h, n_patches_w)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

class ViTAttention(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ V).transpose(1, 2).reshape(B, N, C)
        return self.proj(x), attn

class ViTBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = ViTAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * mlp_ratio, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x, attn_weights

class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, n_classes=10,
                 embed_dim=128, n_heads=4, n_layers=4, dropout=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # [CLS] token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([ViTBlock(embed_dim, n_heads, dropout=dropout) for _ in range(n_layers)])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add position embedding
        x = x + self.pos_embed
        x = self.dropout(x)
        
        # Transformer blocks
        attn_weights = []
        for block in self.blocks:
            x, attn = block(x)
            attn_weights.append(attn)
        
        # Classification
        x = self.norm(x)
        cls_output = x[:, 0]  # [CLS] token
        return self.head(cls_output), attn_weights

model = ViT(img_size=32, patch_size=4, n_classes=10, embed_dim=64, n_heads=4, n_layers=3).to(device)
print(f'ViT Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Visualize Patch Embedding

In [None]:
# Visualize how image is split into patches
def visualize_patches(img, patch_size=4):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img.permute(1, 2, 0))
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Image with patch grid
    axes[1].imshow(img.permute(1, 2, 0))
    h, w = img.shape[1], img.shape[2]
    for i in range(0, h, patch_size):
        axes[1].axhline(y=i, color='r', linewidth=1)
    for j in range(0, w, patch_size):
        axes[1].axvline(x=j, color='r', linewidth=1)
    axes[1].set_title(f'Patches ({h//patch_size}√ó{w//patch_size} = {(h//patch_size)**2} patches)')
    axes[1].axis('off')
    
    # Patches as sequence
    n_patches = (h // patch_size) ** 2
    patch_sequence = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    patch_sequence = patch_sequence.contiguous().view(3, -1, patch_size, patch_size)
    
    for i in range(min(16, n_patches)):
        ax = fig.add_subplot(3, 20, 41 + i)
        ax.imshow(patch_sequence[:, i].permute(1, 2, 0))
        ax.axis('off')
    
    axes[2].axis('off')
    axes[2].set_title('First 16 Patches as Sequence')
    
    plt.tight_layout()
    plt.show()

# Create sample image
sample_img = torch.rand(3, 32, 32)
visualize_patches(sample_img, patch_size=4)

## Training on CIFAR-10

In [None]:
# Load CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(f'Training samples: {len(trainset)}')
print(f'Test samples: {len(testset)}')

In [None]:
# Training
model = ViT(img_size=32, patch_size=4, n_classes=10, embed_dim=64, n_heads=4, n_layers=4).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

n_epochs = 5
losses = []
accuracies = []

print('Training ViT on CIFAR-10...')
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    
    for i, (images, labels) in enumerate(trainloader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs, _ = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if (i + 1) % 200 == 0:
            print(f'Epoch {epoch+1}, Step {i+1}: Loss = {running_loss/200:.4f}')
            losses.append(running_loss/200)
            running_loss = 0.0
    
    # Evaluate
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs, _ = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    acc = 100 * correct / total
    accuracies.append(acc)
    print(f'Epoch {epoch+1} Test Accuracy: {acc:.2f}%')

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(losses)
axes[0].set_xlabel('Step (√ó200)')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(accuracies, 'o-')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Test Accuracy')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Visualize attention
model.eval()
images, labels = next(iter(testloader))
img = images[0:1].to(device)

with torch.no_grad():
    _, attn_weights = model(img)

# Visualize attention from [CLS] to patches
attn = attn_weights[-1][0, 0, 0, 1:].reshape(8, 8).cpu()  # Last layer, head 0, from CLS

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Original image
axes[0].imshow(images[0].permute(1, 2, 0) * 0.5 + 0.5)
axes[0].set_title(f'Input Image: {classes[labels[0]]}')
axes[0].axis('off')

# Attention map
axes[1].imshow(attn, cmap='hot')
axes[1].set_title('Attention from [CLS] Token')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print('\nüéØ Key Takeaways:')
print('1. ViT treats image patches as tokens')
print('2. [CLS] token aggregates information for classification')
print('3. No convolutions - pure transformer')
print('4. Needs large data or pretraining for best results')