# Module 4.1: Mixture of Experts (MoE)

**Goal**: Implement MoE layer from scratch and understand routing mechanisms

**Time**: 90 minutes

**Concepts Covered**:
- MoE layer implementation
- Router mechanism with top-k selection
- Compare dense vs MoE memory/compute
- Convert existing model to MoE
- Visualize expert routing patterns

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

class MoELayer(nn.Module):
    """Mixture of Experts Layer with Top-K Routing"""
    def __init__(self, d_model, num_experts=8, top_k=2, expert_capacity_factor=1.25):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.d_model = d_model
        
        # Router: maps input to expert scores
        self.router = nn.Linear(d_model, num_experts)
        
        # Expert networks (simple FFNs)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.GELU(),
                nn.Linear(d_model * 4, d_model)
            ) for _ in range(num_experts)
        ])
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        # Router scores
        router_logits = self.router(x)  # (batch, seq_len, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Top-k selection
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)  # Renormalize
        
        # Process through experts
        output = torch.zeros_like(x)
        for expert_idx in range(self.num_experts):
            # Find positions routed to this expert
            expert_mask = (top_k_indices == expert_idx)
            if expert_mask.any():
                expert_input = x[expert_mask]
                expert_output = self.experts[expert_idx](expert_input)
                # Weight by routing probability
                expert_weights = top_k_probs[expert_mask]
                output[expert_mask] += expert_output * expert_weights.unsqueeze(-1)
        
        return output, router_probs

# Test MoE layer
d_model = 128
num_experts = 4
top_k = 2
moe = MoELayer(d_model, num_experts, top_k)

x = torch.randn(2, 10, d_model)
output, routing = moe(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Routing probabilities shape: {routing.shape}")
print(f"\nTop-{top_k} routing active for each token")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.