# Mixture of Experts

In this notebook, you'll build and explore the core mechanics of mixture-of-experts (MoE) layers using toy-scale PyTorch models. No pretrained models or GPUs needed—everything runs in seconds on CPU.

**What you'll do:**
- Build a router from scratch (a single linear layer + softmax + top-k) and observe how different input vectors route to different experts
- Build a complete MoE layer with 4 expert FFNs and a router, run a forward pass on a batch of 8 tokens, and compare active parameters to total parameters
- Simulate expert routing on real tokenized sentences and visualize per-token expert assignments—looking for emergent specialization patterns
- Design an experiment that trains a toy MoE model with and without an auxiliary load-balancing loss, tracking expert utilization to observe router collapse

**For each exercise, PREDICT the output before running the cell.** Wrong predictions are more valuable than correct ones—they reveal gaps in your mental model.

In [None]:
# Setup—self-contained for Google Colab
# No extra pip installs needed—torch and matplotlib are in Colab by default.

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print('Setup complete.')

---

## Exercise 1: Build a Router from Scratch (Guided)

The router is the new component that MoE adds to the transformer block. From the lesson, you know it is a single linear layer followed by softmax and top-k selection—the same dot-product + softmax pattern as attention, but selecting experts instead of tokens.

In this exercise, you'll build a router for **4 experts** on a toy `d_model=64` hidden state. You'll:
1. Create the router weight matrix `W_router` of shape `(num_experts, d_model)`
2. Compute router logits via a matrix multiply: `router_logits = W_router @ hidden_state`
3. Apply softmax to get a probability distribution over experts
4. Select the top-2 experts
5. Visualize how 5 different input vectors route to different experts

**Before running, predict:**
- Will all 5 input vectors route to the same 2 experts, or different ones?
- How confident will the router be? Will the top-2 probabilities be close (e.g., 0.30 and 0.28) or will one dominate (e.g., 0.80 and 0.10)?
- The router is initialized with random weights. Does it still produce *different* routing decisions for different inputs? Why?

In [None]:
# --- The entire router in ~10 lines ---
# Compare to attention: same dot-product + softmax pattern.

d_model = 64
num_experts = 4
k = 2  # top-k: how many experts to activate per token

# The router is ONE linear layer. No bias.
# Each row of W_router is a learned "expert embedding."
# The dot product measures relevance of the token to each expert.
W_router = torch.randn(num_experts, d_model) * 0.1  # small init

# Generate 5 different input vectors (simulating 5 different token hidden states)
inputs = torch.randn(5, d_model)

# Step 1: Router logits — one score per expert, for each input
# Shape: (5, num_experts) = (5, 4)
router_logits = inputs @ W_router.T  # dot product: (5, 64) @ (64, 4) = (5, 4)
print(f'Router logits shape: {router_logits.shape}')
print(f'Router logits (raw scores):\n{router_logits}')
print()

# Step 2: Softmax — convert to probability distribution over experts
# Same softmax you've used dozens of times for attention weights.
router_probs = F.softmax(router_logits, dim=-1)
print(f'Router probabilities (sum to 1 per input):\n{router_probs}')
print(f'Sum per input: {router_probs.sum(dim=-1)}')
print()

# Step 3: Top-k selection — pick the top-2 experts per input
top_k_probs, top_k_indices = torch.topk(router_probs, k=k, dim=-1)
print(f'Top-{k} expert indices per input:\n{top_k_indices}')
print(f'Top-{k} probabilities per input:\n{top_k_probs}')
print()

# Normalize the top-k probabilities so they sum to 1
# (we only use k experts, so renormalize to get valid weights)
top_k_weights = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
print(f'Normalized top-{k} weights (sum to 1):\n{top_k_weights}')

In [None]:
# --- Visualize: router probabilities for all 5 inputs ---
# Bar chart showing the probability distribution over 4 experts for each input.
# The top-2 selected experts are highlighted.

fig, axes = plt.subplots(1, 5, figsize=(14, 3.5), sharey=True)

expert_colors = ['#60a5fa', '#f59e0b', '#34d399', '#a78bfa']  # blue, amber, green, violet

for i, ax in enumerate(axes):
    probs = router_probs[i].detach().numpy()
    selected = set(top_k_indices[i].tolist())

    bars = ax.bar(
        range(num_experts), probs,
        color=[expert_colors[j] if j in selected else '#334155' for j in range(num_experts)],
        edgecolor='white', linewidth=0.5,
    )

    # Annotate probabilities
    for j, (bar, p) in enumerate(zip(bars, probs)):
        label = f'{p:.2f}'
        if j in selected:
            label += ' *'
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                label, ha='center', va='bottom', fontsize=8, color='white')

    ax.set_title(f'Input {i+1}', fontsize=10, fontweight='bold')
    ax.set_xticks(range(num_experts))
    ax.set_xticklabels([f'E{j}' for j in range(num_experts)], fontsize=9)
    ax.set_ylim(0, 0.55)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

axes[0].set_ylabel('Router Probability', fontsize=10)
fig.suptitle('Router Probabilities for 5 Different Inputs (top-2 selected = colored, * marked)',
             fontsize=12, fontweight='bold', y=1.04)
plt.tight_layout()
plt.show()

# Summary
print('\nRouting summary:')
for i in range(5):
    experts = top_k_indices[i].tolist()
    weights = top_k_weights[i].tolist()
    print(f'  Input {i+1} -> Expert {experts[0]} (weight {weights[0]:.2f}), '
          f'Expert {experts[1]} (weight {weights[1]:.2f})')

# Count how many unique expert pairs we see
pairs = [tuple(sorted(top_k_indices[i].tolist())) for i in range(5)]
unique_pairs = set(pairs)
print(f'\nUnique expert pairs: {len(unique_pairs)} out of 5 inputs')
print(f'Expert pairs: {pairs}')

**What just happened:** The router—a single matrix multiply + softmax—naturally produces different routing decisions for different inputs. Even with random initialization, different input vectors have different dot products with the expert embeddings (rows of `W_router`), leading to different top-k selections.

This is the same mechanism as attention scores: `Q @ K^T` produces different attention patterns for different query tokens because the dot products depend on the input. Here, `input @ W_router^T` produces different expert scores for different hidden states.

**Key observation:** The probabilities are relatively flat with random initialization (close to uniform 0.25 each). After training, the router would learn sharper distributions—confidently routing tokens to the most relevant experts. But even before training, the mechanism works: different inputs route differently.

---

## Exercise 2: MoE Forward Pass on a Toy Model (Supported)

Now build a complete MoE layer: 4 expert FFNs (each with `d_model=64`, `d_ff=256`) plus the router from Exercise 1. You'll run a forward pass on a batch of 8 tokens and compare the output shape, active parameters, and total parameters to a single dense FFN.

The first expert FFN and the router are provided. You'll build the remaining experts and implement the weighted combination of expert outputs.

**Before running, predict:**
- What will the output shape be? (Same as input? Different?)
- How many total parameters will the MoE layer have vs a single dense FFN?
- How many parameters are *active* per token (with top-2 routing)?
- Will the MoE output be identical to a dense FFN output? Why or why not?

In [None]:
# --- A single expert FFN (provided) ---
# Same structure as the FFN in a transformer block:
# Linear(d_model -> d_ff) -> GELU -> Linear(d_ff -> d_model)

class ExpertFFN(nn.Module):
    """A single expert: standard two-layer FFN with GELU."""

    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.gelu(self.w1(x)))


# --- The MoE layer (you complete this) ---

class MoELayer(nn.Module):
    """Mixture of Experts layer: N expert FFNs + a router."""

    def __init__(self, d_model: int, d_ff: int, num_experts: int, top_k: int):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router: single linear layer (no bias), maps d_model -> num_experts
        self.router = nn.Linear(d_model, num_experts, bias=False)

        # TODO: Create a list of expert FFNs using nn.ModuleList.
        # Each expert is an ExpertFFN(d_model, d_ff).
        # You need num_experts of them.
        # Hint: nn.ModuleList([ExpertFFN(...) for _ in range(...)])
        self.experts = None  # Replace this line

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: shape (batch_size, d_model)
        Returns: shape (batch_size, d_model)
        """
        batch_size = x.shape[0]

        # Step 1: Compute router probabilities
        router_logits = self.router(x)  # (batch_size, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)  # (batch_size, num_experts)

        # Step 2: Select top-k experts
        top_k_probs, top_k_indices = torch.topk(router_probs, k=self.top_k, dim=-1)
        # Normalize top-k weights to sum to 1
        top_k_weights = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Step 3: Run selected experts and combine outputs
        # TODO: For each token in the batch, run the top-k selected experts
        # and compute the weighted sum of their outputs.
        #
        # The result should be:
        #   output[i] = sum(top_k_weights[i, j] * expert_j(x[i]) for j in selected)
        #
        # Approach:
        #   1. Initialize output as zeros: torch.zeros_like(x)
        #   2. Loop over each token i in range(batch_size)
        #   3. For each of the top_k selected experts (j in range(self.top_k)):
        #      - Get the expert index: expert_idx = top_k_indices[i, j].item()
        #      - Get the weight: weight = top_k_weights[i, j]
        #      - Run the expert: expert_output = self.experts[expert_idx](x[i:i+1])
        #      - Add weighted output: output[i:i+1] += weight * expert_output
        #
        # Note: x[i:i+1] keeps the batch dimension (shape [1, d_model]) instead
        # of x[i] which would be (d_model,). The expert expects a batch dim.

        output = None  # Replace this — follow the approach above

        return output, router_probs, top_k_indices


print('Classes defined. Fill in the TODOs and run the next cell.')

<details>
<summary>Solution</summary>

The key insight is that each expert is an independent FFN, and we only run the top-k selected ones per token. The weighted combination mirrors attention's weighted sum of values—but over expert outputs instead of token values.

```python
# In __init__:
self.experts = nn.ModuleList([ExpertFFN(d_model, d_ff) for _ in range(num_experts)])

# In forward, Step 3:
output = torch.zeros_like(x)
for i in range(batch_size):
    for j in range(self.top_k):
        expert_idx = top_k_indices[i, j].item()
        weight = top_k_weights[i, j]
        expert_output = self.experts[expert_idx](x[i:i+1])
        output[i:i+1] += weight * expert_output
```

This is a simple loop implementation. Production MoE uses batched expert execution for efficiency (grouping all tokens assigned to the same expert), but the loop makes the logic clear: for each token, run only the selected experts, weight their outputs, and sum.

</details>

### Helper: Complete MoELayer for Remaining Exercises

**Run the cell below** to get a working `MoELayer` and `ExpertFFN` for Exercises 3 and 4. This ensures the remaining exercises work correctly regardless of whether your Exercise 2 implementation has bugs. If you completed Exercise 2 successfully, this cell just redefines the same classes.

In [None]:
# --- Complete ExpertFFN and MoELayer (reference implementation) ---
# Run this cell to ensure a working MoELayer is available for Exercises 3 and 4.

class ExpertFFN(nn.Module):
    """A single expert: standard two-layer FFN with GELU."""

    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.gelu(self.w1(x)))


class MoELayer(nn.Module):
    """Mixture of Experts layer: N expert FFNs + a router."""

    def __init__(self, d_model: int, d_ff: int, num_experts: int, top_k: int):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(d_model, num_experts, bias=False)
        self.experts = nn.ModuleList([ExpertFFN(d_model, d_ff) for _ in range(num_experts)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        router_logits = self.router(x)
        router_probs = F.softmax(router_logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(router_probs, k=self.top_k, dim=-1)
        top_k_weights = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        output = torch.zeros_like(x)
        for i in range(batch_size):
            for j in range(self.top_k):
                expert_idx = top_k_indices[i, j].item()
                weight = top_k_weights[i, j]
                expert_output = self.experts[expert_idx](x[i:i+1])
                output[i:i+1] += weight * expert_output

        return output, router_probs, top_k_indices


print('ExpertFFN and MoELayer defined. Ready for remaining exercises.')

In [None]:
# --- Run the MoE forward pass and compare to a dense FFN ---

d_model = 64
d_ff = 256
num_experts = 4
top_k = 2
batch_size = 8

torch.manual_seed(42)

# Create the MoE layer
moe = MoELayer(d_model, d_ff, num_experts, top_k)

# Create a single dense FFN for comparison
dense_ffn = ExpertFFN(d_model, d_ff)

# Input: batch of 8 tokens, each with d_model=64 features
x = torch.randn(batch_size, d_model)

# Forward pass through MoE
moe_output, moe_probs, moe_indices = moe(x)

# Forward pass through dense FFN
dense_output = dense_ffn(x)

# --- Compare shapes ---
print('=== Shape Comparison ===')
print(f'Input shape:       {x.shape}')
print(f'MoE output shape:  {moe_output.shape}')
print(f'Dense output shape: {dense_output.shape}')
print(f'Same shape? {moe_output.shape == dense_output.shape}')
print()

# --- Compare parameter counts ---
def count_parameters(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters())

moe_total_params = count_parameters(moe)
dense_params = count_parameters(dense_ffn)
router_params = count_parameters(moe.router)
single_expert_params = count_parameters(moe.experts[0])

# Active params per token: router + top_k experts
active_params_per_token = router_params + top_k * single_expert_params

print('=== Parameter Comparison ===')
print(f'Dense FFN parameters:          {dense_params:>8,}')
print(f'Single expert parameters:      {single_expert_params:>8,}')
print(f'Router parameters:             {router_params:>8,}')
print(f'MoE total parameters:          {moe_total_params:>8,}')
print(f'MoE active params per token:   {active_params_per_token:>8,} (router + {top_k} experts)')
print()
print(f'MoE total / dense:  {moe_total_params / dense_params:.1f}x more total params')
print(f'MoE active / dense: {active_params_per_token / dense_params:.2f}x active params per token')
print()

# --- Routing decisions ---
print('=== Routing Decisions ===')
for i in range(batch_size):
    experts = moe_indices[i].tolist()
    probs = moe_probs[i].detach().numpy()
    print(f'Token {i}: Expert {experts[0]} & Expert {experts[1]}  '
          f'(probs: [{", ".join(f"{p:.3f}" for p in probs)}])')

# Count expert usage
expert_usage = torch.zeros(num_experts)
for i in range(batch_size):
    for j in range(top_k):
        expert_usage[moe_indices[i, j]] += 1

print(f'\nExpert usage across {batch_size} tokens (top-{top_k}):')
for e in range(num_experts):
    bar = '#' * int(expert_usage[e].item())
    print(f'  Expert {e}: {int(expert_usage[e].item()):>2} tokens  {bar}')

**What just happened:** The MoE layer produces the exact same output shape as a dense FFN—`(batch_size, d_model)`. This is critical: the MoE layer is a *drop-in replacement* for the FFN sub-layer. The rest of the transformer block (attention, residual stream, layer norm) is completely unchanged.

The parameter comparison makes the decoupling concrete:
- **Total parameters**: MoE has ~4x more than the dense FFN (4 experts, each roughly the same size as the dense FFN, plus the router)
- **Active parameters per token**: roughly ~2x the dense FFN (only 2 of 4 experts fire, plus the tiny router)

Scale this up to Mixtral's dimensions: 8 experts instead of 4, and the total/active ratio becomes ~47B total / ~13B active. More stored knowledge, same per-token compute.

---

## Exercise 3: Visualize Expert Routing on Real Text (Supported)

In this exercise, you'll simulate how a router assigns experts to real tokens. We'll use a small pretrained model's tokenizer to tokenize sentences, generate embeddings via a randomly initialized embedding layer (a proxy for real hidden states), and route each token through a trained-from-scratch router.

The goal: see per-token routing in action and look for patterns. Do function words cluster? Do domain words cluster? Is the routing different across sentences?

The first sentence ("The mitochondria is the powerhouse of the cell") is fully set up. You extend the pattern to additional sentences.

**Before running, predict:**
- Will identical tokens (e.g., "the" appearing multiple times) always route to the same expert?
- Will tokens from the same sentence all route to the same expert?
- After a few training steps on synthetic data, will you see any grouping patterns emerge?

In [None]:
# --- Tokenizer setup ---
# We use a simple word-level tokenizer for clarity.
# Production models use BPE (byte-pair encoding), but word-level makes
# the routing patterns easier to interpret.

sentences = [
    'The mitochondria is the powerhouse of the cell',
    'The stock market crashed yesterday after the announcement',
    'def forward self x return self linear x',
    'Le chat est sur le tapis',
    'The gradient flows backward through the computation graph',
]

# Build a simple vocabulary from all tokens
all_tokens = []
for sent in sentences:
    all_tokens.extend(sent.lower().split())
vocab = sorted(set(all_tokens))
token_to_id = {t: i for i, t in enumerate(vocab)}
id_to_token = {i: t for t, i in token_to_id.items()}

print(f'Vocabulary size: {len(vocab)}')
print(f'Tokens: {vocab}')

In [None]:
# --- Build a tiny model: embedding + router ---
# We train the router for a few steps on a simple proxy task so that
# routing patterns can emerge. Without training, the random router
# produces near-uniform distributions (as you saw in Exercise 1).

d_model = 32
num_experts = 8
top_k = 2

torch.manual_seed(123)

embedding = nn.Embedding(len(vocab), d_model)
router = nn.Linear(d_model, num_experts, bias=False)

# Simple proxy task: train the router to differentiate token types.
# We create soft target distributions that push function words toward
# some experts and content words toward others. This simulates what
# emerges naturally during real MoE training.

function_words = {'the', 'is', 'of', 'a', 'an', 'and', 'or', 'in', 'on', 'at',
                  'le', 'est', 'sur', 'after', 'self', 'return'}

def get_target_distribution(token: str) -> torch.Tensor:
    """Create a soft target distribution that pushes different token types
    toward different experts. This simulates emergent specialization."""
    dist = torch.ones(num_experts) * 0.05  # small uniform baseline
    if token in function_words:
        dist[0] += 0.5  # function words prefer Expert 0
        dist[1] += 0.3  # and Expert 1
    elif any(c.isdigit() for c in token):
        dist[4] += 0.6  # numbers prefer Expert 4
        dist[5] += 0.2
    else:
        # Content words get distributed across experts 2-7
        # based on a hash of the token (simulating learned specialization)
        h = hash(token) % 6
        dist[2 + h % 6] += 0.5
        dist[2 + (h + 1) % 6] += 0.3
    return dist / dist.sum()

# Train the router for a few steps
optimizer = torch.optim.Adam(list(embedding.parameters()) + list(router.parameters()), lr=0.01)

for step in range(200):
    total_loss = 0.0
    for token_str, token_id in token_to_id.items():
        x = embedding(torch.tensor([token_id]))
        logits = router(x)  # (1, num_experts)
        target = get_target_distribution(token_str).unsqueeze(0)  # (1, num_experts)
        loss = F.kl_div(F.log_softmax(logits, dim=-1), target, reduction='batchmean')
        total_loss += loss.item()
        loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if step % 50 == 0:
        print(f'Step {step:>3}: avg loss = {total_loss / len(token_to_id):.4f}')

print('\nRouter training complete.')

In [None]:
# --- Visualize routing for the first sentence (fully set up) ---

expert_colors_8 = [
    '#60a5fa',  # E0 - blue (function words)
    '#38bdf8',  # E1 - light blue (function words)
    '#f59e0b',  # E2 - amber
    '#34d399',  # E3 - emerald
    '#a78bfa',  # E4 - violet
    '#fb923c',  # E5 - orange
    '#f87171',  # E6 - rose
    '#e879f9',  # E7 - fuchsia
]


def visualize_routing(sentence: str, ax: plt.Axes) -> list[int]:
    """Route each token in the sentence and visualize expert assignments."""
    tokens = sentence.lower().split()
    token_ids = torch.tensor([token_to_id[t] for t in tokens])

    with torch.no_grad():
        hidden = embedding(token_ids)  # (seq_len, d_model)
        logits = router(hidden)  # (seq_len, num_experts)
        probs = F.softmax(logits, dim=-1)
        _, top_indices = torch.topk(probs, k=1, dim=-1)  # top-1 for visualization clarity

    expert_ids = top_indices.squeeze(-1).tolist()

    # Draw colored boxes for each token
    x_pos = 0
    for i, (token, expert_id) in enumerate(zip(tokens, expert_ids)):
        box_width = len(token) * 0.12 + 0.15
        color = expert_colors_8[expert_id]
        rect = mpatches.FancyBboxPatch(
            (x_pos, 0.2), box_width, 0.5,
            boxstyle='round,pad=0.05',
            facecolor=color, edgecolor='white', linewidth=0.8, alpha=0.7
        )
        ax.add_patch(rect)
        ax.text(x_pos + box_width / 2, 0.45, token,
                ha='center', va='center', fontsize=9, color='white', fontweight='bold')
        ax.text(x_pos + box_width / 2, 0.05, f'E{expert_id}',
                ha='center', va='center', fontsize=7, color=color)
        x_pos += box_width + 0.05

    ax.set_xlim(-0.1, x_pos + 0.1)
    ax.set_ylim(-0.1, 0.85)
    ax.set_aspect('equal')
    ax.axis('off')

    return expert_ids


# Plot the first sentence
fig, ax = plt.subplots(figsize=(14, 1.5))
expert_ids = visualize_routing(sentences[0], ax)
ax.set_title(f'"{sentences[0]}"', fontsize=11, fontweight='bold', pad=10)
plt.tight_layout()
plt.show()

# Analyze
tokens_0 = sentences[0].lower().split()
print('Token -> Expert mapping:')
for t, e in zip(tokens_0, expert_ids):
    category = 'function' if t in function_words else 'content'
    print(f'  "{t}" -> Expert {e}  ({category} word)')

In [None]:
# --- TODO: Visualize routing for ALL 5 sentences ---
# Create a figure with 5 subplots (one per sentence) stacked vertically.
# Use the visualize_routing function from above.
#
# After plotting, analyze the routing patterns:
#   1. Do function words consistently route to the same expert(s)?
#   2. Do content words from different domains route differently?
#   3. Does "the" always go to the same expert regardless of sentence?
#
# Hints:
#   fig, axes = plt.subplots(len(sentences), 1, figsize=(14, len(sentences) * 1.8))
#   for i, (sent, ax) in enumerate(zip(sentences, axes)):
#       expert_ids = visualize_routing(sent, ax)
#       ax.set_title(f'Sentence {i+1}: "{sent}"', ...)
#
# Then print a cross-sentence analysis:
#   - For each unique token that appears in multiple sentences,
#     check if it routes to the same expert.
#   - Count how many tokens per expert across all sentences.

# YOUR CODE HERE (15-30 lines)


<details>
<summary>Solution</summary>

The key insight is that identical tokens should route consistently (same embedding produces same router scores), while different tokens route based on their learned embeddings. Function words cluster, content words spread across experts.

```python
fig, axes = plt.subplots(len(sentences), 1, figsize=(14, len(sentences) * 1.8))

all_routing = {}  # token -> list of expert ids across sentences
expert_counts = torch.zeros(num_experts)

for i, (sent, ax) in enumerate(zip(sentences, axes)):
    expert_ids = visualize_routing(sent, ax)
    ax.set_title(f'Sentence {i+1}: "{sent}"', fontsize=10, fontweight='bold', pad=8)

    # Track routing per token
    for token, eid in zip(sent.lower().split(), expert_ids):
        if token not in all_routing:
            all_routing[token] = []
        all_routing[token].append(eid)
        expert_counts[eid] += 1

plt.suptitle('Per-Token Expert Routing Across 5 Sentences', fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Cross-sentence analysis
print('\n=== Cross-Sentence Routing Consistency ===')
for token, expert_ids in sorted(all_routing.items()):
    if len(expert_ids) > 1:
        consistent = len(set(expert_ids)) == 1
        category = 'function' if token in function_words else 'content'
        print(f'  "{token}" ({category}): Expert(s) {expert_ids} '
              f'{"-- consistent" if consistent else "-- VARIES"}')

print(f'\n=== Expert Utilization (across all tokens) ===')
total_tokens = expert_counts.sum().item()
for e in range(num_experts):
    count = int(expert_counts[e].item())
    pct = count / total_tokens * 100
    bar = '#' * int(pct / 2)
    print(f'  Expert {e}: {count:>3} tokens ({pct:>5.1f}%)  {bar}')
```

**Expected findings:** Identical tokens ("the" appearing in multiple sentences) should always route to the same expert because they produce the same embedding, which produces the same router scores. Function words should cluster toward Experts 0-1, while content words spread across Experts 2-7. This mirrors the per-token routing from the lesson diagram: specialization is emergent and per-token, not per-topic.

</details>

**What just happened:** You saw per-token routing in action on real sentences. The key observations from the lesson are visible even in this toy model:

1. **Different tokens in the same sentence route to different experts.** "The" and "mitochondria" do not share an expert.
2. **Identical tokens route consistently.** The same token always produces the same embedding, so the router always assigns it the same way.
3. **Function words cluster.** Words like "the," "is," "of" tend to share experts, while content words spread across different experts.
4. **Specialization is per-token, not per-topic.** There is no single "biology expert"—different content words route to different experts based on their learned representations.

In a real MoE model (like Mixtral), these patterns emerge from training on natural language. Early layers show more syntactic specialization, later layers show more semantic patterns. The boundaries do not map cleanly to human-interpretable categories.

---

## Exercise 4: Router Collapse Experiment (Independent)

The lesson described router collapse: without load balancing, one expert dominates and others atrophy. The positive feedback loop—popular expert gets more gradient updates, improves, gets even more tokens—causes the model to degenerate to approximately a dense model.

**Your task:** Train a toy MoE model (4 experts, tiny dataset) with and without an auxiliary load-balancing loss. Track expert utilization over training steps. Plot the distribution of tokens across experts at different checkpoints.

**Requirements:**
1. Create a simple regression or classification task (e.g., learn a function from random inputs to outputs)
2. Build a small MoE model using the `MoELayer` from Exercise 2
3. Train it **without** an auxiliary loss—observe collapse
4. Train it **with** an auxiliary loss that penalizes uneven expert utilization
5. Plot expert utilization at steps 0, 100, 500, 1000 for both conditions

**The auxiliary loss:** A simple approach is to compute the fraction of tokens sent to each expert in a batch, then penalize deviation from uniform. If `f_i` is the fraction of tokens routed to expert `i`, the load-balancing loss is:

```
L_balance = num_experts * sum(f_i * P_i)
```

where `P_i` is the mean router probability for expert `i` across the batch. This product is minimized when both `f_i` and `P_i` are uniform.

**No skeleton is provided.** Design the experiment yourself. The solution is in the `<details>` block below.

In [None]:
# Your experiment here.
#
# 1. Define a simple toy task (e.g., random regression: x -> y)
# 2. Build a model with an MoE layer (reuse MoELayer or build a similar one)
# 3. Train WITHOUT auxiliary loss, tracking expert utilization per step
# 4. Train WITH auxiliary loss, tracking expert utilization per step
# 5. Plot expert utilization over time for both conditions
#
# Tips:
# - Use a small model: d_model=32, d_ff=64, 4 experts, top-2
# - Track utilization as: count how many tokens each expert is selected for
# - The auxiliary loss weight (alpha) matters: too small and collapse still happens,
#   too large and it overrides the task loss. Start with alpha=0.1.
# - Train for 1000 steps with batches of 64.
# - Record utilization at steps 0, 50, 100, 200, 500, 1000.



<details>
<summary>Solution</summary>

**Design rationale:** We create a simple regression task (learn a random linear mapping). The MoE model has 4 experts with top-2 routing. We train two copies: one with only the task loss, one with task loss + auxiliary load-balancing loss. We track which experts are selected at each step.

```python
class ToyMoEModel(nn.Module):
    """Tiny model: input projection -> MoE layer -> output projection."""

    def __init__(self, d_in: int, d_model: int, d_ff: int,
                 d_out: int, num_experts: int, top_k: int):
        super().__init__()
        self.input_proj = nn.Linear(d_in, d_model)
        self.moe = MoELayer(d_model, d_ff, num_experts, top_k)
        self.output_proj = nn.Linear(d_model, d_out)

    def forward(self, x: torch.Tensor):
        h = F.gelu(self.input_proj(x))
        moe_out, router_probs, top_k_indices = self.moe(h)
        out = self.output_proj(moe_out)
        return out, router_probs, top_k_indices


def compute_load_balance_loss(
    router_probs: torch.Tensor,
    top_k_indices: torch.Tensor,
    num_experts: int,
) -> torch.Tensor:
    """Auxiliary loss penalizing uneven expert utilization.

    router_probs: (batch_size, num_experts)—softmax probabilities
    top_k_indices: (batch_size, top_k)—selected expert indices
    """
    batch_size = router_probs.shape[0]

    # f_i: fraction of tokens routed to expert i
    # Create a one-hot mask for selected experts, then average over batch
    expert_mask = torch.zeros(batch_size, num_experts)
    expert_mask.scatter_(1, top_k_indices, 1.0)
    f = expert_mask.mean(dim=0)  # (num_experts,)

    # P_i: mean router probability for expert i
    P = router_probs.mean(dim=0)  # (num_experts,)

    # Auxiliary loss: num_experts * sum(f_i * P_i)
    # Minimized when both f and P are uniform (1/num_experts each)
    return num_experts * (f * P).sum()


def get_expert_utilization(
    top_k_indices: torch.Tensor,
    num_experts: int,
) -> list[float]:
    """Compute fraction of tokens assigned to each expert."""
    counts = torch.zeros(num_experts)
    for e in range(num_experts):
        counts[e] = (top_k_indices == e).sum().item()
    total = counts.sum().item()
    if total == 0:
        return [0.0] * num_experts
    return (counts / total).tolist()


# --- Training setup ---
d_in = 16
d_model = 32
d_ff = 64
d_out = 8
num_experts = 4
top_k = 2
batch_size = 64
num_steps = 1000
alpha = 0.1  # weight for auxiliary loss

# Random regression target
torch.manual_seed(42)
target_W = torch.randn(d_in, d_out) * 0.5

# Checkpoints to record
checkpoints = [0, 50, 100, 200, 500, 1000]


def train_model(use_aux_loss: bool, seed: int = 42):
    """Train a toy MoE model, return utilization history."""
    torch.manual_seed(seed)
    model = ToyMoEModel(d_in, d_model, d_ff, d_out, num_experts, top_k)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    utilization_history = {}  # step -> [frac_expert_0, frac_expert_1, ...]
    losses = []

    for step in range(num_steps + 1):
        # Generate random batch
        x = torch.randn(batch_size, d_in)
        y = x @ target_W  # target output

        # Forward pass
        pred, router_probs, top_k_indices = model(x)

        # Task loss
        task_loss = F.mse_loss(pred, y)

        # Auxiliary loss
        if use_aux_loss:
            aux_loss = compute_load_balance_loss(router_probs, top_k_indices, num_experts)
            total_loss = task_loss + alpha * aux_loss
        else:
            total_loss = task_loss

        # Record utilization at checkpoints
        if step in checkpoints:
            util = get_expert_utilization(top_k_indices, num_experts)
            utilization_history[step] = util

        # Backward pass
        if step < num_steps:  # don't optimize after the last step
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        losses.append(task_loss.item())

    return utilization_history, losses


# Train both conditions
print('Training WITHOUT auxiliary loss...')
util_no_aux, losses_no_aux = train_model(use_aux_loss=False)
print('Training WITH auxiliary loss...')
util_with_aux, losses_with_aux = train_model(use_aux_loss=True)
print('Done.\n')

# --- Plot expert utilization over training ---
fig, axes = plt.subplots(2, len(checkpoints), figsize=(16, 6), sharey=True)

expert_colors_4 = ['#60a5fa', '#f59e0b', '#34d399', '#a78bfa']

for col, step in enumerate(checkpoints):
    # No auxiliary loss
    ax_top = axes[0, col]
    util = util_no_aux[step]
    ax_top.bar(range(num_experts), util, color=expert_colors_4,
               edgecolor='white', linewidth=0.5)
    ax_top.set_title(f'Step {step}', fontsize=9)
    ax_top.set_ylim(0, 1.0)
    ax_top.axhline(y=1/num_experts, color='white', linestyle='--',
                   alpha=0.3, linewidth=0.8)
    ax_top.set_xticks(range(num_experts))
    ax_top.set_xticklabels([f'E{i}' for i in range(num_experts)], fontsize=8)
    ax_top.spines['top'].set_visible(False)
    ax_top.spines['right'].set_visible(False)
    if col == 0:
        ax_top.set_ylabel('No Aux Loss\nToken Fraction', fontsize=9)

    # With auxiliary loss
    ax_bot = axes[1, col]
    util = util_with_aux[step]
    ax_bot.bar(range(num_experts), util, color=expert_colors_4,
               edgecolor='white', linewidth=0.5)
    ax_bot.set_ylim(0, 1.0)
    ax_bot.axhline(y=1/num_experts, color='white', linestyle='--',
                   alpha=0.3, linewidth=0.8)
    ax_bot.set_xticks(range(num_experts))
    ax_bot.set_xticklabels([f'E{i}' for i in range(num_experts)], fontsize=8)
    ax_bot.spines['top'].set_visible(False)
    ax_bot.spines['right'].set_visible(False)
    if col == 0:
        ax_bot.set_ylabel('With Aux Loss\nToken Fraction', fontsize=9)

fig.suptitle('Expert Utilization Over Training: Collapse vs Balanced',
             fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# --- Summary statistics ---
print('\nFinal expert utilization (step 1000):')
print('  Without aux loss:', [f'{u:.2f}' for u in util_no_aux[1000]])
print('  With aux loss:   ', [f'{u:.2f}' for u in util_with_aux[1000]])

# Compute imbalance metric: max utilization / uniform
uniform = 1 / num_experts
imbalance_no_aux = max(util_no_aux[1000]) / uniform
imbalance_with_aux = max(util_with_aux[1000]) / uniform
print(f'\n  Imbalance ratio (1.0 = perfectly uniform):')
print(f'    Without aux: {imbalance_no_aux:.2f}x')
print(f'    With aux:    {imbalance_with_aux:.2f}x')

# Plot training losses
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(losses_no_aux, color='#f87171', alpha=0.7, linewidth=0.8, label='No aux loss')
ax.plot(losses_with_aux, color='#34d399', alpha=0.7, linewidth=0.8, label='With aux loss')
ax.set_xlabel('Step', fontsize=11)
ax.set_ylabel('Task Loss (MSE)', fontsize=11)
ax.set_title('Training Loss: Collapse vs Balanced', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

print('\nKey insight: without the auxiliary loss, the router collapses to')
print('favoring 1-2 experts. The other experts get fewer tokens, fewer')
print('gradient updates, and fall further behind. With the auxiliary loss,')
print('utilization stays more uniform, allowing all experts to learn.')
print('Load balancing is not just an efficiency trick—it prevents training collapse.')
```

**Expected findings:**
- **Without auxiliary loss:** Expert utilization becomes increasingly skewed over training. By step 1000, one or two experts handle the majority of tokens. The others are effectively dead.
- **With auxiliary loss:** Utilization stays roughly uniform (close to 25% each with 4 experts). The dashed line at 25% represents perfect balance.
- **Training loss:** Both conditions converge, but the balanced model may converge slightly better because all experts contribute useful computation.

The collapse is the positive feedback loop from the lesson: popular expert gets more gradient updates, improves, gets even more tokens. The auxiliary loss breaks this loop by gently penalizing imbalance.

</details>

---

## Key Takeaways

1. **The router is a single linear layer with softmax—the same dot-product + softmax pattern as attention.** One matrix multiplication produces expert scores, softmax converts to probabilities, top-k selects which experts to activate. Simpler than a single attention head.

2. **The MoE layer is a drop-in replacement for the FFN.** Same input shape, same output shape. The rest of the transformer block (attention, residual stream, layer norm) is unchanged. More total parameters, but only top-k experts activate per token.

3. **Expert specialization is emergent and per-token.** Different tokens in the same sentence route to different experts. Function words cluster, content words spread across experts. The patterns are learned, not designed—there is no "biology expert" in any clean sense.

4. **Without load balancing, the router collapses.** One expert dominates, others atrophy, and the model degenerates to a dense model with wasted parameters. The auxiliary loss breaks the positive feedback loop by penalizing uneven utilization.

5. **MoE decouples total parameters from per-token compute.** Four experts means ~4x total parameters but only ~2x active parameters (with top-2). Scale to Mixtral: 8 experts, ~47B total, ~13B active. More stored knowledge at the same inference cost per token. The library got bigger, but you still only talk to two librarians.