# CIFAR-10 CNN Training

This notebook is the **experiment driver**. All implementation lives in `src/`.

| Section | What happens here |
|---|---|
| 0 – Colab setup | mount Drive, clone repo, pip install |
| 1 – Config & imports | choose model, set hyper-params |
| 2 – Data | load CIFAR-10, quick sanity check |
| 3 – Train | train one architecture |
| 4 – Compare | train all architectures back-to-back |
| 5 – Evaluate | test-set accuracy + per-class breakdown |

## 0 · Colab Setup
> **Skip if running locally.**
> Connect VS Code to Colab: Command Palette → *Jupyter: Specify Jupyter Server* → paste the Colab runtime URL.

In [4]:
import sys, os

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:

    REPO_URL = "https://github.com/behzadsabeti/PartInterviewTasks.git"  # TODO: update
    REPO_DIR = "/content/PartInterviewTasks"

    if not os.path.exists(REPO_DIR):
        os.system(f"git clone {REPO_URL} {REPO_DIR}")
    else:
        os.system(f"cd {REPO_DIR} && git pull")

    os.chdir(REPO_DIR)
    sys.path.insert(0, REPO_DIR)
    os.system("pip install -q -r requirements.txt")

print("cwd:", os.getcwd(), "| colab:", IN_COLAB)

cwd: /content/PartInterviewTasks | colab: True


## 1 · Config & Imports

In [5]:
import torch
import torch.nn as nn

from src.config  import ExperimentConfig, DataConfig, TrainConfig
from src.dataset import get_dataloaders, CLASSES
from src.models  import get_model, count_parameters
from src.trainer import train, evaluate, load_checkpoint
from src.utils   import set_seed, get_device, plot_history, compare_histories

# ---- Edit these to change the experiment ----
cfg = ExperimentConfig(
    seed       = 42,
    model_name = "SimpleCNN",          # "SimpleCNN" | "DeepCNN" | "ResNetCIFAR"
    data  = DataConfig(batch_size=128, val_split=0.1, augment=True),
    train = TrainConfig(epochs=50, learning_rate=1e-3, use_amp=True),
)

set_seed(cfg.seed)
device = get_device()

ImportError: cannot import name 'ExperimentConfig' from 'src.config' (/content/PartInterviewTasks/src/config.py)

## 2 · Data

In [None]:
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir    = cfg.data.data_dir,
    batch_size  = cfg.data.batch_size,
    val_split   = cfg.data.val_split,
    augment     = cfg.data.augment,
    num_workers = cfg.data.num_workers,
)
print(f"Train: {len(train_loader.dataset)}  Val: {len(val_loader.dataset)}  Test: {len(test_loader.dataset)}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from src.dataset import MEAN, STD

# --- 16 sample images ---
imgs, labels = next(iter(train_loader))
imgs_np = (imgs[:16].permute(0, 2, 3, 1).numpy() * np.array(STD) + np.array(MEAN)).clip(0, 1)

fig, axes = plt.subplots(2, 8, figsize=(14, 4))
for ax, img, lbl in zip(axes.flat, imgs_np, labels):
    ax.imshow(img); ax.set_title(CLASSES[lbl], fontsize=8); ax.axis("off")
plt.suptitle("Sample training images (augmented)", fontsize=10)
plt.tight_layout(); plt.show()

# --- Class distribution ---
from collections import Counter
counts = Counter(train_loader.dataset.dataset.targets[i] for i in train_loader.dataset.indices)
plt.figure(figsize=(8, 3))
plt.bar([CLASSES[i] for i in range(10)], [counts[i] for i in range(10)])
plt.title("Training-set class distribution"); plt.xticks(rotation=30, ha="right")
plt.tight_layout(); plt.show()

## 3 · Train One Model

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

model     = get_model(cfg.model_name).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=cfg.train.learning_rate, weight_decay=cfg.train.weight_decay)
scheduler = OneCycleLR(optimizer, max_lr=cfg.train.max_lr,
                       epochs=cfg.train.epochs, steps_per_epoch=len(train_loader))

print(f"Model : {cfg.model_name}  |  {count_parameters(model):,} trainable params")

history = train(
    model, train_loader, val_loader, criterion, optimizer, scheduler,
    device, cfg.train.epochs, cfg.train.checkpoint_dir, cfg.model_name, cfg.train.use_amp,
)

In [None]:
plot_history(history, title=cfg.model_name)

## 4 · Compare All Architectures

In [None]:
all_histories = {}

for arch in ["SimpleCNN", "DeepCNN", "ResNetCIFAR"]:
    print(f"\n{'='*52}\n  {arch}  ({count_parameters(get_model(arch)):,} params)\n{'='*52}")
    set_seed(cfg.seed)
    m    = get_model(arch).to(device)
    opt  = AdamW(m.parameters(), lr=cfg.train.learning_rate, weight_decay=cfg.train.weight_decay)
    sch  = OneCycleLR(opt, max_lr=cfg.train.max_lr,
                      epochs=cfg.train.epochs, steps_per_epoch=len(train_loader))
    crit = nn.CrossEntropyLoss(label_smoothing=0.1)
    all_histories[arch] = train(
        m, train_loader, val_loader, crit, opt, sch,
        device, cfg.train.epochs, cfg.train.checkpoint_dir, arch, cfg.train.use_amp,
    )

compare_histories(all_histories, metric="val_acc")
compare_histories(all_histories, metric="val_loss")

## 5 · Test-Set Evaluation

In [None]:
criterion_test = nn.CrossEntropyLoss()

print(f"{'Model':<16} {'Params':>10}  {'Test Acc':>10}")
print("-" * 40)
for arch in ["SimpleCNN", "DeepCNN", "ResNetCIFAR"]:
    m = get_model(arch).to(device)
    load_checkpoint(m, f"{cfg.train.checkpoint_dir}/{arch}_best.pt", device)
    _, acc = evaluate(m, test_loader, criterion_test, device)
    print(f"{arch:<16} {count_parameters(m):>10,}  {acc:>9.2f}%")

In [None]:
import torch

BEST_ARCH = max(all_histories, key=lambda k: max(all_histories[k]["val_acc"]))
print(f"Best architecture by val acc: {BEST_ARCH}\n")

m = get_model(BEST_ARCH).to(device)
load_checkpoint(m, f"{cfg.train.checkpoint_dir}/{BEST_ARCH}_best.pt", device)
m.eval()

correct = torch.zeros(10)
total   = torch.zeros(10)
with torch.no_grad():
    for x, y in test_loader:
        preds = m(x.to(device)).argmax(1).cpu()
        for c in range(10):
            mask = (y == c)
            correct[c] += (preds[mask] == c).sum()
            total[c]   += mask.sum()

accs = (100 * correct / total).numpy()
plt.figure(figsize=(8, 4))
bars = plt.bar(CLASSES, accs, color=plt.cm.RdYlGn(accs / 100))
plt.axhline(accs.mean(), color="navy", linestyle="--", label=f"mean {accs.mean():.1f}%")
for bar, a in zip(bars, accs):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
             f"{a:.1f}", ha="center", va="bottom", fontsize=8)
plt.ylabel("Accuracy (%)"); plt.title(f"Per-class Test Accuracy — {BEST_ARCH}")
plt.xticks(rotation=30, ha="right"); plt.legend(); plt.ylim(0, 105)
plt.tight_layout(); plt.show()