# MLA-ViT Comparative Evaluation on CIFAR-100

This notebook evaluates and compares three Vision Transformer variants on the CIFAR-100 dataset:

- **ViT-MHA**: A standard Vision Transformer using Multi-Head Attention.
- **ViT-MLA**: A Vision Transformer using Multi-Head Latent Attention for reduced attention overhead.
- **ViT-MLA+RoPE**: The MLA variant enhanced with Rotary Positional Embedding for improved positional encoding.

All models are trained for 500 epochs with identical configurations to ensure fairness. Metrics such as accuracy and training time are recorded for comparison.


In [None]:
# Import required libraries for modeling, training, and visualization
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt

from models.vit_mha import ViT_MHA
from models.vit_mla import ViT_MLA
from models.vit_mla_rope import ViT_MLA_RoPE


In [None]:
# Set experiment-wide hyperparameters
BATCH_SIZE = 128
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Load CIFAR-100 with data augmentation and normalization
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])
train_set = torchvision.datasets.CIFAR100(root='dataset', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR100(root='dataset', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)


In [None]:
# Define a function to train and evaluate a given model architecture
def train_and_evaluate(model, model_name):
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)

    best_acc = 0
    train_accs, test_accs = [], []
    start_time = time.time()

    for epoch in range(EPOCHS):
        model.train()
        correct, total = 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
        train_accs.append(100 * correct / total)

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                outputs = model(imgs)
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)
        acc = 100 * correct / total
        test_accs.append(acc)
        best_acc = max(best_acc, acc)
        scheduler.step()

        print(f"[{model_name}] Epoch {epoch+1}: Train Acc: {train_accs[-1]:.2f}%, Test Acc: {acc:.2f}%")

    end_time = time.time()
    print(f"[{model_name}] Final Test Accuracy: {best_acc:.2f}% | Total Time: {end_time - start_time:.1f}s")
    return train_accs, test_accs


In [None]:
# Instantiate and evaluate all model variants under identical settings
model_classes = {
    "ViT-MHA": ViT_MHA,
    "ViT-MLA": ViT_MLA,
    "ViT-MLA+RoPE": ViT_MLA_RoPE
}

results = {}

for name, cls in model_classes.items():
    print(f"\nRunning {name}...")
    model = cls()
    results[name] = train_and_evaluate(model, name)


In [None]:
# Visualize accuracy curves for each model
plt.figure(figsize=(12, 6))
for name, (train_accs, test_accs) in results.items():
    plt.plot(test_accs, label=f'{name} Test Acc')

plt.title("Test Accuracy over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.grid(True)
plt.show()
