# Q1 — Fine-tune a pretrained ViT on CIFAR-10 (Colab-ready)

This notebook fine-tunes a pretrained Vision Transformer (`vit_b_16`) from `torchvision` on CIFAR-10.

Notes:
- Resize CIFAR images to 224×224 expected by the pretrained ViT.
- Use standard augmentations, AdamW, and a cosine LR scheduler.
- Save best checkpoint `best_vit_cifar10.pth`.

In [None]:
# Install dependencies (run first cell in Colab)
!pip install -q timm einops albumentations==1.3.1 torchmetrics

# NOTE: Colab usually has a suitable torch+cuda build preinstalled. If not, you may need to install torch compatible with the runtime GPU.

In [None]:
import os, time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR10
import torchvision
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.classification import MulticlassAccuracy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## Data: CIFAR-10
- Resize to 224x224 (ViT pretraining size)
- Basic augmentations for training

In [None]:
train_transform = T.Compose([
    T.ToPILImage(),
    T.RandomResizedCrop(224, scale=(0.8,1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.1,0.1,0.1,0.02),
    T.ToTensor(),
    T.Normalize(mean=(0.4914,0.4822,0.4465), std=(0.247,0.243,0.261))
])

val_transform = T.Compose([
    T.ToPILImage(),
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=(0.4914,0.4822,0.4465), std=(0.247,0.243,0.261))
])

class TransformCIFAR10(CIFAR10):
    def __init__(self, root, train, transform, download=False):
        super().__init__(root=root, train=train, transform=None, download=download)
        self.alb_transform = transform
    def __getitem__(self, idx):
        img, label = self.data[idx], int(self.targets[idx])
        img = self.alb_transform(img)
        return img, label

def get_loaders(batch_size=64, num_workers=2):
    train_ds = TransformCIFAR10(root='./data', train=True, download=True, transform=train_transform)
    val_ds = TransformCIFAR10(root='./data', train=False, download=True, transform=val_transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader

train_loader, val_loader = get_loaders(batch_size=128, num_workers=4)
print('Train batches:', len(train_loader), 'Val batches:', len(val_loader))


## Model: load pretrained ViT and adapt head for 10 classes
- We use `torchvision.models.vit_b_16` with pretrained weights (if available in your torchvision).

In [None]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

weights = None
try:
    # If torchvision provides pretrained weights in this environment
    weights = ViT_B_16_Weights.IMAGENET1K_V1
except Exception:
    weights = None

print('Using weights:', weights)
model = vit_b_16(weights=weights).to(device)
# Replace head
in_features = model.heads.head.in_features if hasattr(model, 'heads') else model.head.in_features
try:
    # torchvision ViT has `heads` module
    model.heads.head = nn.Linear(in_features, 10)
except Exception:
    model.head = nn.Linear(in_features, 10)
model = model.to(device)
print(model)


## Training utilities

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
EPOCHS = 20
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

def evaluate(model, loader):
    model.eval()
    acc = MulticlassAccuracy(num_classes=10).to(device)
    total_loss = 0.0
    n = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            total_loss += loss.item() * imgs.size(0)
            preds = torch.argmax(logits, dim=1)
            acc.update(preds, labels)
            n += imgs.size(0)
    return total_loss / n, acc.compute().item()

def train_one_epoch(model, loader, optimizer):
    model.train()
    acc = MulticlassAccuracy(num_classes=10).to(device)
    running_loss = 0.0
    n = 0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        preds = torch.argmax(logits, dim=1)
        acc.update(preds, labels)
        n += imgs.size(0)
    return running_loss / n, acc.compute().item()


In [None]:
# Training loop (short example). Increase EPOCHS to 50-100 for better results.
best_acc = 0.0
for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, val_loader)
    scheduler.step()
    print(f"Epoch {epoch:03d} | train_loss {train_loss:.4f} acc {train_acc:.4f} | val_loss {val_loss:.4f} acc {val_acc:.4f} | time {time.time()-t0:.1f}s")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({'model_state': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc}, 'best_vit_cifar10.pth')
print('Best val acc:', best_acc)


## Notes
- Increase `EPOCHS` and consider techniques such as MixUp, CutMix, and longer training for higher accuracy.
- If the environment does not have `ViT_B_16` weights pre-downloaded, `weights=None` will use randomly initialized ViT; to use ImageNet pretrained weights ensure torchvision and internet access permit download.