# CIFAR-10 Baseline (Colab)

This notebook loads the repo, installs deps, optionally runs the trainer, and visualizes saved metrics.

In [None]:


# Setup: clone repo and install requirements
%cd /content  # go back to root
!rm -rf mixture-of-experts-project
!git clone https://github.com/moe-project-uu/mixture-of-experts-project.git || true
%cd mixture-of-experts-project
%pip install -r requirements.txt

# 2) Use Colab’s preinstalled torch to avoid CUDA mismatches
#    (so don't reinstall torch/torchvision). Install your package + other deps.
%pip install -U pip
%pip install -e .


In [None]:
# --- Load metrics and model checkpoint ---
import os, torch, matplotlib.pyplot as plt

# Choose the layer and run tag you used during training
FF_LAYER = "Dense"              # e.g. "Dense", "SoftMoE", "SparseMoE"
RUN_TAG  = "W512-S42"           # matches run_tag from training script

base_dir = os.path.join("checkpoints", FF_LAYER, RUN_TAG)
metrics_path = os.path.join(base_dir, "metrics.pt")
ckpt_path    = os.path.join(base_dir, "model.pt")
summary_path = os.path.join(base_dir, "summary.json")

# --- Load metrics safely ---
metrics = torch.load(metrics_path, map_location="cpu")
train_losses = metrics.get("train_losses", [])
train_accs   = metrics.get("train_accs", [])
val_losses   = metrics.get("val_losses", [])
val_accs     = metrics.get("val_accs", [])
test_losses  = metrics.get("test_losses", [])
test_accs    = metrics.get("test_accs", [])

# --- Load checkpoint and summary for final results ---
ckpt = torch.load(ckpt_path, map_location="cpu") if os.path.exists(ckpt_path) else {}
final_train_acc = (train_accs[-1] * 100) if len(train_accs) else float("nan")
final_test_acc  = ckpt.get("test_acc", (test_accs[-1] * 100) if len(test_accs) else float("nan"))

# --- Plot loss curves ---
plt.figure(figsize=(8,5))
plt.plot(train_losses, label="train loss")
if val_losses: plt.plot(val_losses, label="val loss")
if test_losses: plt.plot(test_losses, label="test loss")
plt.xlabel("epoch"); plt.ylabel("loss")
plt.legend(); plt.title(f"{FF_LAYER} | Loss Curve")
plt.show()

# --- Plot accuracy curves ---
plt.figure(figsize=(8,5))
plt.plot([x * 100 for x in train_accs], label="train acc %")
if val_accs: plt.plot([x * 100 for x in val_accs], label="val acc %")
if test_accs: plt.plot([x * 100 for x in test_accs], label="test acc %")
plt.xlabel("epoch"); plt.ylabel("accuracy (%)")
plt.legend(); plt.title(f"{FF_LAYER} | Accuracy Curve")

# --- Add final results as a text box ---
textstr = f"Final Train Acc: {final_train_acc:.2f}%\nFinal Test Acc: {final_test_acc:.2f}%"
plt.gca().text(
    0.95, 0.05, textstr,
    transform=plt.gca().transAxes,
    fontsize=10, va="bottom", ha="right",
    bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5)
)

plt.show()


In [None]:
# CELL3: Loss/Accuracy curves from `hist`
import numpy as np
import matplotlib.pyplot as plt

train_loss = np.array(hist["train_loss"])
train_acc  = np.array(hist["train_acc"])
val_loss   = np.array(hist["val_loss"])
val_acc    = np.array(hist["val_acc"])

plt.figure(figsize=(6,4))
plt.plot(train_loss, label="train")
plt.plot(val_loss,   label="val")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title(f'{args.FF_layer} — Loss'); plt.legend(); plt.show()

plt.figure(figsize=(6,4))
plt.plot(train_acc, label="train")
plt.plot(val_acc,   label="val")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.title(f'{args.FF_layer} — Accuracy'); plt.legend(); plt.show()


In [None]:
# CELL4: SoftMoE — Expert utilization over epochs + per-expert bars at a few epochs
import numpy as np
import matplotlib.pyplot as plt

if args.FF_layer != "SoftMoE":
    print("Utilization plots are only available for SoftMoE. Re-run with args.FF_layer='SoftMoE'.")
else:
    util_list = hist.get("util_per_epoch", [])
    assert len(util_list) > 0, "No utilization recorded — ensure you trained with SoftMoE."

    util = np.stack(util_list, axis=0)  # shape: (num_epochs, num_experts)

    # (A) line plot: one curve per expert across epochs
    plt.figure(figsize=(7,4))
    for i in range(util.shape[1]):
        plt.plot(util[:, i], label=f"expert {i}")
    plt.xlabel("epoch"); plt.ylabel("mean p_i")
    plt.title("Expert Utilization over epochs")
    plt.legend(ncol=2); plt.show()

    # (B) one-bar-per-expert “histograms” at epochs ~0, 50, 100 (clamped to available range)
    want = [0, 50, 100]
    max_idx = util.shape[0] - 1
    picked = [e for e in want if e <= max_idx]

    for e in picked:
        vals = util[e]  # length = num_experts
        plt.figure(figsize=(6,4))
        plt.bar(np.arange(len(vals)), vals)
        plt.xticks(np.arange(len(vals)), [f"E{i}" for i in range(len(vals))])
        plt.ylim(0, 1)
        plt.ylabel("mean p_i"); plt.xlabel("expert")
        plt.title(f"Expert mean probabilities at epoch index {e} (0-based)")
        plt.show()
