# **MoE Experiments**

First we just want to import the `Transformer` class as well as all of our custom MoE modules.

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import einops
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborne as sns
import time

from transformer import Transformer
from moes import RegularMoE, RandomMoE, OrthogonalMoE, HashMoE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu" and torch.backends.mps.is_available():
    device = torch.device("mps")
    torch.mps.manual_seed(67960)
if device.type == "cuda" or device.type == "cpu":
    torch.manual_seed(67960)

B = 32

print(f"Using device: {device}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: mps


Next, we want to import our data and get it ready to use.

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import GPT2TokenizerFast

# load dataset
print("Loading AG News dataset...")
dataset = load_dataset("ag_news")
train_data = dataset['train']
test_data = dataset['test']

# use GPT-2's BPE tokenizer
print("Loading GPT-2 tokenizer...")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # set EOS as padding
vocab_size = len(tokenizer)

print(f"Vocab size: {vocab_size}")
print(f"Train samples: {len(train_data)}, Test samples: {len(test_data)}")

# dataset class with BPE tokenization
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, max_len=128):
        self.data = hf_dataset
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]['text']
        encoding = tokenizer(text, truncation=True, max_length=self.max_len+1, padding='max_length', return_tensors='pt')
        tokens = encoding['input_ids'].squeeze(0)
        
        x = tokens[:-1]
        y = tokens[1:]
        mask = (x != tokenizer.pad_token_id)
        return x, y, mask

# create dataloaders
train_dataset = TextDataset(train_data, max_len=128)
test_dataset = TextDataset(test_data, max_len=128)

train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=B, shuffle=False, num_workers=0)

print("\n*Data done loading*")


Loading AG News dataset...
Loading GPT-2 tokenizer...
Vocab size: 50257
Train samples: 120000, Test samples: 7600

*Data done loading*


Before we do anything, we want to set up some constants and create our `Transformers`

In [None]:
D = 128
H = 256
N = 8
K = 2
V = vocab_size
n_heads = 4
n_layers = 2
max_seq_len = 128

print(f"D: {D}\n H: {H}\n N: {N}\n K: {K}\n V: {V}\n n_heads: {n_heads}\n n_layers: {n_layers}\n max_seq_len: {max_seq_len}")

# create models
moe_fns = [
    lambda: RegularMoE(D, H, N, K),
    lambda: RandomMoE(D, H, N, K),
    lambda: OrthogonalMoE(D, H, N, K),
    lambda: HashMoE(D, H, N, K)
]
models = [Transformer(V, D, n_heads, n_layers, moe_fn, max_seq_len).to(device) for moe_fn in moe_fns]

# print number of parameters in each model
for i, model in enumerate(models):
    print(f"Model {i+1} ({moe_fns[i]().__class__.__name__}) has {sum(p.numel() for p in model.parameters())} parameters and {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")

D: 128
 H: 256
 N: 8
 K: 2
 V: 50257
 n_heads: 4
 n_layers: 2
 max_seq_len: 128
Model 1 (RegularMoE) has 8152464 parameters and 8152464 trainable parameters
Model 2 (RandomMoE) has 8152464 parameters and 8150400 trainable parameters
Model 3 (OrthogonalMoE) has 8150400 parameters and 8150400 trainable parameters
Model 4 (HashMoE) has 8150400 parameters and 8150400 trainable parameters


Now we can finally train all of our models separately and then compare their results to each other

In [None]:
def train_epoch(model, loader, optimizer, device):
    """Train for one epoch, return average loss"""
    model.train()
    total_loss = 0
    num_batches = 0

    print(f"{len(loader)} batches to process...")
    
    for x, y, mask in loader:
        x, y, mask = x.to(device), y.to(device), mask.to(device)
        
        # tm1 = time.time()
        logits = model(x, mask)  # [B, S, V]
        # tm2 = time.time()
        # print(f"Time taken for forward pass: {tm2 - tm1:.4f}s")
        
        loss = F.cross_entropy(logits.view(-1, V), y.view(-1), ignore_index=tokenizer.pad_token_id)
        
        # tm1 = time.time()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        # tm2 = time.time()
        # print(f"Time taken for backward pass: {tm2 - tm1:.4f}s")
        
        total_loss += loss.item()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f"Processed {num_batches} batches...")
    return total_loss / num_batches

@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate on dataset, return average loss"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    for x, y, mask in loader:
        x, y, mask = x.to(device), y.to(device), mask.to(device)
        logits = model(x, mask)
        loss = F.cross_entropy(logits.view(-1, V), y.view(-1), ignore_index=tokenizer.pad_token_id)
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches

# Training config
num_epochs = 3
lr = 3e-4
model_names = [models[i].moe_fn().__class__.__name__ for i in range(len(models))]

# Train each model
results = {}
for i, (model, name) in enumerate(zip(models, model_names)):
    print(f"\n{'='*60}")
    print(f"Training Model {i+1}/{len(models)}: {name}")
    print(f"{'='*60}")
    
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    train_losses = []
    test_losses = []
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, device)
        train_losses.append(train_loss)
        
        # Evaluate
        test_loss = evaluate(model, test_loader, device)
        test_losses.append(test_loss)
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Time: {epoch_time:.2f}s")
    
    results[name] = {
        'train_losses': train_losses,
        'test_losses': test_losses,
        'final_train_loss': train_losses[-1],
        'final_test_loss': test_losses[-1]
    }
    
    # Move back to CPU to free memory
    model = model.cpu()

# Print summary
print(f"\n{'='*60}")
print("FINAL RESULTS")
print(f"{'='*60}")
for name, res in results.items():
    print(f"{name:20s} | Train: {res['final_train_loss']:.4f} | Test: {res['final_test_loss']:.4f}")

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

for name, res in results.items():
    ax1.plot(range(1, num_epochs+1), res['train_losses'], marker='o', label=name)
    ax2.plot(range(1, num_epochs+1), res['test_losses'], marker='o', label=name)

ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('Test Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()



Training Model 1/4: RegularMoE
3750 batches to process...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=96, pipe_handle=177)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/opt/homebrew/Cellar/python@3.13/3.13.3/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/opt/homebrew/Cellar/python@3.13/3.13.3/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [

RuntimeError: DataLoader worker (pid(s) 53026) exited unexpectedly