# Mixture of Experts (MoE) — A Beginner's Walkthrough

*A companion to the [KV Cache Walkthrough](kv_cache_always_on.ipynb) and [LoRA Notebook](lora_on_device.ipynb) — based on Sebastian Raschka's ["Build a Large Language Model From Scratch"](https://www.manning.com/books/build-a-large-language-model-from-scratch).*

In the KV cache notebook we optimized **inference speed** (don't recompute old K,V). In the LoRA notebook we optimized **customization** (tiny adapters instead of full fine-tuning). This notebook tackles a different question:

> **How do you make a model smarter without making every forward pass slower?**

The answer is **Mixture of Experts (MoE)** — the architecture behind models like Mixtral 8x7B, GPT-4, and DeepSeek-V2.

### The core idea

Instead of one large FeedForward network that processes every token, MoE has **multiple smaller "expert" networks** and a **router (gate)** that picks which experts to use for each token. Most experts are idle on any given token — so you get the knowledge capacity of a huge model with the compute cost of a small one.

```
Standard FFN (Dense):                    MoE FFN (Sparse):
                                         
token → [  BIG FFN  ] → output           token → [Router] → picks top-2 experts
         (always runs)                                │
                                              ┌──────┼──────────┐
                                              ▼      ▼          │
                                         [Expert 1][Expert 3]  [Expert 2,4,5,6,7,8]
                                          (active)  (active)    (IDLE — not computed!)
                                              │      │
                                              └──┬───┘
                                                 ▼
                                          weighted sum → output
```

> **How to use this notebook:** Run cells top-to-bottom (`Shift+Enter`). We build a working MoE FeedForward layer, compare it to the standard dense FFN from Chapter 4, and see how it fits into the GPT architecture.

In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

## Quick Recap: The Standard (Dense) FeedForward

In Chapter 4, we built this FeedForward network that sits inside every TransformerBlock:

```
Input x (768) → Linear (768 → 3072) → GELU → Linear (3072 → 768) → Output (768)
                     expand 4x                    contract back
```

Every token goes through the **same** FFN, using **all** the parameters. This is called a **dense** architecture — dense because every parameter is active on every forward pass.

The problem: if you want a smarter model, you make the FFN bigger (wider or deeper). But bigger FFN = more computation on **every single token**. This gets expensive fast.

In [2]:
class GELU(nn.Module):
    """Gaussian Error Linear Unit — smooth activation function (Ch 4)."""
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))


class FeedForward(nn.Module):
    """Standard dense FFN from Chapter 4.
    Every token uses ALL parameters on every forward pass."""
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], cfg["hidden_dim"]),   # 768 → 3072
            GELU(),
            nn.Linear(cfg["hidden_dim"], cfg["emb_dim"]),   # 3072 → 768
        )

    def forward(self, x):
        return self.layers(x)


# Quick parameter count
cfg_dense = {"emb_dim": 768, "hidden_dim": 768 * 4}
dense_ffn = FeedForward(cfg_dense)
dense_params = sum(p.numel() for p in dense_ffn.parameters())
print(f"Dense FFN parameters: {dense_params:,}")
print(f"  Layer 1: (768 × 3072) + 3072 bias = {768*3072 + 3072:,}")
print(f"  Layer 2: (3072 × 768) + 768 bias  = {3072*768 + 768:,}")
print(f"  All {dense_params:,} are active on EVERY token")

Dense FFN parameters: 4,722,432
  Layer 1: (768 × 3072) + 3072 bias = 2,362,368
  Layer 2: (3072 × 768) + 768 bias  = 2,360,064
  All 4,722,432 are active on EVERY token


## The MoE FeedForward: Multiple Experts + a Router

MoE replaces the single dense FFN with:

1. **Multiple expert networks** — each expert is a small FFN (same architecture as the dense one, but can be smaller)
2. **A router (gate)** — a simple linear layer that scores each expert for each token, then picks the top-k

### How the router works

```
Input token x (768-dim)
    │
    ▼
Router: nn.Linear(768, num_experts)    ← one score per expert
    │
    ▼
Scores: [0.1, 2.3, -0.5, 1.8, 0.4, -1.2, 0.7, 0.9]    ← 8 experts
    │
    ▼
Top-2:  Expert 1 (score 2.3), Expert 3 (score 1.8)       ← pick the best 2
    │
    ▼
Softmax over top-2 scores: [0.62, 0.38]                  ← normalize to weights
    │
    ▼
Output = 0.62 × Expert_1(x) + 0.38 × Expert_3(x)        ← weighted combination
```

### Why this is efficient

With 8 experts but top-2 routing, each token only runs through **2 out of 8** expert networks. That's 25% of the expert compute, but the model has 8x the total expert parameters to store knowledge in.

### The SwiGLU activation

The experts in this implementation use **SwiGLU** instead of GELU — a gated activation that modern LLMs (LLaMA, Mixtral, etc.) prefer:

```
Standard:  output = GELU(fc1(x))                    ← one path
SwiGLU:    output = SiLU(fc1(x)) * fc2(x)           ← two paths multiplied (gating)
```

The gate (`fc2`) learns to selectively filter what information passes through — like a learned attention mechanism within the FFN itself.

In [4]:
class MoEFeedForward(nn.Module):
    """Mixture of Experts FeedForward layer.
    
    Instead of one big FFN, we have multiple small "expert" FFNs.
    A router picks the top-k experts for each token.
    Only the selected experts run — the rest are skipped entirely.
    
    Architecture of each expert (SwiGLU style):
        hidden = SiLU(fc1(x)) * fc2(x)    ← gated activation
        output = fc3(hidden)                ← project back to emb_dim
    
    Args (from cfg dict):
        emb_dim:             Embedding dimension (768 for GPT-2)
        hidden_dim:          Expert hidden size (typically 4x emb_dim)
        num_experts:         Total number of experts (e.g., 8)
        num_experts_per_tok: How many experts to use per token (e.g., 2)
    """
    def __init__(self, cfg):
        super().__init__()
        self.num_experts_per_tok = cfg["num_experts_per_tok"]
        self.num_experts = cfg["num_experts"]
        self.emb_dim = cfg["emb_dim"]

        # ========== THE ROUTER (GATE) ==========
        # A simple linear layer that produces one score per expert for each token
        # Input: token embedding (768) → Output: expert scores (num_experts)
        # No bias — we don't want the router to have a default preference
        self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False)

        # ========== THE EXPERTS ==========
        # Each expert has 3 linear layers (SwiGLU architecture):
        #   fc1: "gate path"   — (emb_dim → hidden_dim), activated with SiLU
        #   fc2: "value path"  — (emb_dim → hidden_dim), multiplied with fc1's output
        #   fc3: "output path" — (hidden_dim → emb_dim), projects back
        self.fc1 = nn.ModuleList(
            [nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False)
             for _ in range(self.num_experts)]
        )
        self.fc2 = nn.ModuleList(
            [nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False)
             for _ in range(self.num_experts)]
        )
        self.fc3 = nn.ModuleList(
            [nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], bias=False)
             for _ in range(self.num_experts)]
        )

    def forward(self, x):
        # x shape: (batch, seq_len, emb_dim)
        batch, seq_len, _ = x.shape

        # ========== STEP 1: ROUTE — decide which experts handle which tokens ==========
        # Gate produces a score for each expert, for each token
        scores = self.gate(x)  # (batch, seq_len, num_experts)

        # Pick the top-k experts with the highest scores
        topk_scores, topk_indices = torch.topk(
            scores, self.num_experts_per_tok, dim=-1
        )  # both: (batch, seq_len, num_experts_per_tok)

        # Softmax over ONLY the top-k scores → these become the mixing weights
        # (we don't softmax over ALL experts — only the selected ones)
        topk_probs = torch.softmax(topk_scores, dim=-1)
        # topk_probs: (batch, seq_len, num_experts_per_tok)
        # e.g., [0.62, 0.38] meaning 62% weight on expert 1, 38% on expert 3

        # ========== STEP 2: FLATTEN for efficient expert dispatch ==========
        # Merge batch and seq_len dims: (batch * seq_len, emb_dim)
        x_flat = x.reshape(batch * seq_len, -1)
        out_flat = torch.zeros(
            batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype
        )

        topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)
        topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)

        # ========== STEP 3: RUN EACH ACTIVE EXPERT ==========
        # Find which experts were actually selected by at least one token
        unique_experts = torch.unique(topk_indices_flat)

        for expert_id_tensor in unique_experts:
            expert_id = int(expert_id_tensor.item())

            # Find which tokens selected this expert
            mask = topk_indices_flat == expert_id  # (total_tokens, num_experts_per_tok)
            if not mask.any():
                continue

            # Get indices of tokens that use this expert
            token_mask = mask.any(dim=-1)  # (total_tokens,) — True if token uses this expert
            selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)
            if selected_idx.numel() == 0:
                continue

            # Gather the input for just these tokens
            expert_input = x_flat.index_select(0, selected_idx)

            # ===== SwiGLU computation =====
            # gate_path = SiLU(fc1(x))  ← smooth gating
            # value_path = fc2(x)        ← what to pass through
            # hidden = gate_path * value_path  ← element-wise gating
            # output = fc3(hidden)       ← project back to emb_dim
            hidden = F.silu(self.fc1[expert_id](expert_input)) * \
                     self.fc2[expert_id](expert_input)
            expert_out = self.fc3[expert_id](hidden)

            # Get the routing probability for this expert
            # (each token may have selected this expert in slot 0 or slot 1 of top-k)
            mask_selected = mask[selected_idx]
            slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)
            selected_probs = torch.gather(
                topk_probs_flat.index_select(0, selected_idx),
                dim=-1, index=slot_indices
            ).squeeze(-1)

            # Accumulate: output += expert_output * routing_weight
            # index_add_ scatters the results back to the correct token positions
            out_flat.index_add_(
                0, selected_idx, expert_out * selected_probs.unsqueeze(-1)
            )

        # ========== STEP 4: RESHAPE back to (batch, seq_len, emb_dim) ==========
        return out_flat.reshape(batch, seq_len, self.emb_dim)


print("MoEFeedForward class defined!")

MoEFeedForward class defined!


## Tracing Through an Example

Let's walk through what happens to a single token step by step.

In [5]:
# Create an MoE layer with 8 experts, top-2 routing
cfg_moe = {
    "emb_dim": 768,
    "hidden_dim": 768 * 4,  # 3072 — same hidden size as dense FFN
    "num_experts": 8,
    "num_experts_per_tok": 2,
}

torch.manual_seed(42)
moe_ffn = MoEFeedForward(cfg_moe)

# Create a single token embedding
x = torch.randn(1, 1, 768)  # batch=1, seq_len=1, emb_dim=768

# Step 1: See what the router produces
with torch.no_grad():
    scores = moe_ffn.gate(x)  # (1, 1, 8) — one score per expert
    topk_scores, topk_indices = torch.topk(scores, 2, dim=-1)
    topk_probs = torch.softmax(topk_scores, dim=-1)

print("=== Router Decision for One Token ===")
print(f"\nAll expert scores: {scores.squeeze().tolist()}")
print(f"\nSelected experts: {topk_indices.squeeze().tolist()}")
print(f"Selected scores:  {topk_scores.squeeze().tolist()}")
print(f"Routing weights (after softmax): {topk_probs.squeeze().tolist()}")

# Visualize which experts are active vs idle
selected = set(topk_indices.squeeze().tolist())
print(f"\nExpert activity:")
for i in range(8):
    status = "ACTIVE" if i in selected else "idle"
    score = scores.squeeze()[i].item()
    bar = "|" + "#" * max(0, int((score + 2) * 5)) + " " * max(0, 20 - int((score + 2) * 5)) + "|"
    print(f"  Expert {i}: {bar} score={score:+.3f}  [{status}]")

print(f"\n→ Only {len(selected)}/{cfg_moe['num_experts']} experts computed!")
print(f"  Compute savings: {(1 - len(selected)/cfg_moe['num_experts'])*100:.0f}%")

=== Router Decision for One Token ===

All expert scores: [0.17712165415287018, 0.24847665429115295, 0.74974524974823, -0.02854180708527565, 0.4913672208786011, 0.8514304161071777, -0.6397891640663147, -0.4104754328727722]

Selected experts: [5, 2]
Selected scores:  [0.8514304161071777, 0.74974524974823]
Routing weights (after softmax): [0.5253994464874268, 0.474600613117218]

Expert activity:
  Expert 0: |##########          | score=+0.177  [idle]
  Expert 1: |###########         | score=+0.248  [idle]
  Expert 2: |#############       | score=+0.750  [ACTIVE]
  Expert 3: |#########           | score=-0.029  [idle]
  Expert 4: |############        | score=+0.491  [idle]
  Expert 5: |##############      | score=+0.851  [ACTIVE]
  Expert 6: |######              | score=-0.640  [idle]
  Expert 7: |#######             | score=-0.410  [idle]

→ Only 2/8 experts computed!
  Compute savings: 75%


## Parameter Count: Dense vs MoE

This is the key insight — MoE has **more total parameters** but uses **fewer per token**.

In [6]:
# Dense FFN
dense_ffn = FeedForward({"emb_dim": 768, "hidden_dim": 3072})
dense_total = sum(p.numel() for p in dense_ffn.parameters())

# MoE FFN with 8 experts
moe_configs = [
    {"num_experts": 4, "num_experts_per_tok": 1},
    {"num_experts": 8, "num_experts_per_tok": 2},
    {"num_experts": 16, "num_experts_per_tok": 2},
]

print(f"{'Config':<25} {'Total Params':>15} {'Active Params/Token':>20} {'Active %':>10}")
print("-" * 75)
print(f"{'Dense FFN':<25} {dense_total:>15,} {dense_total:>20,} {'100%':>10}")

for mc in moe_configs:
    cfg_tmp = {"emb_dim": 768, "hidden_dim": 3072, **mc}
    moe_tmp = MoEFeedForward(cfg_tmp)
    total = sum(p.numel() for p in moe_tmp.parameters())
    
    # Each expert has: fc1(768*3072) + fc2(768*3072) + fc3(3072*768) = 3 * 768 * 3072
    expert_params = 3 * 768 * 3072  # params per expert
    active_params = expert_params * mc["num_experts_per_tok"] + 768 * mc["num_experts"]  # + router
    active_pct = f"{active_params/total*100:.0f}%"
    
    label = f"MoE {mc['num_experts']}E top-{mc['num_experts_per_tok']}"
    print(f"{label:<25} {total:>15,} {active_params:>20,} {active_pct:>10}")

print()
print("Key insight:")
print("  MoE 8E top-2 has ~8x the parameters of dense FFN")
print("  But only ~25% are active per token — similar compute to dense!")
print("  The extra 75% idle parameters still STORE knowledge")
print("  → more capacity, similar speed")

Config                       Total Params  Active Params/Token   Active %
---------------------------------------------------------------------------
Dense FFN                       4,722,432            4,722,432       100%
MoE 4E top-1                   28,314,624            7,080,960        25%
MoE 8E top-2                   56,629,248           14,161,920        25%
MoE 16E top-2                 113,258,496           14,168,064        13%

Key insight:
  MoE 8E top-2 has ~8x the parameters of dense FFN
  But only ~25% are active per token — similar compute to dense!
  The extra 75% idle parameters still STORE knowledge
  → more capacity, similar speed


## Why Different Tokens Get Different Experts

The router learns to send different types of tokens to different experts. In a trained model, you'd see patterns like:

```
Token: "photosynthesis"  → Expert 2 (science/biology specialist)
Token: "litigation"      → Expert 5 (legal terminology specialist)  
Token: "the"            → Expert 0 (common words / syntax specialist)
Token: "def"            → Expert 7 (code/programming specialist)
```

This **specialization** is what gives MoE its power — each expert can focus on a specific type of knowledge without competing for capacity with everything else. A dense FFN would need to cram all of these specialties into a single set of weights.

### The routing is learned, not hardcoded

The gate (`nn.Linear(emb_dim, num_experts)`) is trained alongside the experts via backpropagation. Over time, it learns which experts are best at handling which types of input. Nobody tells it "Expert 2 should handle science" — it figures this out during training.

Let's verify that different inputs do get routed differently:

In [7]:
# Create multiple different token embeddings
torch.manual_seed(42)
moe_ffn = MoEFeedForward(cfg_moe)

# Simulate 6 different tokens (random embeddings — in a trained model these would be meaningful)
tokens = torch.randn(1, 6, 768)

with torch.no_grad():
    scores = moe_ffn.gate(tokens)  # (1, 6, 8)
    topk_scores, topk_indices = torch.topk(scores, 2, dim=-1)

print("Which experts does each token use?")
print("─" * 50)
for t in range(6):
    experts = topk_indices[0, t].tolist()
    print(f"  Token {t} → Expert {experts[0]} and Expert {experts[1]}")

print()
# Count how often each expert is used
expert_usage = torch.zeros(8)
for t in range(6):
    for e in topk_indices[0, t].tolist():
        expert_usage[e] += 1

print("Expert utilization across all 6 tokens:")
for i in range(8):
    bar = "█" * int(expert_usage[i].item())
    print(f"  Expert {i}: {bar:<6} ({int(expert_usage[i].item())} tokens)")

print(f"\nNote: In a trained model, the router would learn meaningful specialization.")
print(f"With random weights, the routing is essentially random — that's expected!")

Which experts does each token use?
──────────────────────────────────────────────────
  Token 0 → Expert 5 and Expert 2
  Token 1 → Expert 1 and Expert 0
  Token 2 → Expert 6 and Expert 4
  Token 3 → Expert 2 and Expert 1
  Token 4 → Expert 3 and Expert 1
  Token 5 → Expert 7 and Expert 3

Expert utilization across all 6 tokens:
  Expert 0: █      (1 tokens)
  Expert 1: ███    (3 tokens)
  Expert 2: ██     (2 tokens)
  Expert 3: ██     (2 tokens)
  Expert 4: █      (1 tokens)
  Expert 5: █      (1 tokens)
  Expert 6: █      (1 tokens)
  Expert 7: █      (1 tokens)

Note: In a trained model, the router would learn meaningful specialization.
With random weights, the routing is essentially random — that's expected!


## Plugging MoE into the GPT Architecture

MoE replaces the FeedForward layer **inside each TransformerBlock**. Everything else stays the same — attention, layer norm, residual connections, KV cache — all unchanged.

```
TransformerBlock (Dense — Chapter 4):     TransformerBlock (MoE):

    Input x                                   Input x
      │                                         │
    LayerNorm                                 LayerNorm
      │                                         │
    Attention (+ KV cache)                    Attention (+ KV cache)    ← SAME
      │                                         │
    + residual                                + residual                ← SAME
      │                                         │
    LayerNorm                                 LayerNorm                 ← SAME
      │                                         │
    FeedForward  ← one big FFN                MoE FeedForward  ← router + experts
      │                                         │
    + residual                                + residual                ← SAME
      │                                         │
    Output                                    Output
```

The swap is this simple in code:

In [8]:
# In the TransformerBlock from gpt_with_kv_cache.py, the only change is:

# self.ff = FeedForward(cfg)                                    ← Dense (Chapter 4)
# self.ff = MoEFeedForward(cfg) if cfg["num_experts"] > 0 \    ← MoE (when experts > 0)
#           else FeedForward(cfg)                                ← Dense (fallback)

# Let's verify both produce the same-shaped output
x = torch.randn(2, 10, 768)  # batch=2, seq_len=10, emb_dim=768

dense_out = dense_ffn(x)
moe_out = moe_ffn(x)

print(f"Input shape:      {list(x.shape)}")
print(f"Dense FFN output: {list(dense_out.shape)}")
print(f"MoE FFN output:   {list(moe_out.shape)}")
print(f"\n→ Same shape! MoE is a drop-in replacement for the dense FFN.")
print(f"  The rest of the transformer (attention, norms, residuals) doesn't change.")

Input shape:      [2, 10, 768]
Dense FFN output: [2, 10, 768]
MoE FFN output:   [2, 10, 768]

→ Same shape! MoE is a drop-in replacement for the dense FFN.
  The rest of the transformer (attention, norms, residuals) doesn't change.


## Full Model Comparison: Dense vs MoE

Let's build a full GPT-like model with both architectures and compare.

In [9]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.scale * (x - mean) / torch.sqrt(var + self.eps) + self.shift


class MultiHeadAttention(nn.Module):
    """Simplified MHA (no KV cache) for clean comparison."""
    def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        b, n, _ = x.shape
        q = self.W_query(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.W_key(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.W_value(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        mask = torch.triu(torch.ones(n, n, device=x.device), diagonal=1).bool()
        scores.masked_fill_(mask, -torch.inf)
        weights = self.dropout(torch.softmax(scores, dim=-1))
        out = (weights @ v).transpose(1, 2).contiguous().view(b, n, self.d_out)
        return self.out_proj(out)


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            cfg["emb_dim"], cfg["emb_dim"], cfg["n_heads"],
            dropout=cfg["drop_rate"], qkv_bias=cfg["qkv_bias"]
        )
        # THIS IS THE KEY LINE — MoE or Dense based on config
        self.ff = MoEFeedForward(cfg) if cfg.get("num_experts", 0) > 0 else FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop = nn.Dropout(cfg["drop_rate"])
    
    def forward(self, x):
        x = x + self.drop(self.att(self.norm1(x)))
        x = x + self.drop(self.ff(self.norm2(x)))
        return x


class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop = nn.Dropout(cfg["drop_rate"])
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.norm = LayerNorm(cfg["emb_dim"])
        self.head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
    
    def forward(self, idx):
        b, n = idx.shape
        x = self.tok_emb(idx) + self.pos_emb(torch.arange(n, device=idx.device))
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        return self.head(self.norm(x))


# Build both models
base_cfg = {
    "vocab_size": 50257, "context_length": 1024, "emb_dim": 768,
    "n_heads": 12, "n_layers": 12, "drop_rate": 0.0, "qkv_bias": False,
    "hidden_dim": 3072,
}

# Dense model (standard GPT-2)
dense_cfg = {**base_cfg, "num_experts": 0}
torch.manual_seed(42)
dense_model = GPTModel(dense_cfg)

# MoE model (8 experts, top-2)
moe_cfg = {**base_cfg, "num_experts": 8, "num_experts_per_tok": 2}
torch.manual_seed(42)
moe_model = GPTModel(moe_cfg)

dense_params = sum(p.numel() for p in dense_model.parameters())
moe_params = sum(p.numel() for p in moe_model.parameters())

# Count just FFN params
dense_ffn_params = sum(
    sum(p.numel() for p in blk.ff.parameters()) for blk in dense_model.blocks
)
moe_ffn_params = sum(
    sum(p.numel() for p in blk.ff.parameters()) for blk in moe_model.blocks
)

print("=== Full GPT Model Comparison ===")
print(f"")
print(f"{'':30} {'Dense GPT':>15} {'MoE GPT (8E top-2)':>20}")
print(f"{'-'*70}")
print(f"{'Total parameters':30} {dense_params:>15,} {moe_params:>20,}")
print(f"{'FFN parameters':30} {dense_ffn_params:>15,} {moe_ffn_params:>20,}")
print(f"{'Non-FFN parameters':30} {dense_params-dense_ffn_params:>15,} {moe_params-moe_ffn_params:>20,}")
print(f"{'FFN params active/token':30} {dense_ffn_params:>15,} {moe_ffn_params//4:>20,}")
print(f"")
print(f"The MoE model has {moe_params/dense_params:.1f}x more total parameters")
print(f"but activates roughly the same compute per token as dense.")
print(f"")
print(f"Think of it as: {moe_params/1e6:.0f}M parameters of KNOWLEDGE,")
print(f"with ~{dense_params/1e6:.0f}M parameters of COMPUTE per token.")

=== Full GPT Model Comparison ===

                                     Dense GPT   MoE GPT (8E top-2)
----------------------------------------------------------------------
Total parameters                   163,009,536          785,891,328
FFN parameters                      56,669,184          679,550,976
Non-FFN parameters                 106,340,352          106,340,352
FFN params active/token             56,669,184          169,887,744

The MoE model has 4.8x more total parameters
but activates roughly the same compute per token as dense.

Think of it as: 786M parameters of KNOWLEDGE,
with ~163M parameters of COMPUTE per token.


## Timing: Dense vs MoE Forward Pass

Let's measure the actual computation time to verify MoE doesn't significantly slow down the forward pass despite having many more parameters.

In [10]:
# Benchmark both models
dense_model.eval()
moe_model.eval()

# Create a test input
test_input = torch.randint(0, 50257, (1, 128))  # batch=1, seq_len=128

# Warmup
with torch.no_grad():
    _ = dense_model(test_input)
    _ = moe_model(test_input)

# Time dense model
n_runs = 5
dense_times = []
for _ in range(n_runs):
    start = time.perf_counter()
    with torch.no_grad():
        _ = dense_model(test_input)
    dense_times.append(time.perf_counter() - start)

# Time MoE model
moe_times = []
for _ in range(n_runs):
    start = time.perf_counter()
    with torch.no_grad():
        _ = moe_model(test_input)
    moe_times.append(time.perf_counter() - start)

dense_avg = sum(dense_times) / n_runs * 1000  # ms
moe_avg = sum(moe_times) / n_runs * 1000  # ms

print(f"Forward pass time (128 tokens, averaged over {n_runs} runs):")
print(f"  Dense GPT ({dense_params/1e6:.0f}M params): {dense_avg:.1f} ms")
print(f"  MoE GPT   ({moe_params/1e6:.0f}M params):  {moe_avg:.1f} ms")
print(f"  Ratio: {moe_avg/dense_avg:.2f}x")
print(f"")
if moe_avg < dense_avg * 2:
    print(f"  Despite having {moe_params/dense_params:.1f}x more parameters,")
    print(f"  MoE is only {moe_avg/dense_avg:.2f}x slower — because most experts are idle!")
else:
    print(f"  Note: On CPU, MoE has some overhead from routing logic.")
    print(f"  On GPU with proper batching, the gap narrows significantly.")
print(f"")
print(f"  On GPU with optimized kernels, MoE would be nearly the same speed")
print(f"  as dense while having {moe_params/dense_params:.1f}x the capacity.")

Forward pass time (128 tokens, averaged over 5 runs):
  Dense GPT (163M params): 38.8 ms
  MoE GPT   (786M params):  119.3 ms
  Ratio: 3.08x

  Note: On CPU, MoE has some overhead from routing logic.
  On GPU with proper batching, the gap narrows significantly.

  On GPU with optimized kernels, MoE would be nearly the same speed
  as dense while having 4.8x the capacity.


## Real-World MoE Models

Here's how MoE is used in actual production models:

```
Model              Total Params    Active Params/Token    Experts    Top-k
──────────────     ────────────    ───────────────────    ───────    ─────
GPT-2 124M         124M            124M (dense)           1          1
Mixtral 8x7B       46.7B           12.9B                  8          2
DeepSeek-V2         236B            21B                   160         6
GPT-4 (rumored)    ~1.8T           ~280B                  16         2
```

### Mixtral 8x7B: The MoE poster child

Mixtral has 8 expert FFNs per layer, each the size of a 7B model's FFN. The router picks the top-2 for each token. So:
- Total parameters: 46.7B (lots of stored knowledge)
- Active parameters per token: 12.9B (affordable compute)
- Result: Matches or beats LLaMA 2 70B despite using ~5x less compute per token

### The tradeoff

MoE isn't free — you need to store all expert parameters in memory, even though most are idle:

```
Dense 13B model:    13B params × 2 bytes (fp16) = 26 GB GPU memory
Mixtral 8x7B:       46.7B params × 2 bytes      = 93 GB GPU memory
                    (but only computes like a ~13B model!)
```

This is why MoE pairs well with **quantization** (shrink each parameter from 2 bytes to 0.5 bytes) and **offloading** (keep idle experts on CPU/disk, only load active ones to GPU).

## How MoE, KV Cache, and LoRA Work Together

These three optimizations target different bottlenecks and combine naturally:

```
┌─────────────────────────────────────────────────────────────────────┐
│                    The Modern LLM Optimization Stack                │
│                                                                     │
│  PROBLEM                SOLUTION              WHERE IN GPT          │
│  ───────                ────────              ────────────          │
│                                                                     │
│  Model too slow         KV Cache              MultiHeadAttention    │
│  at generation?         (don't recompute      (cache K,V from       │
│                          old K,V)              previous tokens)     │
│                                                                     │
│  Model not smart        MoE                   FeedForward           │
│  enough?                (more expert params,   (replace dense FFN   │
│                          same compute)          with router+experts)│
│                                                                     │
│  Need to customize      LoRA                  W_query, W_value      │
│  for a task?            (tiny adapters,        (add small A×B       │
│                          freeze base)           alongside frozen W) │
│                                                                     │
│  Model too big          Quantization          All weight matrices   │
│  for device?            (4-bit weights,        (compress 32-bit     │
│                          smaller model)         floats to 4-bit)    │
│                                                                     │
│  All together: A quantized MoE model with KV cache and LoRA        │
│  adapters = maximum capability, minimum resource usage              │
└─────────────────────────────────────────────────────────────────────┘
```

### Key Takeaway

MoE is a way to scale model **knowledge** without proportionally scaling **compute**. Instead of making the FeedForward layer bigger (which slows down every token), you add more expert FFNs and let a learned router pick the best ones per token. The result: models like Mixtral 8x7B that match 70B-parameter dense models while using only ~13B parameters of compute per token.

The `MoEFeedForward` class in this notebook is the same architecture used in production. The only difference is scale — Mixtral uses 8 experts with 7B-parameter-sized FFNs, while we used 8 experts with GPT-2-sized FFNs. The routing logic, gating mechanism, and SwiGLU computation are identical.