# CIFAR-10 CNN Training

This notebook trains and compares CNN architectures on CIFAR-10.
All implementation lives in `src/`; this file is only the experiment driver.

| Cell group | Purpose |
|---|---|
| 0 – Environment setup | Colab/Drive mount, repo clone, pip install |
| 1 – Imports & config  | Load modules, set hyper-params |
| 2 – Data              | Download / inspect CIFAR-10 |
| 3 – Single experiment | Train one chosen architecture |
| 4 – Comparison        | Train all architectures & compare |
| 5 – Evaluation        | Test-set accuracy + per-class breakdown |

## 0 · Environment Setup (Colab)

> **Skip this section if you are running locally.**
>
> When connecting from VS Code to a Colab runtime:
> 1. Open the Command Palette → *"Jupyter: Specify Jupyter Server for Connections"*
> 2. Paste the Colab runtime URL (copy from *Connect → Copy link*).
> 3. Make sure you have the **Jupyter** and **Python** VS Code extensions installed.

In [None]:
import sys, os

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    # --- Mount Google Drive (optional: for persisting checkpoints) ---
    from google.colab import drive
    drive.mount("/content/drive")

    # --- Clone your repository ---
    REPO_URL = "https://github.com/YOUR_USERNAME/PartInterviewTasks.git"  # <-- update this
    REPO_DIR = "/content/PartInterviewTasks"

    if not os.path.exists(REPO_DIR):
        os.system(f"git clone {REPO_URL} {REPO_DIR}")
    else:
        os.system(f"cd {REPO_DIR} && git pull")

    os.chdir(REPO_DIR)
    sys.path.insert(0, REPO_DIR)

    # --- Install dependencies ---
    os.system("pip install -q torch torchvision matplotlib")

print("Working directory:", os.getcwd())
print("In Colab:", IN_COLAB)

## 1 · Imports & Configuration

In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR

from src.config  import ExperimentConfig, DataConfig, TrainConfig
from src.dataset import get_dataloaders, CLASSES
from src.models  import get_model, count_parameters
from src.trainer import train, evaluate, load_checkpoint
from src.utils   import set_seed, get_device, plot_history, compare_histories

# ---------- Experiment configuration ----------
cfg = ExperimentConfig(
    seed       = 42,
    model_name = "SimpleCNN",   # change to "DeepCNN" or "ResNetCIFAR"
    data  = DataConfig(
        batch_size  = 128,
        val_split   = 0.1,
        augment     = True,
        num_workers = 2,
    ),
    train = TrainConfig(
        epochs        = 50,
        learning_rate = 1e-3,
        weight_decay  = 5e-4,
        use_one_cycle = True,
        max_lr        = 1e-2,
        use_amp       = True,
        checkpoint_dir = "./checkpoints",
    ),
)

set_seed(cfg.seed)
device = get_device()

## 2 · Data

In [None]:
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir    = cfg.data.data_dir,
    batch_size  = cfg.data.batch_size,
    val_split   = cfg.data.val_split,
    augment     = cfg.data.augment,
    num_workers = cfg.data.num_workers,
)

print(f"Train batches : {len(train_loader)}  ({len(train_loader.dataset)} samples)")
print(f"Val   batches : {len(val_loader)}  ({len(val_loader.dataset)} samples)")
print(f"Test  batches : {len(test_loader)}  ({len(test_loader.dataset)} samples)")
print(f"Classes : {CLASSES}")

In [None]:
# Quick sanity-check: visualise a mini-batch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid

CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465])
CIFAR_STD  = np.array([0.2470, 0.2435, 0.2616])

imgs, labels = next(iter(train_loader))
imgs_show = imgs[:16].cpu().numpy()              # (16, 3, 32, 32)
imgs_show = imgs_show.transpose(0, 2, 3, 1)     # (16, 32, 32, 3)
imgs_show = imgs_show * CIFAR_STD + CIFAR_MEAN  # un-normalise
imgs_show = imgs_show.clip(0, 1)

fig, axes = plt.subplots(2, 8, figsize=(14, 4))
for ax, img, lbl in zip(axes.flat, imgs_show, labels):
    ax.imshow(img)
    ax.set_title(CLASSES[lbl], fontsize=7)
    ax.axis("off")
plt.suptitle("Sample training images (after augmentation, un-normalised)", fontsize=10)
plt.tight_layout()
plt.show()

## 3 · Train a Single Model

In [None]:
model = get_model(cfg.model_name).to(device)
print(f"Model         : {cfg.model_name}")
print(f"Parameters    : {count_parameters(model):,}")

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=cfg.train.learning_rate, weight_decay=cfg.train.weight_decay)

if cfg.train.use_one_cycle:
    scheduler = OneCycleLR(
        optimizer,
        max_lr       = cfg.train.max_lr,
        epochs       = cfg.train.epochs,
        steps_per_epoch = len(train_loader),
    )
else:
    scheduler = CosineAnnealingLR(optimizer, T_max=cfg.train.epochs)

history = train(
    model          = model,
    train_loader   = train_loader,
    val_loader     = val_loader,
    criterion      = criterion,
    optimizer      = optimizer,
    scheduler      = scheduler,
    device         = device,
    epochs         = cfg.train.epochs,
    checkpoint_dir = cfg.train.checkpoint_dir,
    model_name     = cfg.model_name,
    use_amp        = cfg.train.use_amp,
)

In [None]:
plot_history(history, title=cfg.model_name)

## 4 · Compare All Architectures

In [None]:
ARCHITECTURES = ["SimpleCNN", "DeepCNN", "ResNetCIFAR"]
all_histories: dict = {}

for arch in ARCHITECTURES:
    print(f"\n{'='*60}")
    print(f"  Training: {arch}")
    print(f"{'='*60}")

    set_seed(cfg.seed)
    m = get_model(arch).to(device)
    print(f"  Parameters: {count_parameters(m):,}")

    opt  = AdamW(m.parameters(), lr=cfg.train.learning_rate, weight_decay=cfg.train.weight_decay)
    sched = OneCycleLR(
        opt,
        max_lr          = cfg.train.max_lr,
        epochs          = cfg.train.epochs,
        steps_per_epoch = len(train_loader),
    )
    crit = nn.CrossEntropyLoss(label_smoothing=0.1)

    hist = train(
        model          = m,
        train_loader   = train_loader,
        val_loader     = val_loader,
        criterion      = crit,
        optimizer      = opt,
        scheduler      = sched,
        device         = device,
        epochs         = cfg.train.epochs,
        checkpoint_dir = cfg.train.checkpoint_dir,
        model_name     = arch,
        use_amp        = cfg.train.use_amp,
    )
    all_histories[arch] = hist

compare_histories(all_histories, metric="val_acc")
compare_histories(all_histories, metric="val_loss")

## 5 · Test-Set Evaluation

In [None]:
import os

criterion_eval = nn.CrossEntropyLoss()

print(f"{'Architecture':<16} {'Test Acc':>10}")
print("-" * 28)

for arch in ARCHITECTURES:
    ckpt_path = os.path.join(cfg.train.checkpoint_dir, f"{arch}_best.pt")
    if not os.path.exists(ckpt_path):
        print(f"{arch:<16}  checkpoint not found, skipping")
        continue

    m = get_model(arch).to(device)
    load_checkpoint(m, ckpt_path, device)
    _, test_acc = evaluate(m, test_loader, criterion_eval, device)
    print(f"{arch:<16} {test_acc:>9.2f}%")

In [None]:
# Per-class accuracy for the best checkpoint
import torch

BEST_ARCH = "ResNetCIFAR"  # change to whichever model won

m = get_model(BEST_ARCH).to(device)
load_checkpoint(m, os.path.join(cfg.train.checkpoint_dir, f"{BEST_ARCH}_best.pt"), device)
m.eval()

class_correct = torch.zeros(10)
class_total   = torch.zeros(10)

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        preds = m(inputs).argmax(1)
        for c in range(10):
            mask = targets == c
            class_correct[c] += (preds[mask] == c).sum().item()
            class_total[c]   += mask.sum().item()

print(f"Per-class accuracy — {BEST_ARCH}")
print("-" * 35)
for c, name in enumerate(CLASSES):
    acc = 100.0 * class_correct[c] / class_total[c]
    print(f"  {name:<12} {acc:.1f}%")