# ðŸ”€ Switch Transformer: Mixture of Experts

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/transformer_architectures/10_switch_transformer/demo.ipynb)

![Architecture](architecture.png)

### Key Innovation
- **Mixture of Experts (MoE)**: Multiple FFN "experts"
- **Sparse Activation**: Each token routed to 1 expert
- **Massive Scale**: Trillion parameters with constant FLOPs

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## The MoE Concept

Instead of one large FFN, use many small "expert" FFNs and route tokens!

In [None]:
def visualize_moe_concept():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Standard FFN
    ax = axes[0]
    ax.add_patch(plt.Rectangle((0.2, 0.3), 0.6, 0.4, color='steelblue', alpha=0.7))
    ax.text(0.5, 0.5, 'Single Large FFN\n(d_ff = 4096)', ha='center', va='center', fontsize=12, color='white')
    
    # Input tokens
    for i, y in enumerate([0.15, 0.1, 0.05]):
        ax.annotate('', xy=(0.2, 0.3 + i*0.1), xytext=(0.05, y + 0.1),
                   arrowprops=dict(arrowstyle='->', color='gray'))
        ax.text(0.02, y + 0.1, f'Token {i+1}', fontsize=9)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Standard Transformer\nAll tokens â†’ One FFN', fontsize=11)
    ax.axis('off')
    
    # MoE
    ax = axes[1]
    colors = plt.cm.Set3(range(4))
    expert_positions = [(0.15, 0.5), (0.4, 0.5), (0.6, 0.5), (0.85, 0.5)]
    
    for i, (x, y) in enumerate(expert_positions):
        ax.add_patch(plt.Rectangle((x-0.08, y-0.15), 0.16, 0.3, color=colors[i], alpha=0.8))
        ax.text(x, y, f'Expert\n{i+1}', ha='center', va='center', fontsize=9)
    
    # Router
    ax.add_patch(plt.Circle((0.5, 0.2), 0.08, color='gold', alpha=0.8))
    ax.text(0.5, 0.2, 'Router', ha='center', va='center', fontsize=9)
    
    # Tokens being routed
    token_routes = [(0, 0), (1, 2), (2, 1)]  # token -> expert
    for tok_idx, exp_idx in token_routes:
        ax.annotate('', xy=(expert_positions[exp_idx][0], 0.35),
                   xytext=(0.5, 0.28),
                   arrowprops=dict(arrowstyle='->', color=colors[exp_idx], lw=2))
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Switch Transformer\nTokens routed to different experts', fontsize=11)
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_moe_concept()
print('Key: More parameters, same compute per token!')

## Switch Routing

In [None]:
class SwitchRouter(nn.Module):
    """Route tokens to experts using learned routing."""
    def __init__(self, d_model, n_experts, capacity_factor=1.25):
        super().__init__()
        self.n_experts = n_experts
        self.capacity_factor = capacity_factor
        self.router = nn.Linear(d_model, n_experts)
    
    def forward(self, x):
        """Route tokens to experts.
        
        Returns:
            - expert_indices: Which expert each token goes to
            - router_probs: Probability for top expert (for loss)
            - expert_mask: Binary mask for each expert
        """
        B, T, C = x.shape
        
        # Compute routing scores
        router_logits = self.router(x)  # (B, T, n_experts)
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top expert for each token (Switch uses top-1)
        expert_indices = router_probs.argmax(dim=-1)  # (B, T)
        top_probs = router_probs.max(dim=-1).values  # (B, T)
        
        # Create mask for each expert
        expert_mask = F.one_hot(expert_indices, self.n_experts).float()  # (B, T, n_experts)
        
        return expert_indices, top_probs, expert_mask, router_probs

# Demonstrate routing
router = SwitchRouter(d_model=64, n_experts=4)
x = torch.randn(2, 16, 64)  # Batch=2, SeqLen=16

expert_indices, top_probs, expert_mask, router_probs = router(x)

print(f'Input shape: {x.shape}')
print(f'Expert indices shape: {expert_indices.shape}')
print(f'\nToken routing (batch 0):')
print(f'  Expert assignments: {expert_indices[0].tolist()}')
print(f'  Routing probs: {top_probs[0, :5].tolist()}')  # First 5 tokens

# Visualize expert distribution
expert_counts = expert_mask[0].sum(dim=0).tolist()
plt.figure(figsize=(8, 4))
plt.bar(range(4), expert_counts, color=plt.cm.Set3(range(4)))
plt.xlabel('Expert ID')
plt.ylabel('Number of Tokens')
plt.title('Token Distribution Across Experts')
plt.xticks(range(4), [f'Expert {i}' for i in range(4)])
plt.grid(True, alpha=0.3)
plt.show()

## Load Balancing Loss

Prevent all tokens going to one expert!

In [None]:
def load_balancing_loss(router_probs, expert_mask):
    """Auxiliary loss to encourage balanced routing.
    
    From Switch Transformer paper:
    L_aux = alpha * n_experts * sum_i(f_i * P_i)
    where f_i = fraction of tokens to expert i
          P_i = average routing probability for expert i
    """
    # Fraction of tokens routed to each expert
    tokens_per_expert = expert_mask.sum(dim=(0, 1))  # (n_experts,)
    total_tokens = expert_mask.sum()
    f = tokens_per_expert / (total_tokens + 1e-10)
    
    # Average routing probability for each expert
    P = router_probs.mean(dim=(0, 1))  # (n_experts,)
    
    # Load balancing loss
    n_experts = router_probs.shape[-1]
    loss = n_experts * (f * P).sum()
    
    return loss

# Test load balancing
lb_loss = load_balancing_loss(router_probs, expert_mask)
print(f'Load balancing loss: {lb_loss.item():.4f}')
print('(Ideal is 1.0 when perfectly balanced)')

## Switch Transformer Implementation

In [None]:
class Expert(nn.Module):
    """Single FFN expert."""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class SwitchFFN(nn.Module):
    """Switch Transformer FFN with multiple experts."""
    def __init__(self, d_model, d_ff, n_experts=4, dropout=0.1):
        super().__init__()
        self.n_experts = n_experts
        self.router = SwitchRouter(d_model, n_experts)
        self.experts = nn.ModuleList([Expert(d_model, d_ff, dropout) for _ in range(n_experts)])
    
    def forward(self, x):
        B, T, C = x.shape
        
        # Get routing
        expert_indices, top_probs, expert_mask, router_probs = self.router(x)
        
        # Process tokens by their assigned expert
        output = torch.zeros_like(x)
        
        for expert_idx in range(self.n_experts):
            # Get mask for this expert
            mask = (expert_indices == expert_idx)  # (B, T)
            
            if mask.any():
                # Get tokens for this expert
                expert_input = x[mask]  # (num_tokens, C)
                expert_output = self.experts[expert_idx](expert_input)
                
                # Scale by routing probability
                expert_output = expert_output * top_probs[mask].unsqueeze(-1)
                
                # Put back
                output[mask] = expert_output
        
        # Compute auxiliary loss
        aux_loss = load_balancing_loss(router_probs, expert_mask)
        
        return output, aux_loss, expert_indices

class SwitchBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, n_experts=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.switch_ffn = SwitchFFN(d_model, d_ff, n_experts, dropout)
    
    def forward(self, x, mask=None):
        # Self-attention
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask)
        x = x + attn_out
        
        # Switch FFN
        ffn_out, aux_loss, expert_indices = self.switch_ffn(self.norm2(x))
        x = x + ffn_out
        
        return x, aux_loss, expert_indices

class SwitchTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, d_ff=256, n_experts=4, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([SwitchBlock(d_model, n_heads, d_ff, n_experts, dropout) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0)
        
        # Causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), diagonal=1)
        
        x = self.dropout(self.embed(x) + self.pos_embed(pos))
        
        total_aux_loss = 0
        expert_usage = []
        
        for layer in self.layers:
            x, aux_loss, expert_indices = layer(x, mask)
            total_aux_loss += aux_loss
            expert_usage.append(expert_indices)
        
        return self.head(self.norm(x)), total_aux_loss / len(self.layers), expert_usage

model = SwitchTransformer(vocab_size=1000, d_model=64, n_heads=4, n_layers=2, d_ff=128, n_experts=4).to(device)
print(f'Switch Transformer Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Training Switch Transformer

In [None]:
# Dataset
text = 'the quick brown fox jumps over the lazy dog ' * 300
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)

# Training
seq_len = 64
model = SwitchTransformer(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=2, d_ff=128, n_experts=4, max_len=seq_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

losses = []
aux_losses = []
n_steps = 300
aux_weight = 0.01  # Weight for load balancing loss

print('Training Switch Transformer with MoE...')
for step in range(n_steps):
    idx = torch.randint(0, len(data) - seq_len - 1, (16,))
    x = torch.stack([data[i:i+seq_len] for i in idx]).to(device)
    y = torch.stack([data[i+1:i+seq_len+1] for i in idx]).to(device)
    
    optimizer.zero_grad()
    logits, aux_loss, _ = model(x)
    
    main_loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    total_loss = main_loss + aux_weight * aux_loss
    
    total_loss.backward()
    optimizer.step()
    
    losses.append(main_loss.item())
    aux_losses.append(aux_loss.item())
    
    if (step + 1) % 50 == 0:
        print(f'Step {step+1}: Main Loss = {main_loss.item():.4f}, Aux Loss = {aux_loss.item():.4f}')

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].plot(losses)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Main Loss')
axes[0].set_title('Language Modeling Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(aux_losses, color='orange')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Load Balancing Loss')
axes[1].set_title('Auxiliary Loss (Expert Balance)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Analyze expert utilization
model.eval()
test_input = torch.randint(0, vocab_size, (4, 64)).to(device)

with torch.no_grad():
    _, _, expert_usage = model(test_input)

# Count expert usage across layers
fig, axes = plt.subplots(1, len(expert_usage), figsize=(12, 4))

for layer_idx, expert_indices in enumerate(expert_usage):
    counts = torch.bincount(expert_indices.flatten(), minlength=4).cpu()
    axes[layer_idx].bar(range(4), counts, color=plt.cm.Set3(range(4)))
    axes[layer_idx].set_xlabel('Expert')
    axes[layer_idx].set_ylabel('Tokens')
    axes[layer_idx].set_title(f'Layer {layer_idx + 1}')
    axes[layer_idx].set_xticks(range(4))

plt.suptitle('Expert Utilization by Layer')
plt.tight_layout()
plt.show()

print('\nðŸŽ¯ Key Takeaways:')
print('1. MoE: Multiple FFN experts, tokens routed dynamically')
print('2. Switch (top-1): Each token goes to exactly 1 expert')
print('3. Scales to trillions of parameters with constant FLOPs')
print('4. Load balancing loss prevents expert collapse')
print('5. Used in Google Switch-C (1.6T params), GLaM, Mixtral')