- How does expert utilization looks like as we go across the layers?
- How does expert utilization moves as we train for more tokens?

In [7]:
import torch
import numpy as np
from math import ceil
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path
from collections import defaultdict
from typing import DefaultDict

# Roughly 1M tokens from the test subset got selected.
# { checkpoint step -> { layer number -> { expert id -> active counts } }
record: DefaultDict[int, DefaultDict[int, DefaultDict[int, int]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
model_dump = Path("/tmp/slurm-31207/actives")

for ckpt_dump in tqdm(sorted(model_dump.iterdir()), desc="Checkpoint"):
    ckpt_step = int(ckpt_dump.name)
    
    # Gathering
    for layer_dump in tqdm(sorted(ckpt_dump.iterdir()), desc="Layer", leave=False):
        layer_number = int(layer_dump.name)
        for probs_dump in tqdm(sorted(layer_dump.iterdir())[:64], desc="Probs", leave=False):
            probs = torch.load(probs_dump, map_location="cpu")
            for _, expert_id in probs.nonzero():
                record[ckpt_step][layer_number][expert_id.item()] += 1

    # Plotting
    nrows, ncols = ceil(len(record[ckpt_step]) / 4), 4
    fig, axs = plt.subplots(nrows, ncols, figsize=(16, 16))
    axs = axs.flatten()

    for i, layer_number in enumerate(sorted(record[ckpt_step].keys())):
        heatmap_data = np.zeros((8, 8), dtype=int)
        for expert_id, active_counts in record[ckpt_step][layer_number].items():
            row, col = expert_id // 8, expert_id % 8
            heatmap_data[row, col] = active_counts

        ax = axs[i]
        im = ax.imshow(heatmap_data, cmap="viridis")
        ax.set_title(f"Layer {layer_number}")
        ax.set_xticks([])
        ax.set_yticks([])

        for (y, x), val in np.ndenumerate(heatmap_data):
            ax.text(x, y, f"{val}", ha='center', va='center', fontsize=6, color="white")

    for j in range(len(record.keys()), len(axs)):
        axs[j].axis('off')

    fig.suptitle(f"Expert Utilization Heatmaps Across Layers (Step {ckpt_step})", fontsize=20)
    plt.tight_layout()

    dmp = Path("figures/expert-utilization-across-layers/flame-moe-721m")
    dmp.mkdir(mode=500, parents=True, exist_ok=True)
    plt.savefig(Path(dmp, f"{ckpt_step}.png"), dpi=300)
    plt.close()


Checkpoint:   0%|          | 0/6 [00:00<?, ?it/s]

Layer:   0%|          | 0/11 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Layer:   0%|          | 0/11 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Layer:   0%|          | 0/11 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Layer:   0%|          | 0/11 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Layer:   0%|          | 0/11 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Layer:   0%|          | 0/11 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]

Probs:   0%|          | 0/64 [00:00<?, ?it/s]