In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import math
import json
import os

# ==================== MODEL COMPONENTS ====================

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))


class PatchEmbeddings(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.patch_size = cfg["patch"]
        self.img_size = cfg["image_size"]
        self.hidden_size = cfg["dim"]
        self.no_patch = (self.img_size // self.patch_size) ** 2
        self.layer = nn.Conv2d(3, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)
    
    def forward(self, x):
        x = self.layer(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embed = PatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1, 1, config["dim"]))
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.no_patch + 1, config["dim"]))

    def forward(self, x):
        x = self.patch_embed(x)
        batch_size, _, _ = x.shape
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_embed + x
        return x


class SelfAttention(nn.Module):
    def __init__(self, inp_dim, cfg):
        super().__init__()
        self.patch_size = cfg["patch"]
        self.img_size = cfg["image_size"]
        self.context_length = (self.img_size // self.patch_size) ** 2 + 1
        self.inp_dim = inp_dim
        self.w_q = nn.Linear(inp_dim, inp_dim)
        self.w_k = nn.Linear(inp_dim, inp_dim)
        self.w_v = nn.Linear(inp_dim, inp_dim)
        self.dropout = nn.Dropout(0.1)
        self.attn_weights = None  # Store attention weights
    
    def forward(self, x):
        b, p, d = x.shape
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        attn = q @ k.transpose(1, 2)
        dim_k = k.shape[-1]
        attn = attn / dim_k ** 0.5
        attn_scores = torch.softmax(attn, dim=-1)
        self.attn_weights = attn_scores  # Save for visualization
        attn_scores = self.dropout(attn_scores)
        return attn_scores @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, inp_dim, cfg):
        super().__init__()
        self.head_dim = inp_dim // n_heads
        self.n_heads = n_heads
        self.heads = nn.ModuleList([SelfAttention(self.head_dim, cfg) for _ in range(n_heads)])
    
    def forward(self, x):
        # Split input for each head
        batch_size, seq_len, dim = x.shape
        x_split = x.view(batch_size, seq_len, self.n_heads, self.head_dim)
        x_split = x_split.transpose(1, 2)  # (batch, n_heads, seq_len, head_dim)
        
        outputs = []
        for i, head in enumerate(self.heads):
            head_input = x_split[:, i, :, :].contiguous()
            outputs.append(head(head_input))
        
        x = torch.cat(outputs, dim=-1)
        return x
    
    def get_attention_weights(self):
        """Get attention weights from all heads"""
        return [head.attn_weights for head in self.heads if head.attn_weights is not None]


class MLP(nn.Module):
    def __init__(self, inp_dim, drop_rate):
        super().__init__()
        self.layer1 = nn.Linear(inp_dim, 4 * inp_dim)
        self.layer2 = nn.Linear(4 * inp_dim, inp_dim)
        self.dropout = nn.Dropout(drop_rate)
        self.activation = GELU()
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.dropout(x)
        return x


class Transformer(nn.Module):
    def __init__(self, dim, n_heads, cfg):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(n_heads, dim, cfg)
        self.mlp = MLP(dim, 0.1)
    
    def forward(self, x):
        residue = x
        x = self.norm1(x)
        x = self.attn(x)
        x = residue + x
        residue = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = residue + x
        return x


class Encoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.ModuleList([
            Transformer(cfg["dim"], cfg["n_heads"], cfg) for _ in range(cfg["layers"])
        ])
        self.embedding = Embeddings(cfg)
    
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        return x
    
    def get_attention_weights(self):
        """Get attention weights from all layers"""
        all_attentions = []
        for layer in self.layers:
            all_attentions.append(layer.attn.get_attention_weights())
        return all_attentions


class ViT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.image_size = config["image_size"]
        self.hidden_size = config["dim"]
        self.num_classes = config["num_classes"]
        self.encoder = Encoder(config)
        self.out = nn.Linear(config["dim"], self.num_classes)
    
    def forward(self, x):
        x = self.encoder(x)
        cls = x[:, 0, :]
        logits = self.out(cls)
        return logits
    
    def get_attention_weights(self):
        """Get all attention weights for visualization"""
        return self.encoder.get_attention_weights()


# ==================== DATA PREPARATION ====================

def prepare_data(batch_size=32, num_workers=2):
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((64, 64), scale=(0.8, 1.0)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=train_transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=num_workers)

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=num_workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainloader, testloader, classes


# ==================== UTILITIES ====================

def save_experiment(experiment_name, config, model, train_losses, test_losses, accuracies, base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)

    with open(os.path.join(outdir, 'config.json'), 'w') as f:
        json.dump(config, f, sort_keys=True, indent=4)

    with open(os.path.join(outdir, 'metrics.json'), 'w') as f:
        data = {
            'train_losses': train_losses,
            'test_losses': test_losses,
            'accuracies': accuracies,
        }
        json.dump(data, f, sort_keys=True, indent=4)

    torch.save(model.state_dict(), os.path.join(outdir, 'model_final.pt'))


# ==================== TRAINER ====================

class Trainer:
    def __init__(self, model, optimizer, loss_fn, exp_name, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device

    def train(self, trainloader, testloader, epochs):
        train_losses, test_losses, accuracies = [], [], []
        
        for epoch in range(epochs):
            train_loss = self.train_epoch(trainloader)
            accuracy, test_loss = self.evaluate(testloader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            print(f"Epoch: {epoch+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        save_experiment(self.exp_name, self.model.config, self.model, train_losses, test_losses, accuracies)
        return train_losses, test_losses, accuracies

    def train_epoch(self, trainloader):
        self.model.train()
        total_loss = 0
        for images, labels in trainloader:
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
            logits = self.model(images)
            loss = self.loss_fn(logits, labels)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        for images, labels in testloader:
            images, labels = images.to(self.device), labels.to(self.device)
            logits = self.model(images)
            loss = self.loss_fn(logits, labels)
            total_loss += loss.item() * len(images)
            predictions = torch.argmax(logits, dim=1)
            correct += torch.sum(predictions == labels).item()
        
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss


# ==================== ATTENTION VISUALIZATION ====================

def visualize_attention(model, image, label, classes, device, layer_idx=0, head_idx=0):
    """
    Visualize attention weights for a single image
    
    Args:
        model: Trained ViT model
        image: Input image tensor (C, H, W)
        label: True label
        classes: List of class names
        device: Device to run on
        layer_idx: Which transformer layer to visualize
        head_idx: Which attention head to visualize
    """
    model.eval()
    
    # Add batch dimension
    image_batch = image.unsqueeze(0).to(device)
    
    # Forward pass
    with torch.no_grad():
        logits = model(image_batch)
        prediction = torch.argmax(logits, dim=1).item()
        
        # Get attention weights
        all_attentions = model.get_attention_weights()
    
    # Get specific layer and head attention
    if layer_idx < len(all_attentions) and head_idx < len(all_attentions[layer_idx]):
        attn_weights = all_attentions[layer_idx][head_idx][0].cpu().numpy()  # [seq_len, seq_len]
    else:
        print(f"Invalid layer_idx or head_idx")
        return
    
    # Prepare image for display
    img_display = image.permute(1, 2, 0).cpu().numpy()
    img_display = (img_display * 0.5) + 0.5  # Denormalize
    img_display = np.clip(img_display, 0, 1)
    
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img_display)
    axes[0].set_title(f'Original Image\nTrue: {classes[label]}\nPred: {classes[prediction]}')
    axes[0].axis('off')
    
    # CLS token attention (how CLS attends to patches)
    cls_attn = attn_weights[0, 1:]  # Exclude CLS to CLS
    num_patches = int(np.sqrt(len(cls_attn)))
    cls_attn_map = cls_attn.reshape(num_patches, num_patches)
    
    im1 = axes[1].imshow(cls_attn_map, cmap='viridis')
    axes[1].set_title(f'CLS Token Attention\nLayer {layer_idx+1}, Head {head_idx+1}')
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], fraction=0.046)
    
    # Mean attention across all patches
    mean_attn = attn_weights[1:, 1:].mean(axis=0)  # Average over query positions
    mean_attn_map = mean_attn.reshape(num_patches, num_patches)
    
    im2 = axes[2].imshow(mean_attn_map, cmap='viridis')
    axes[2].set_title(f'Mean Patch Attention\nLayer {layer_idx+1}, Head {head_idx+1}')
    axes[2].axis('off')
    plt.colorbar(im2, ax=axes[2], fraction=0.046)
    
    plt.tight_layout()
    plt.savefig(f'attention_visualization_L{layer_idx}_H{head_idx}.png', dpi=150, bbox_inches='tight')
    plt.show()


def visualize_multiple_heads(model, image, label, classes, device, layer_idx=0, num_heads=4):
    """
    Visualize attention from multiple heads in a grid
    """
    model.eval()
    image_batch = image.unsqueeze(0).to(device)
    
    with torch.no_grad():
        logits = model(image_batch)
        prediction = torch.argmax(logits, dim=1).item()
        all_attentions = model.get_attention_weights()
    
    # Prepare image
    img_display = image.permute(1, 2, 0).cpu().numpy()
    img_display = (img_display * 0.5) + 0.5
    img_display = np.clip(img_display, 0, 1)
    
    # Create grid
    fig, axes = plt.subplots(2, num_heads // 2 + 1, figsize=(20, 8))
    axes = axes.flatten()
    
    # Show original image
    axes[0].imshow(img_display)
    axes[0].set_title(f'Original\nTrue: {classes[label]}\nPred: {classes[prediction]}', fontsize=10)
    axes[0].axis('off')
    
    # Show attention from each head
    for head_idx in range(min(num_heads, len(all_attentions[layer_idx]))):
        attn_weights = all_attentions[layer_idx][head_idx][0].cpu().numpy()
        cls_attn = attn_weights[0, 1:]
        num_patches = int(np.sqrt(len(cls_attn)))
        cls_attn_map = cls_attn.reshape(num_patches, num_patches)
        
        im = axes[head_idx + 1].imshow(cls_attn_map, cmap='viridis')
        axes[head_idx + 1].set_title(f'Head {head_idx+1}', fontsize=10)
        axes[head_idx + 1].axis('off')
        plt.colorbar(im, ax=axes[head_idx + 1], fraction=0.046, pad=0.04)
    
    # Hide unused subplots
    for idx in range(num_heads + 1, len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle(f'Attention Heads - Layer {layer_idx+1}', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(f'attention_all_heads_L{layer_idx}.png', dpi=150, bbox_inches='tight')
    plt.show()


# ==================== MAIN ====================

def main():
    # Configuration
    config = {
        "patch": 4,
        "dim": 128,
        "n_heads": 16,
        "layers": 16,
        "image_size": 64,
        "num_classes": 10,
    }
    
    # Training parameters
    exp_name = 'vit-cifar10'
    batch_size = 32
    epochs = 10
    lr = 1e-2
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"Using device: {device}")
    print(f"Configuration: {config}")
    
    # Prepare data
    trainloader, testloader, classes = prepare_data(batch_size=batch_size)
    
    # Create model
    model = ViT(config)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    loss_fn = nn.CrossEntropyLoss()
    
    # Train
    trainer = Trainer(model, optimizer, loss_fn, exp_name, device=device)
    print("\nStarting training...")
    train_losses, test_losses, accuracies = trainer.train(trainloader, testloader, epochs)
    
    # Visualize attention on test samples
    print("\nGenerating attention visualizations...")
    model.eval()
    
    # Get a few test images
    test_iter = iter(testloader)
    images, labels = next(test_iter)
    
    # Visualize first 3 images
    for i in range(min(3, len(images))):
        print(f"\nVisualizing sample {i+1}")
        visualize_attention(model, images[i], labels[i].item(), classes, device, layer_idx=0, head_idx=0)
        visualize_multiple_heads(model, images[i], labels[i].item(), classes, device, layer_idx=0, num_heads=4)
    
    print("\nTraining complete! Attention visualizations saved.")


if __name__ == '__main__':
    main()

Using device: cuda
Configuration: {'patch': 4, 'dim': 128, 'n_heads': 16, 'layers': 16, 'image_size': 64, 'num_classes': 10}

Starting training...
Epoch: 1, Train loss: 2.2953, Test loss: 1.6340, Accuracy: 0.3778
Epoch: 2, Train loss: 1.5619, Test loss: 1.4235, Accuracy: 0.4718
Epoch: 3, Train loss: 21.5374, Test loss: 12.7163, Accuracy: 0.1532
