# CIFAR-10 Baseline (Colab)

This notebook loads the repo, installs deps, optionally runs the trainer, and visualizes saved metrics.

In [None]:


# Setup: clone repo and install requirements
%cd /content  # go back to root
!git clone https://github.com/moe-project-uu/mixture-of-experts-project.git || true
%cd mixture-of-experts-project
%pip install -r requirements.txt

# 2) Use Colab’s preinstalled torch to avoid CUDA mismatches
#    (so don't reinstall torch/torchvision). Install your package + other deps.
%pip install -U pip
%pip install -e .


In [None]:
# CELL2: Train via CLI so argparse is used
%cd /content/mixture-of-experts-project

# config settings
FF_LAYER = "SoftMoE"   # "Dense" or "SoftMoE"
EPOCHS   = 100
NUM_EXPERTS = 4        # ignored for Dense

if FF_LAYER == "Dense":
    !python scripts/train_cifar10.py --FF_layer Dense --epochs {EPOCHS}
else:
    !python scripts/train_cifar10.py --FF_layer SoftMoE --epochs {EPOCHS} --num_experts {NUM_EXPERTS}


In [None]:
# CELL3: Load metrics.pt and plot loss/accuracy
import os, torch, numpy as np
import matplotlib.pyplot as plt

FF_LAYER = "SoftMoE"   # must match CELL2
EPOCHS   = 100
NUM_EXPERTS = 4        # only used if not Dense

run_tag = f"E{EPOCHS}" if FF_LAYER == "Dense" else f"E{EPOCHS}-X{NUM_EXPERTS}"
ckpt_dir = os.path.join("/content/mixture-of-experts-project", "checkpoints", FF_LAYER, run_tag)
metrics_path = os.path.join(ckpt_dir, "metrics.pt")
assert os.path.exists(metrics_path), f"metrics.pt not found at {metrics_path}"

#load hist dict from metrics.pt
hist = torch.load(metrics_path, map_location="cpu")

train_loss = np.array(hist["train_loss"])
train_acc  = np.array(hist["train_acc"])
val_loss   = np.array(hist["val_loss"])
val_acc    = np.array(hist["val_acc"])

plt.figure(figsize=(6,4))
plt.plot(train_loss, label="train")
plt.plot(val_loss,   label="val")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.title(f"{FF_LAYER} — Loss"); plt.legend(); plt.show()

plt.figure(figsize=(6,4))
plt.plot(train_acc, label="train")
plt.plot(val_acc,   label="val")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.title(f"{FF_LAYER} — Accuracy"); plt.legend(); plt.show()


In [None]:
# CELL4: SoftMoE — Expert Utilization (lines + per-epoch bars)
import numpy as np
import matplotlib.pyplot as plt

if FF_LAYER != "SoftMoE":
    print("Utilization plots are only available for SoftMoE.")
else:
    util_list = hist.get("util_per_epoch", [])
    if len(util_list) == 0:
        print("No utilization recorded. Make sure you trained with SoftMoE.")
    else:
        util = np.stack(util_list, axis=0)  # shape (num_epochs, num_experts)

        # (A) Line plot: evolution over epochs
        plt.figure(figsize=(7,4))
        for i in range(util.shape[1]):
            plt.plot(util[:, i], label=f"expert {i}")
        plt.xlabel("epoch"); plt.ylabel("mean p_i")
        plt.title("Expert Utilization over epochs")
        plt.legend(ncol=2); plt.show()

        # (B) Bars at epochs 0, 50, 100 (0-based indices)
        requested = [0, 50, 100]
        max_idx = util.shape[0] - 1
        picked = [e for e in requested if 0 <= e <= max_idx]

        for e in picked:
            vals = util[e]  # length = num_experts
            plt.figure(figsize=(6,4))
            plt.bar(np.arange(len(vals)), vals)
            plt.xticks(np.arange(len(vals)), [f"E{i}" for i in range(len(vals))])
            plt.ylim(0, 1)
            plt.ylabel("mean p_i"); plt.xlabel("expert")
            plt.title(f"Expert mean probabilities at epoch index {e} (0-based)")
            plt.show()


In [None]:
# CELL5: SoftMoE — Gating Entropy over epochs
import numpy as np
import matplotlib.pyplot as plt

if FF_LAYER != "SoftMoE":
    print("Entropy plot is only available for SoftMoE.")
else:
    H = np.array(hist.get("entropy_per_epoch", []))
    if H.size == 0:
        print("No entropy recorded. Make sure you trained with SoftMoE.")
    else:
        plt.figure(figsize=(6,4))
        plt.plot(H)
        plt.xlabel("epoch"); plt.ylabel("entropy  H = -Σ p log p")
        plt.title("Gating Entropy over epochs")
        plt.show()
