# CIFAR-10 Baseline (Colab)

This notebook loads the repo, installs deps, optionally runs the trainer, and visualizes saved metrics.

In [None]:


# Setup: clone repo and install requirements
%cd /content  # go back to root
!rm -rf mixture-of-experts-project
!git clone https://github.com/moe-project-uu/mixture-of-experts-project.git || true
%cd mixture-of-experts-project
%pip install -r requirements.txt

# 2) Use Colab’s preinstalled torch to avoid CUDA mismatches
#    (so don't reinstall torch/torchvision). Install your package + other deps.
%pip install -U pip
%pip install -e .


In [None]:
# run training script
!python scripts/train_cifar10.py --FF_layer Dense --epochs 1


In [None]:
# --- Load metrics and model checkpoint ---
import os, torch, matplotlib.pyplot as plt

# Choose the layer and run tag you used during training
FF_LAYER = "Dense"              # e.g. "Dense", "SoftMoE", "SparseMoE"
RUN_TAG  = "W512-S42"           # matches run_tag from training script

base_dir = os.path.join("checkpoints", FF_LAYER, RUN_TAG)
metrics_path = os.path.join(base_dir, "metrics.pt")
ckpt_path    = os.path.join(base_dir, "model.pt")
summary_path = os.path.join(base_dir, "summary.json")

# --- Load metrics safely ---
metrics = torch.load(metrics_path, map_location="cpu")
train_losses = metrics.get("train_losses", [])
train_accs   = metrics.get("train_accs", [])
val_losses   = metrics.get("val_losses", [])
val_accs     = metrics.get("val_accs", [])
test_losses  = metrics.get("test_losses", [])
test_accs    = metrics.get("test_accs", [])

# --- Load checkpoint and summary for final results ---
ckpt = torch.load(ckpt_path, map_location="cpu") if os.path.exists(ckpt_path) else {}
final_train_acc = (train_accs[-1] * 100) if len(train_accs) else float("nan")
final_test_acc  = ckpt.get("test_acc", (test_accs[-1] * 100) if len(test_accs) else float("nan"))

# --- Plot loss curves ---
plt.figure(figsize=(8,5))
plt.plot(train_losses, label="train loss")
if val_losses: plt.plot(val_losses, label="val loss")
if test_losses: plt.plot(test_losses, label="test loss")
plt.xlabel("epoch"); plt.ylabel("loss")
plt.legend(); plt.title(f"{FF_LAYER} | Loss Curve")
plt.show()

# --- Plot accuracy curves ---
plt.figure(figsize=(8,5))
plt.plot([x * 100 for x in train_accs], label="train acc %")
if val_accs: plt.plot([x * 100 for x in val_accs], label="val acc %")
if test_accs: plt.plot([x * 100 for x in test_accs], label="test acc %")
plt.xlabel("epoch"); plt.ylabel("accuracy (%)")
plt.legend(); plt.title(f"{FF_LAYER} | Accuracy Curve")

# --- Add final results as a text box ---
textstr = f"Final Train Acc: {final_train_acc:.2f}%\nFinal Test Acc: {final_test_acc:.2f}%"
plt.gca().text(
    0.95, 0.05, textstr,
    transform=plt.gca().transAxes,
    fontsize=10, va="bottom", ha="right",
    bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5)
)

plt.show()
