# Mixture of Experts (MoE) vs Feed-Forward Networks (FFN) Comparison

This notebook demonstrates a comprehensive comparison between Mixture of Experts (MoE) models and standard Feed Forward Networks (FFN). We'll explore their performance on multi modal tasks, parameter efficiency, and expert specialization patterns.

## Key Learning Objectives:
- Understanding MoE architecture and routing mechanisms
- Comparing parameter efficiency between MoE and FFN models
- Analyzing expert specialization on different task modalities
- Visualizing training dynamics and convergence patterns

In [None]:
import os
import math
import random
from typing import Tuple, Dict, List
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns

# Set clean plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [None]:
class MultiModalDataset(Dataset):
    """Multi-modal dataset designed to showcase MoE advantages.
    Four distinct modalities that benefit from specialized expert processing.
    """
    def __init__(self, n_samples: int = 10000, seq_len: int = 32, noise: float = 0.08):
        super().__init__()
        self.n = n_samples
        self.seq_len = seq_len
        self.noise = noise
        
        samples_per_modality = n_samples // 4
        X = []
        y = []
        
        # Modality 0: High-frequency sinusoidal patterns
        for _ in range(samples_per_modality):
            freq = np.random.uniform(1.0, 4.0)
            phase = np.random.uniform(0, 2*np.pi)
            amp = np.random.uniform(0.5, 1.5)
            t = np.linspace(0, 6*np.pi, seq_len)
            signal = amp * np.sin(freq * t + phase) + np.random.randn(seq_len) * noise
            X.append(signal)
            y.append(0)
        
        # Modality 1: Complex polynomial patterns
        for _ in range(samples_per_modality):
            # Higher-order polynomials
            coeffs = np.random.randn(4) * [0.1, 0.3, 0.5, 0.2]  # cubic
            t = np.linspace(-2, 2, seq_len)
            signal = np.polyval(coeffs, t) + np.random.randn(seq_len) * noise
            X.append(signal)
            y.append(1)
        
        # Modality 2: Multi-level step functions
        for _ in range(samples_per_modality):
            n_steps = np.random.randint(3, 7)
            signal = np.zeros(seq_len)
            step_size = seq_len // n_steps
            for i in range(n_steps):
                start = i * step_size
                end = min((i+1) * step_size, seq_len)
                level = np.random.uniform(-1.5, 1.5)
                signal[start:end] = level
            signal += np.random.randn(seq_len) * noise * 0.3
            X.append(signal)
            y.append(2)
        
        # Modality 3: Exponential decay/growth patterns
        for _ in range(samples_per_modality):
            decay_rate = np.random.uniform(-0.15, 0.15)
            initial = np.random.uniform(-1, 1)
            t = np.linspace(0, seq_len-1, seq_len)
            signal = initial * np.exp(decay_rate * t) + np.random.randn(seq_len) * noise
            X.append(signal)
            y.append(3)
        
        self.X = np.array(X, dtype=np.float32)
        self.y = np.array(y, dtype=np.int64)
        
        # Shuffle
        idx = np.arange(len(self.X))
        np.random.shuffle(idx)
        self.X = self.X[idx]
        self.y = self.y[idx]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create dataset and data loaders
print("Creating multi-modal dataset...")
dataset = MultiModalDataset(n_samples=10000, seq_len=32)

# Split data
n = len(dataset)
indices = np.arange(n)
np.random.shuffle(indices)
train_idx = indices[:int(n * 0.8)]
val_idx = indices[int(n * 0.8):]

train_dataset = Subset(dataset, train_idx.tolist())
val_dataset = Subset(dataset, val_idx.tolist())

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Task types: 4 (Sinusoidal, Polynomial, Step Functions, Exponential)")

In [None]:
class StandardFFN(nn.Module):
    """Standard Feed-Forward Network for comparison baseline."""
    def __init__(self, d_model=32, d_hidden=160, n_classes=4):
        super(StandardFFN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ReLU(),
            nn.Dropout(0.15),
            nn.Linear(d_hidden, d_hidden),
            nn.ReLU(),
            nn.Dropout(0.15),
            nn.Linear(d_hidden, d_hidden//2),
            nn.ReLU(),
            nn.Linear(d_hidden//2, n_classes)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())

class MixtureOfExperts(nn.Module):
    """Mixture of Experts model with specialized expert networks."""
    def __init__(self, d_model=32, d_hidden=80, n_experts=4, n_classes=4, top_k=2, load_balance_loss_coef=0.01):
        super(MixtureOfExperts, self).__init__()
        self.n_experts = n_experts
        self.top_k = top_k
        self.load_balance_loss_coef = load_balance_loss_coef
        
        # Router network with proper initialization for stability
        self.router = nn.Sequential(
            nn.Linear(d_model, d_hidden//2),
            nn.ReLU(),
            nn.Dropout(0.05),  # Reduced dropout for stability
            nn.Linear(d_hidden//2, n_experts)
        )
        
        # Initialize router weights for stability
        with torch.no_grad():
            for module in self.router.modules():
                if isinstance(module, nn.Linear):
                    nn.init.normal_(module.weight, 0, 0.1)
                    nn.init.constant_(module.bias, 0)
        
        # Specialized experts with distinct architectures
        self.experts = nn.ModuleList()
        
        # Expert 0: Sinusoidal specialist (Tanh for periodicity)
        self.experts.append(nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, d_hidden//2),
            nn.Tanh(),
            nn.Linear(d_hidden//2, n_classes)
        ))
        
        # Expert 1: Polynomial specialist (Deep ReLU)
        self.experts.append(nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_hidden//2),
            nn.ReLU(),
            nn.Linear(d_hidden//2, n_classes)
        ))
        
        # Expert 2: Step function specialist
        self.experts.append(nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, d_hidden//2),
            nn.ReLU(),
            nn.Linear(d_hidden//2, n_classes)
        ))
        
        # Expert 3: Exponential specialist (ELU for smooth gradients)
        self.experts.append(nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ELU(),
            nn.Dropout(0.1),
            nn.Linear(d_hidden, d_hidden//2),
            nn.ELU(),
            nn.Linear(d_hidden//2, n_classes)
        ))
        
        # Initialize expert weights properly
        for expert in self.experts:
            for module in expert.modules():
                if isinstance(module, nn.Linear):
                    nn.init.xavier_uniform_(module.weight)
                    nn.init.constant_(module.bias, 0)
        
        # Tracking variables
        self.expert_usage = torch.zeros(n_experts)
        self.class_expert_matrix = torch.zeros(n_classes, n_experts)
        
    def forward(self, x, track_routing=False):
        batch_size = x.size(0)
        
        # Router decision with temperature scaling for stability
        gate_logits = self.router(x)
        gate_probs = F.softmax(gate_logits / 1.0, dim=-1)  # Temperature=1.0
        
        # Top-k routing with proper normalization
        topk_vals, topk_idx = torch.topk(gate_probs, self.top_k, dim=-1)
        
        # Create routing weights
        routing_weights = torch.zeros_like(gate_probs)
        routing_weights.scatter_(1, topk_idx, topk_vals)
        
        # Normalize routing weights properly
        routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-8)
        
        # Expert outputs
        expert_outputs = []
        for expert in self.experts:
            expert_outputs.append(expert(x))
        expert_stack = torch.stack(expert_outputs, dim=-1)  # [batch, n_classes, n_experts]
        
        # Apply routing weights
        routing_weights_expanded = routing_weights.unsqueeze(1)  # [batch, 1, n_experts]
        output = (expert_stack * routing_weights_expanded).sum(dim=-1)  # [batch, n_classes]
        
        # Calculate load balancing loss for training stability
        self.load_balance_loss = 0.0
        if self.training:
            # Expert usage should be balanced
            expert_usage = routing_weights.sum(dim=0)  # [n_experts]
            balance_target = batch_size * self.top_k / self.n_experts
            balance_loss = ((expert_usage - balance_target) ** 2).mean()
            self.load_balance_loss = self.load_balance_loss_coef * balance_loss
        
        # Clean tracking for analysis
        if track_routing:
            with torch.no_grad():
                self.expert_usage += gate_probs.sum(dim=0)
            return output, {'gate_probs': gate_probs.detach(), 'routing_weights': routing_weights.detach()}
        
        return output
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())
    
    def get_active_parameters(self):
        """Estimate active parameters based on expert usage."""
        if self.expert_usage.sum() == 0:
            return self.count_parameters()
        
        # Router is always active
        router_params = sum(p.numel() for p in self.router.parameters())
        
        # Weighted expert parameters
        usage_weights = self.expert_usage / self.expert_usage.sum()
        expert_params = 0
        for i, expert in enumerate(self.experts):
            expert_param_count = sum(p.numel() for p in expert.parameters())
            expert_params += expert_param_count * usage_weights[i].item()
        
        return router_params + expert_params

# Initialize models
print("Initializing models...")
ffn = StandardFFN(d_model=32, d_hidden=160, n_classes=4).to(DEVICE)
moe = MixtureOfExperts(d_model=32, d_hidden=80, n_experts=4, n_classes=4, top_k=2, load_balance_loss_coef=0.01).to(DEVICE)

print(f"FFN Parameters: {ffn.count_parameters():,}")
print(f"MoE Parameters: {moe.count_parameters():,}")
print(f"Parameter ratio (MoE/FFN): {moe.count_parameters()/ffn.count_parameters():.2f}")

In [None]:
def train_epoch(model, optimizer, loader, criterion, track_routing=False):
    """Train model for one epoch with proper gradient clipping and loss tracking."""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    routing_data = []
    
    for batch_x, batch_y in loader:
        batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
        
        optimizer.zero_grad()
        
        if hasattr(model, 'expert_usage') and track_routing:
            outputs, routing_info = model(batch_x, track_routing=True)
            routing_data.append(routing_info)
        else:
            outputs = model(batch_x)
        
        # Main classification loss
        loss = criterion(outputs, batch_y)
        
        # Add load balancing loss for MoE
        if hasattr(model, 'load_balance_loss'):
            loss = loss + model.load_balance_loss
        
        loss.backward()
        
        # Gradient clipping for stability (especially important for MoE)
        if hasattr(model, 'expert_usage'):  # MoE model
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        else:  # FFN model
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        
        optimizer.step()
        
        total_loss += loss.item() * batch_x.size(0)
        _, predicted = outputs.max(1)
        total += batch_y.size(0)
        correct += predicted.eq(batch_y).sum().item()
    
    avg_loss = total_loss / total
    accuracy = correct / total
    
    return avg_loss, accuracy, routing_data

def evaluate_model(model, loader, criterion, track_routing=False):
    """Evaluate model on validation data with routing analysis."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    class_expert_usage = defaultdict(lambda: defaultdict(float))
    
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
            
            if hasattr(model, 'expert_usage') and track_routing:
                outputs, routing_info = model(batch_x, track_routing=True)
                
                # Track class-expert associations
                gate_probs = routing_info['gate_probs']
                for i, class_id in enumerate(batch_y.cpu().numpy()):
                    for expert_id in range(gate_probs.size(1)):
                        class_expert_usage[class_id][expert_id] += gate_probs[i, expert_id].item()
            else:
                outputs = model(batch_x)
            
            loss = criterion(outputs, batch_y)
            total_loss += loss.item() * batch_x.size(0)
            _, predicted = outputs.max(1)
            total += batch_y.size(0)
            correct += predicted.eq(batch_y).sum().item()
    
    return total_loss / total, correct / total, dict(class_expert_usage)

# Set up optimizers and learning rate schedulers
optimizer_ffn = torch.optim.Adam(ffn.parameters(), lr=2e-3, weight_decay=1e-5, eps=1e-8)
optimizer_moe = torch.optim.Adam(moe.parameters(), lr=1.5e-3, weight_decay=1e-5, eps=1e-8)  # Slightly lower LR for MoE

scheduler_ffn = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_ffn, mode='min', factor=0.7, patience=10)
scheduler_moe = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_moe, mode='min', factor=0.7, patience=10)

criterion = nn.CrossEntropyLoss()

print("Training configuration ready!")
print(f"FFN Learning Rate: {optimizer_ffn.param_groups[0]['lr']}")
print(f"MoE Learning Rate: {optimizer_moe.param_groups[0]['lr']}")

In [None]:
# Training configuration
epochs = 100
print(f"Starting training for {epochs} epochs...")

# Training history
history = {
    'ffn_train_loss': [], 'ffn_train_acc': [], 'ffn_val_loss': [], 'ffn_val_acc': [],
    'moe_train_loss': [], 'moe_train_acc': [], 'moe_val_loss': [], 'moe_val_acc': [],
    'moe_active_params': [], 'moe_efficiency': [],
    'ffn_total_params': ffn.count_parameters(),
    'moe_total_params': moe.count_parameters()
}

best_moe_acc = 0
final_class_expert_usage = {}

for epoch in range(1, epochs + 1):
    # Train FFN
    ffn_train_loss, ffn_train_acc, _ = train_epoch(ffn, optimizer_ffn, train_loader, criterion)
    ffn_val_loss, ffn_val_acc, _ = evaluate_model(ffn, val_loader, criterion)
    
    # Train MoE
    track_routing = (epoch % 10 == 0) or (epoch == epochs)
    moe_train_loss, moe_train_acc, _ = train_epoch(moe, optimizer_moe, train_loader, criterion, track_routing)
    moe_val_loss, moe_val_acc, class_expert_usage = evaluate_model(moe, val_loader, criterion, track_routing)
    
    # Update learning rate schedulers
    scheduler_ffn.step(ffn_val_loss)
    scheduler_moe.step(moe_val_loss)
    
    # Calculate efficiency metrics
    active_params = moe.get_active_parameters()
    efficiency = active_params / ffn.count_parameters()
    
    # Store history
    history['ffn_train_loss'].append(ffn_train_loss)
    history['ffn_train_acc'].append(ffn_train_acc)
    history['ffn_val_loss'].append(ffn_val_loss)
    history['ffn_val_acc'].append(ffn_val_acc)
    
    history['moe_train_loss'].append(moe_train_loss)
    history['moe_train_acc'].append(moe_train_acc)
    history['moe_val_loss'].append(moe_val_loss)
    history['moe_val_acc'].append(moe_val_acc)
    
    history['moe_active_params'].append(active_params)
    history['moe_efficiency'].append(efficiency)
    
    if moe_val_acc > best_moe_acc:
        best_moe_acc = moe_val_acc
        final_class_expert_usage = class_expert_usage
    
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:3d} | FFN: {ffn_val_acc:.3f} | MoE: {moe_val_acc:.3f} | Efficiency: {efficiency:.3f}")

print("\nTraining completed!")

In [None]:
# Final evaluation
print("=" * 60)
print("FINAL RESULTS")
print("=" * 60)

ffn_final_loss, ffn_final_acc, _ = evaluate_model(ffn, val_loader, criterion)
moe_final_loss, moe_final_acc, _ = evaluate_model(moe, val_loader, criterion)

improvement = ((moe_final_acc - ffn_final_acc) / ffn_final_acc) * 100
param_reduction = (1 - moe.get_active_parameters() / ffn.count_parameters()) * 100

print(f"Standard FFN   | Accuracy: {ffn_final_acc:.4f} | Loss: {ffn_final_loss:.4f}")
print(f"MoE Network    | Accuracy: {moe_final_acc:.4f} | Loss: {moe_final_loss:.4f}")
print(f"")
print(f"Performance Improvement: {improvement:.2f}%")
print(f"Parameter Reduction: {param_reduction:.1f}%")
print(f"Active Parameters: {moe.get_active_parameters():,.0f} / {moe.count_parameters():,}")
print("=" * 60)

# Calculate efficiency statistics
final_efficiency = history['moe_efficiency'][-1]
avg_efficiency = np.mean(history['moe_efficiency'][10:])  # After stabilization
print(f"\nEfficiency Analysis:")
print(f"Final efficiency: {final_efficiency:.3f}")
print(f"Average efficiency (after epoch 10): {avg_efficiency:.3f}")
print(f"Efficiency stabilization: {'✓ Stable' if np.std(history['moe_efficiency'][10:]) < 0.01 else '✗ Unstable'}")

# Expert usage summary
print(f"\nExpert Usage Summary:")
expert_usage_normalized = moe.expert_usage / moe.expert_usage.sum()
for i, usage in enumerate(expert_usage_normalized):
    print(f"Expert {i}: {usage:.3f} ({usage*100:.1f}%)")

In [None]:
# Set up clean plotting style
plt.rcParams.update({
    'font.size': 11,
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'legend.fontsize': 10,
    'lines.linewidth': 2.5,
    'figure.dpi': 100
})

# Create training curves visualization
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10))

# Accuracy plot
epochs_range = range(1, len(history['ffn_val_acc']) + 1)
ax1.plot(epochs_range, history['ffn_val_acc'], label='Standard FFN', color='#2E86AB', linewidth=3)
ax1.plot(epochs_range, history['moe_val_acc'], label='MoE Network', color='#A23B72', linewidth=3)
ax1.set_ylabel('Validation Accuracy', fontweight='bold')
ax1.set_title('Model Performance Comparison', fontweight='bold', pad=20)
ax1.legend(frameon=True, shadow=True)
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0.5, 1.0)

# Loss plot
ax2.plot(epochs_range, history['ffn_val_loss'], label='Standard FFN', color='#2E86AB', linewidth=3)
ax2.plot(epochs_range, history['moe_val_loss'], label='MoE Network', color='#A23B72', linewidth=3)
ax2.set_ylabel('Validation Loss', fontweight='bold')
ax2.set_xlabel('Epoch', fontweight='bold')
ax2.set_title('Training Loss Progression', fontweight='bold', pad=20)
ax2.legend(frameon=True, shadow=True)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Training summary statistics
print("Training Curve Analysis:")
print(f"Final FFN accuracy: {history['ffn_val_acc'][-1]:.4f}")
print(f"Final MoE accuracy: {history['moe_val_acc'][-1]:.4f}")
print(f"Best FFN accuracy: {max(history['ffn_val_acc']):.4f}")
print(f"Best MoE accuracy: {max(history['moe_val_acc']):.4f}")
print(f"FFN convergence (epochs to 95%): {next((i for i, acc in enumerate(history['ffn_val_acc']) if acc > 0.95), 'N/A')}")
print(f"MoE convergence (epochs to 95%): {next((i for i, acc in enumerate(history['moe_val_acc']) if acc > 0.95), 'N/A')}")

In [None]:
# Parameter efficiency visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Parameter count comparison
ffn_params = history['ffn_total_params']
moe_total_params = history['moe_total_params']
moe_active_params = history['moe_active_params'][-1]  # Final active params

models = ['Standard FFN', 'MoE Total', 'MoE Active']
param_counts = [ffn_params, moe_total_params, moe_active_params]
colors = ['#2E86AB', '#F18F01', '#A23B72']

bars = ax1.bar(models, param_counts, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
ax1.set_ylabel('Parameter Count', fontweight='bold')
ax1.set_title('Parameter Efficiency Analysis', fontweight='bold', pad=20)
ax1.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, count in zip(bars, param_counts):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + max(param_counts)*0.01,
            f'{count:,}', ha='center', va='bottom', fontweight='bold')

# Efficiency over time
epochs_range = range(1, len(history['moe_efficiency']) + 1)
ax2.plot(epochs_range, history['moe_efficiency'], color='#A23B72', linewidth=3)
ax2.set_ylabel('Active Parameters Ratio', fontweight='bold')
ax2.set_xlabel('Epoch', fontweight='bold')
ax2.set_title('Parameter Efficiency Over Training', fontweight='bold', pad=20)
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1)

# Add annotation for stabilization
stabilization_epoch = 10
ax2.axvline(x=stabilization_epoch, color='red', linestyle='--', alpha=0.7)
ax2.text(stabilization_epoch + 2, 0.8, 'Efficiency\nStabilizes', 
         fontsize=10, color='red', fontweight='bold')

plt.tight_layout()
plt.show()

# Calculate and display efficiency metrics
print("Parameter Efficiency Metrics:")
print(f"FFN Total Parameters: {ffn_params:,}")
print(f"MoE Total Parameters: {moe_total_params:,}")
print(f"MoE Active Parameters: {moe_active_params:,.0f}")
print(f"")
print(f"MoE Total has {((ffn_params - moe_total_params) / ffn_params * 100):.1f}% fewer total parameters than FFN")
print(f"MoE Active uses {((moe_active_params / ffn_params) * 100):.1f}% of FFN's parameter count actively")
print(f"Parameter efficiency: {(1 - moe_active_params / ffn_params) * 100:.1f}% reduction")

# Efficiency timeline analysis
early_efficiency = np.mean(history['moe_efficiency'][:5])
late_efficiency = np.mean(history['moe_efficiency'][-10:])
print(f"")
print(f"Efficiency Evolution:")
print(f"Early training (epochs 1-5): {early_efficiency:.3f}")
print(f"Late training (last 10 epochs): {late_efficiency:.3f}")