# CIFAR-10 Baseline (Colab)

This notebook loads the repo, installs deps, optionally runs the trainer, and visualizes saved metrics.

In [None]:
# Setup: clone repo and install requirements
!git clone https://github.com/moe-project-uu/mixture-of-experts-project.git || true
%cd mixture-of-experts-project
%pip install -r requirements.txt

In [None]:
# Optional: run training here (or skip if you've already trained)
# You can tweak EPOCHS in the script if you add an argparse flag later.
!python scripts/train_cifar10.py

In [None]:
# Load saved metrics for plotting
import torch
metrics = torch.load("checkpoints/cifar10_metrics.pt", map_location="cpu")
train_losses = metrics["train_losses"]
train_accs   = metrics["train_accs"]
val_losses   = metrics["val_losses"]
val_accs     = metrics["val_accs"]
len(train_losses), len(val_losses)

In [None]:
# Plot curves
import matplotlib.pyplot as plt
plt.figure()
plt.plot(train_losses, label="train loss")
plt.plot(val_losses, label="val loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.show()

plt.figure()
plt.plot([x*100 for x in train_accs], label="train acc %")
plt.plot([x*100 for x in val_accs], label="val acc %")
plt.xlabel("epoch"); plt.ylabel("accuracy %"); plt.legend(); plt.show()

In [None]:
# Inspect a few predictions
import torch
import torchvision
import torchvision.transforms as T
from scripts.train_cifar10 import cnn_classifier

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
model = cnn_classifier(in_channels=3, out_channels=10)
ckpt = torch.load("checkpoints/cifar10_cnn.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval().to(device)

# CIFAR-10 test set with same normalization used in training
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)),
])
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
classes = testset.classes

from torch.utils.data import DataLoader
loader = DataLoader(testset, batch_size=16, shuffle=True)

images, labels = next(iter(loader))
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
    logits = model(images)
preds = logits.argmax(dim=1)

# Show a small grid
import matplotlib.pyplot as plt
import numpy as np

def unnormalize(img):
    mean = np.array([0.4914, 0.4822, 0.4465]).reshape(3,1,1)
    std  = np.array([0.2470, 0.2435, 0.2616]).reshape(3,1,1)
    img = img.cpu().numpy()
    return np.clip((img * std) + mean, 0, 1)

plt.figure(figsize=(10,6))
for i in range(12):
    plt.subplot(3,4,i+1)
    plt.imshow(np.transpose(unnormalize(images[i]), (1,2,0)))
    plt.title(f"pred: {classes[preds[i]]}\ntrue: {classes[labels[i]]}")
    plt.axis('off')
plt.tight_layout(); plt.show()