# CIFAR-10 CNN Training

This notebook is the **experiment driver**. All implementation lives in `src/`.

| Section | What happens here |
|---|---|
| 0 – Colab setup | mount Drive, clone repo, pip install |
| 1 – Config & imports | choose model, set hyper-params |
| 2 – Data | load CIFAR-10, quick sanity check |
| 3 – Train | train one architecture |
| 4 – Compare | train all architectures back-to-back |
| 5 – Evaluate | test-set accuracy + per-class breakdown |

## 0 · Colab Setup
> **Skip if running locally.**
> Connect VS Code to Colab: Command Palette → *Jupyter: Specify Jupyter Server* → paste the Colab runtime URL.

In [None]:
import sys, os

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")

    REPO_URL = "https://github.com/behzadsabeti/PartInterviewTasks.git"  # TODO: update
    REPO_DIR = "/content/PartInterviewTasks"

    if not os.path.exists(REPO_DIR):
        os.system(f"git clone {REPO_URL} {REPO_DIR}")
    else:
        os.system(f"cd {REPO_DIR} && git pull")

    os.chdir(REPO_DIR)
    sys.path.insert(0, REPO_DIR)
    os.system("pip install -q -r requirements.txt")

print("cwd:", os.getcwd(), "| colab:", IN_COLAB)

## 1 · Config & Imports

In [2]:
import torch
import torch.nn as nn

from src.config  import ExperimentConfig, DataConfig, TrainConfig
from src.dataset import get_dataloaders, CLASSES
from src.models  import get_model, count_parameters
from src.trainer import train, evaluate, load_checkpoint
from src.utils   import set_seed, get_device, plot_history, compare_histories

# ---- Edit these to change the experiment ----
cfg = ExperimentConfig(
    seed       = 42,
    model_name = "SimpleCNN",          # "SimpleCNN" | "DeepCNN" | "ResNetCIFAR"
    data  = DataConfig(batch_size=128, val_split=0.1, augment=True),
    train = TrainConfig(epochs=50, learning_rate=1e-3, use_amp=True),
)

set_seed(cfg.seed)
device = get_device()

ModuleNotFoundError: No module named 'src'

## 2 · Data

In [None]:
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir    = cfg.data.data_dir,
    batch_size  = cfg.data.batch_size,
    val_split   = cfg.data.val_split,
    augment     = cfg.data.augment,
    num_workers = cfg.data.num_workers,
)
print(f"Train: {len(train_loader.dataset)}  Val: {len(val_loader.dataset)}  Test: {len(test_loader.dataset)}")

In [None]:
# TODO: visualise a mini-batch (un-normalise, show 16 images with class labels)

## 3 · Train One Model

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

model     = get_model(cfg.model_name).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=cfg.train.learning_rate, weight_decay=cfg.train.weight_decay)
scheduler = OneCycleLR(optimizer, max_lr=cfg.train.max_lr,
                       epochs=cfg.train.epochs, steps_per_epoch=len(train_loader))

print(f"{cfg.model_name}  |  {count_parameters(model):,} params")

history = train(
    model, train_loader, val_loader, criterion, optimizer, scheduler,
    device, cfg.train.epochs, cfg.train.checkpoint_dir, cfg.model_name, cfg.train.use_amp,
)

In [None]:
plot_history(history, title=cfg.model_name)

## 4 · Compare All Architectures

In [None]:
all_histories = {}

for arch in ["SimpleCNN", "DeepCNN", "ResNetCIFAR"]:
    # TODO: build model, optimizer, scheduler for each arch
    # TODO: call train() and store result in all_histories[arch]
    pass

compare_histories(all_histories, metric="val_acc")

## 5 · Test-Set Evaluation

In [None]:
# TODO: load best checkpoint for each architecture and report test accuracy

In [None]:
# TODO: per-class accuracy breakdown for the best model