# CIFAR-10 Baseline (Colab)

This notebook loads the repo, installs deps, optionally runs the trainer, and visualizes saved metrics.

In [None]:


# Setup: clone repo and install requirements
%cd /content  # go back to root
!rm -rf mixture-of-experts-project
!git clone https://github.com/moe-project-uu/mixture-of-experts-project.git || true
%cd mixture-of-experts-project
%pip install -r requirements.txt

# 2) Use Colab’s preinstalled torch to avoid CUDA mismatches
#    (so don't reinstall torch/torchvision). Install your package + other deps.
%pip install -U pip
%pip install -e .


In [None]:
# CELL2: run training and capture history
import sys, os
proj = "/content/mixture-of-experts-project"
sys.path.insert(0, proj)                    # for scripts/
sys.path.insert(0, os.path.join(proj, "src"))  # ← for src/moe package

from scripts.train_cifar10 import parser, main  # now works

args = parser.parse_args(args=[])
args.FF_layer = "SoftMoE"   # or "Dense"
args.epochs   = 100
args.num_experts = 4        # ignored for Dense

hist = main(args)


In [None]:
# CELL3: Loss/Accuracy curves from `hist`
import numpy as np
import matplotlib.pyplot as plt

train_loss = np.array(hist["train_loss"])
train_acc  = np.array(hist["train_acc"])
val_loss   = np.array(hist["val_loss"])
val_acc    = np.array(hist["val_acc"])

plt.figure(figsize=(6,4))
plt.plot(train_loss, label="train")
plt.plot(val_loss,   label="val")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title(f'{args.FF_layer} — Loss'); plt.legend(); plt.show()

plt.figure(figsize=(6,4))
plt.plot(train_acc, label="train")
plt.plot(val_acc,   label="val")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.title(f'{args.FF_layer} — Accuracy'); plt.legend(); plt.show()


In [None]:
# CELL3: Loss/Accuracy curves from `hist`
import numpy as np
import matplotlib.pyplot as plt

train_loss = np.array(hist["train_loss"])
train_acc  = np.array(hist["train_acc"])
val_loss   = np.array(hist["val_loss"])
val_acc    = np.array(hist["val_acc"])

plt.figure(figsize=(6,4))
plt.plot(train_loss, label="train")
plt.plot(val_loss,   label="val")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title(f'{args.FF_layer} — Loss'); plt.legend(); plt.show()

plt.figure(figsize=(6,4))
plt.plot(train_acc, label="train")
plt.plot(val_acc,   label="val")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.title(f'{args.FF_layer} — Accuracy'); plt.legend(); plt.show()


In [None]:
# CELL4: SoftMoE — Expert utilization over epochs + per-expert bars at a few epochs
import numpy as np
import matplotlib.pyplot as plt

if args.FF_layer != "SoftMoE":
    print("Utilization plots are only available for SoftMoE. Re-run with args.FF_layer='SoftMoE'.")
else:
    util_list = hist.get("util_per_epoch", [])
    assert len(util_list) > 0, "No utilization recorded — ensure you trained with SoftMoE."

    util = np.stack(util_list, axis=0)  # shape: (num_epochs, num_experts)

    # (A) line plot: one curve per expert across epochs
    plt.figure(figsize=(7,4))
    for i in range(util.shape[1]):
        plt.plot(util[:, i], label=f"expert {i}")
    plt.xlabel("epoch"); plt.ylabel("mean p_i")
    plt.title("Expert Utilization over epochs")
    plt.legend(ncol=2); plt.show()

    # (B) one-bar-per-expert “histograms” at epochs ~0, 50, 100 (clamped to available range)
    want = [0, 50, 100]
    max_idx = util.shape[0] - 1
    picked = [e for e in want if e <= max_idx]

    for e in picked:
        vals = util[e]  # length = num_experts
        plt.figure(figsize=(6,4))
        plt.bar(np.arange(len(vals)), vals)
        plt.xticks(np.arange(len(vals)), [f"E{i}" for i in range(len(vals))])
        plt.ylim(0, 1)
        plt.ylabel("mean p_i"); plt.xlabel("expert")
        plt.title(f"Expert mean probabilities at epoch index {e} (0-based)")
        plt.show()


In [None]:
# CELL5: SoftMoE — Gating entropy over epochs
import numpy as np
import matplotlib.pyplot as plt

if args.FF_layer != "SoftMoE":
    print("Entropy plot is only available for SoftMoE. Re-run with args.FF_layer='SoftMoE'.")
else:
    H = np.array(hist.get("entropy_per_epoch", []))
    if H.size == 0:
        print("No entropy recorded — make sure SoftMoE ran.")
    else:
        plt.figure(figsize=(6,4))
        plt.plot(H)
        plt.xlabel("epoch"); plt.ylabel("entropy  H = -Σ p log p")
        plt.title("Gating Entropy over epochs")
        plt.show()
