# Variable Expert Analysis - Unified Notebook

This notebook supports both model sizes:
- **OpenWebText**: 4x2944 + 4x128 experts
- **WikiText**: 4x2432 + 4x128 experts

Simply change the `DATASET` variable in the config cell to switch between them.

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import rearrange
from collections import defaultdict, Counter
from tqdm import tqdm
import tiktoken

import stk
import stk.ops
import stk.random
import stk.matrix
from megablocks.layers.gelu import gelu

from model import GPT, GPTConfig, MoeMLP

In [None]:
# ========== CONFIGURATION ==========
# Change this to switch between datasets and model sizes
DATASET = "wikitext"  # Options: "openwebtext" or "wikitext"
SEED = 42  # Options: 42, 1223, 1337

# Dataset-specific configurations
CONFIGS = {
    "openwebtext": {
        "n_layer": 12,
        "n_head": 12,
        "n_embd": 768,
        "vocab_size": 50304,
        "expert_sizes": [(4, 2944), (4, 128)],  # 4 large (2944) + 4 small (128)
        "checkpoint_dir": f"out-openwebtext/moe-8x2-variable-4x2944-4x128-seed{SEED}",
        "val_data_path": "data/openwebtext/val.bin",
        "model_name": "OpenWebText (4x2944 + 4x128)"
    },
    "wikitext": {
        "n_layer": 8,
        "n_head": 8,
        "n_embd": 640,
        "vocab_size": 8192,
        "expert_sizes": [(4, 2432), (4, 128)],  # 4 large (2432) + 4 small (128)
        "checkpoint_dir": f"out-wikitext/moe-8x2-variable-4x2432-4x128-seed{SEED}",
        "val_data_path": "data/wikitext/val.bin",
        "model_name": "WikiText (4x2432 + 4x128)"
    }
}

# Select configuration
cfg = CONFIGS[DATASET]

# Create model config
config = GPTConfig(
    n_layer = cfg['n_layer'],
    n_head = cfg['n_head'],
    n_embd = cfg['n_embd'],
    bias = False,
    vocab_size= cfg['vocab_size'],
    
    # MoE configuration with VARIABLE-SIZE EXPERTS
    use_moe = True,
    num_experts = 8,
    num_experts_per_tok = 2,
    norm_topk_prob = True,
    block_size = 128,
    block_k = 64,
    expert_sizes = cfg["expert_sizes"]
)

print(f"\n{'='*60}")
print(f"Configuration: {cfg['model_name']}")
print(f"Seed: {SEED}")
print(f"Expert sizes: {config.expert_sizes}")
print(f"Checkpoint: {cfg['checkpoint_dir']}/ckpt.pt")
print(f"Val data: {cfg['val_data_path']}")
print(f"{'='*60}\n")

## Subclass the MoE MLP layer and GPT layer to track token routing

In [3]:
class MoeMLPWithTracking(MoeMLP):
    """Add expert assignment tracking to the mlp layer's forward pass"""

    @torch.compiler.disable
    def forward(self, x):
        batch_size, seq_len, n_embd = x.shape

        x_flat = rearrange(x, 'batch_size seq_len n_embd -> (batch_size seq_len) n_embd')

        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
        expert_weights, selected_experts = torch.topk(router_probs, self.num_experts_per_tok, dim=-1)

        if self.norm_topk_prob:
            expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
        expert_weights = expert_weights.to(x.dtype)
        expert_weights_flat = rearrange(expert_weights, '... -> (...)')
        selected_experts_flat = rearrange(selected_experts, '... -> (...)')

        bin_ids, indices, tokens_per_expert = self._sort_tokens_by_expert(selected_experts_flat)
        padded_bins, topology = self._create_topology(x_flat, tokens_per_expert)
        x_permuted = self._gather_tokens(x_flat, indices, bin_ids, tokens_per_expert, padded_bins)
        x_permuted = stk.ops.sdd(x_permuted, self.w1, topology)
        x_permuted = gelu(x_permuted)
        x_permuted = stk.ops.dsd(x_permuted, self.w2)

        x_permuted = self._scatter_tokens(x_permuted, indices, bin_ids, expert_weights_flat, tokens_per_expert, padded_bins)
        output = rearrange(x_permuted, '(batch_size seq_len) n_embd -> batch_size seq_len n_embd', batch_size=batch_size, seq_len=seq_len)

        router_z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()

        p_i = router_probs.mean(dim=0).to(torch.bfloat16)

        experts_flat = selected_experts.flatten()
        f_i = torch.zeros(self.num_experts, dtype=x.dtype, device=x.device)
        ones = torch.ones_like(experts_flat, dtype=x.dtype) / len(experts_flat)
        f_i.scatter_add(0, experts_flat, ones)
        load_balance_loss = self.num_experts * (f_i @ p_i)
        
        expert_assignments = rearrange(selected_experts, '(batch seq) k -> batch seq k', batch=batch_size, seq=seq_len)
        router_logits_reshaped = rearrange(router_logits, '(batch seq) num_experts -> batch seq num_experts', batch=batch_size, seq=seq_len)
        router_probs_reshaped = rearrange(router_probs, '(batch seq) num_experts -> batch seq num_experts', batch=batch_size, seq=seq_len)

        aux_loss = {
            'router_z_loss': router_z_loss,
            'load_balance_loss': load_balance_loss,
            'expert_assignments': expert_assignments,
            'router_logits': router_logits_reshaped,
            'router_probs': router_probs_reshaped,
        }
        
        return output, aux_loss, f_i

class GPTWithTracking(GPT):
    """Track expert assignments across layers"""

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.n_ctx, f"Cannot forward sequence of length {t}, context length is only {self.config.n_ctx}"
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        
        combined_aux_loss = {}
        all_expert_usage = []
        all_expert_assignments = {}
        all_router_logits = {}
        all_router_probs = {}
        aux_loss_count = 0

        for layer_idx, block in enumerate(self.transformer.h):
            block_out = block(x)
            
            x, aux_loss, f_i = block_out
            
            if f_i is not None:
                all_expert_usage.append(f_i)
            
            all_expert_assignments[f'layer_{layer_idx}'] = aux_loss['expert_assignments']
            all_router_logits[f'layer_{layer_idx}'] = aux_loss['router_logits']
            all_router_probs[f'layer_{layer_idx}'] = aux_loss['router_probs']
            
            if layer_idx == 0:
                combined_aux_loss = {k: v.clone() for k, v in aux_loss.items() 
                                if k not in ['expert_assignments', 'router_logits', 'router_probs']}
            else:
                for key in aux_loss:
                    if key not in ['expert_assignments', 'router_logits', 'router_probs']:
                        combined_aux_loss[key] += aux_loss[key]
            
            aux_loss_count += 1

        for key in combined_aux_loss:
            combined_aux_loss[key] /= aux_loss_count

        if all_expert_usage:
            avg_expert_usage = torch.stack(all_expert_usage).mean(dim=0)
            combined_aux_loss['expert_usage'] = avg_expert_usage

        combined_aux_loss['expert_assignments'] = all_expert_assignments
        combined_aux_loss['router_logits'] = all_router_logits
        combined_aux_loss['router_probs'] = all_router_probs

        x = self.transformer.ln_f(x)
        if targets is not None:
            logits = self.lm_head(x)
            ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            loss = ce_loss
            if combined_aux_loss is not None:
                loss = loss + self.config.load_balance_loss_weight * combined_aux_loss['load_balance_loss'] + self.config.router_z_loss_weight * combined_aux_loss['router_z_loss']
                combined_aux_loss['ce_loss'] = ce_loss
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
            ce_loss = None

        return logits, loss, combined_aux_loss

## Load Model and Checkpoint

In [None]:
checkpoint_path = f"{cfg['checkpoint_dir']}/ckpt.pt"

model = GPTWithTracking(config).to(torch.bfloat16)

for block in model.transformer.h:
    if hasattr(block.mlp, 'expert_sizes'):
        old_mlp = block.mlp
        block.mlp = MoeMLPWithTracking(config).to(torch.bfloat16)
        block.mlp.load_state_dict(old_mlp.state_dict())

checkpoint = torch.load(checkpoint_path, map_location='cpu')

state_dict = checkpoint['model']
if any(k.startswith('_orig_mod.') for k in state_dict.keys()):
    state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}

model.load_state_dict(state_dict)

print(f"✓ Loaded checkpoint from {checkpoint_path}")

## Collect ALL Routing Statistics in a SINGLE Pass

This cell collects all statistics (individual expert assignments and combinations) in one efficient pass through the validation set.

In [None]:
tokenizer = tiktoken.get_encoding('gpt2')

val_data_path = cfg['val_data_path']
val_data = np.memmap(val_data_path, dtype=np.uint16, mode='r')

device = 'cuda'
model = model.to(device)
model.eval()

num_layers = config.n_layer
expert_sizes = model.transformer.h[0].mlp.expert_sizes

# Initialize ALL tracking structures
token_stats_per_layer = {}
token_combinations_per_layer = {}

for layer_idx in range(num_layers):
    layer_name = f'layer_{layer_idx}'
    token_stats_per_layer[layer_name] = defaultdict(lambda: {
        'expert_counts': np.zeros(config.num_experts, dtype=np.int64),
        'total_occurrences': 0,
        'total_entropy': 0.0,
        'expert_size_sum': 0.0,
    })
    token_combinations_per_layer[layer_name] = defaultdict(Counter)

batch_size = 1
seq_len = 1024
total_tokens = len(val_data)
num_batches = total_tokens // seq_len

print(f"Running single pass through {num_batches} batches to collect all statistics...")

for batch_idx in tqdm(range(num_batches)):
    start_idx = batch_idx * seq_len
    end_idx = start_idx + seq_len
    batch_tokens = torch.from_numpy(val_data[start_idx:end_idx].astype(np.int64)).unsqueeze(0).to(device)

    with torch.inference_mode():
        logits, loss, aux_loss = model(batch_tokens, targets=batch_tokens)

    output_probs = F.softmax(logits[0], dim=-1)
    epsilon = 1e-10
    output_entropy = -(output_probs * torch.log(output_probs + epsilon)).sum(dim=-1).float().cpu().numpy()

    for layer_idx in range(num_layers):
        layer_name = f'layer_{layer_idx}'
        layer_assignments = aux_loss['expert_assignments'][layer_name][0].cpu().numpy()
        token_stats = token_stats_per_layer[layer_name]

        for pos in range(seq_len):
            token_id = int(batch_tokens[0, pos].item())
            expert_ids = layer_assignments[pos]

            # Update individual expert statistics
            token_stats[token_id]['total_occurrences'] += 1
            token_stats[token_id]['total_entropy'] += output_entropy[pos]

            for expert_id in expert_ids:
                token_stats[token_id]['expert_counts'][expert_id] += 1
                token_stats[token_id]['expert_size_sum'] += expert_sizes[expert_id]

            # Track expert combinations
            expert_combination = tuple(sorted(expert_ids))
            token_combinations_per_layer[layer_name][token_id][expert_combination] += 1

print(f"\n✓ Collected all statistics in a single pass!")

## Analysis: Print the Most Common Expert Combination for Each Token

This shows which expert pair each token uses most frequently across all layers.

In [None]:
def format_expert_combo(combo, expert_sizes):
    """Format expert combination with sizes"""
    return ",".join([f"{e}({expert_sizes[e]})" for e in combo])

def print_token_routing_table(token_combinations_per_layer, token_stats_per_layer, expert_sizes, tokenizer, max_tokens=100):
    """Print comprehensive routing table"""
    
    # Calculate average expert size and FLOPs for each token
    token_avg_sizes = {}
    for layer_name in token_stats_per_layer:
        for token_id, stats in token_stats_per_layer[layer_name].items():
            if stats['total_occurrences'] > 0:
                avg_size = stats['expert_size_sum'] / (stats['total_occurrences'] * 2)  # 2 experts per token
                if token_id not in token_avg_sizes:
                    token_avg_sizes[token_id] = []
                token_avg_sizes[token_id].append(avg_size)
    
    overall_avg_sizes = {tid: np.mean(sizes) for tid, sizes in token_avg_sizes.items()}
    
    # Header
    print(f"\n{'='*200}")
    header = f"{'Token ID':<10}{'Token':<20}{'Avg Size':<13}{'FLOPs':<16}"
    for layer_idx in range(num_layers):
        header += f"{'Layer ' + str(layer_idx):<21}"
    print(header)
    print(f"{'='*200}")
    
    # Print first N tokens
    for token_id in sorted(overall_avg_sizes.keys())[:max_tokens]:
        try:
            token_str = tokenizer.decode([token_id]).replace('\n', '\\n')
        except:
            token_str = f"<{token_id}>"
        
        avg_size = overall_avg_sizes[token_id]
        # FLOPs calculation: 2 * seq_len * hidden * expert_size (for matrix mult)
        # Simplified: avg_size * hidden * 2 operations
        flops = avg_size * 768 * 4 * 12  # hidden=768, 4x forward/backward, 12 layers
        
        row = f"{token_id:<10}{token_str:<20}{avg_size:<13.1f}{flops:<16,.0f}"
        
        for layer_idx in range(num_layers):
            layer_name = f'layer_{layer_idx}'
            if token_id in token_combinations_per_layer[layer_name]:
                combos = token_combinations_per_layer[layer_name][token_id]
                most_common = combos.most_common(1)[0][0]
                combo_str = f"({format_expert_combo(most_common, expert_sizes)})"
                row += f"{combo_str:<21}"
            else:
                row += f"{'N/A':<21}"
        
        print(row)

print_token_routing_table(token_combinations_per_layer, token_stats_per_layer, expert_sizes, tokenizer, max_tokens=350)

In [None]:
import pandas as pd

print("Building dataframe from collected statistics...")

# Get unique tokens
all_token_ids = set()
for layer_combos in token_combinations_per_layer.values():
    all_token_ids.update(layer_combos.keys())

# Build data for dataframe
data = []
for token_id in all_token_ids:
    # Decode token
    try:
        token_text = tokenizer.decode([token_id])
        token_text = token_text.replace('\n', '\\n').replace('\t', '\\t').replace('\r', '\\r')
        if '�' in token_text or not token_text.isprintable():
            token_text = f"<{token_id}>"
        if len(token_text) > 18:
            token_text = token_text[:17] + '…'
    except:
        token_text = f"<{token_id}>"
    
    # Calculate average expert SIZE across all layers
    total_size = 0
    layer_count = 0
    layer_data = {}
    
    for layer_idx in range(num_layers):
        layer_name = f'layer_{layer_idx}'
        combos = token_combinations_per_layer[layer_name][token_id]
        if combos:
            most_common = combos.most_common(1)[0][0]
            layer_size = sum(expert_sizes[e] for e in most_common)
            total_size += layer_size
            layer_count += 1
            # Format with expert sizes: (5(128),7(128))
            formatted = "(" + ",".join([f"{e}({expert_sizes[e]})" for e in most_common]) + ")"
            layer_data[f'layer_{layer_idx}'] = formatted
        else:
            layer_data[f'layer_{layer_idx}'] = 'N/A'
    
    avg_size = total_size / layer_count if layer_count > 0 else 0
    flops = 4 * config.n_embd * total_size
    
    row = {
        'token_id': token_id,
        'token': token_text,
        'avg_size': avg_size,
        'flops': flops,
        **layer_data
    }
    data.append(row)

# Create DataFrame and sort by FLOPs
df = pd.DataFrame(data)
df = df.sort_values('flops', ascending=True).reset_index(drop=True)

print(f"\nDataFrame created with {len(df)} tokens, sorted by FLOPs (low to high)")
print(f"\nFirst 20 rows (lowest FLOPs):")
print(df.head(20).to_string())

print(f"\n\nLast 20 rows (highest FLOPs):")
print(df.tail(20).to_string())

# Summary statistics
print(f"\n{'='*80}")
print(f"Summary Statistics:")
print(f"{'='*80}")
print(f"Total unique tokens: {len(df)}")
print(f"\nFLOPs distribution:")
print(f"  Min:    {df['flops'].min():,.0f}")
print(f"  25%:    {df['flops'].quantile(0.25):,.0f}")
print(f"  Median: {df['flops'].median():,.0f}")
print(f"  75%:    {df['flops'].quantile(0.75):,.0f}")
print(f"  Max:    {df['flops'].max():,.0f}")
print(f"  Mean:   {df['flops'].mean():,.0f}")

print(f"\nAverage Size distribution:")
print(f"  Tokens with avg_size >= 2560: {(df['avg_size'] >= 2560).sum()} ({100*(df['avg_size'] >= 2560).sum()/len(df):.2f}%)")
print(f"  Tokens with avg_size < 2560:  {(df['avg_size'] < 2560).sum()} ({100*(df['avg_size'] < 2560).sum()/len(df):.2f}%)")

baseline_flops = 8 * 4 * 640 * 2560  # num_layers * 4 * hidden_size * expert_size
print(f"\nAverage FLOPs per token: {df['flops'].mean():.0f} ({100*df['flops'].mean()/baseline_flops:.2f}% of baseline)")

# Store the dataframe for further analysis
expert_combinations_df = df
sweep_value = '-'.join(checkpoint_path.split('-')).split('/')[-2]

df.to_csv(f'analysis_csvs/{sweep_value}_expert_combinations.csv', index=False)
print(f"\n✓ Saved to analysis_csvs/{sweep_value}_expert_combinations.csv")

In [None]:
# Analyze per-token statistics FOR EACH LAYER
for layer_idx in range(num_layers):
    layer_name = f'layer_{layer_idx}'
    token_stats = token_stats_per_layer[layer_name]
    
    print(f"\n{'='*80}")
    print(f"LAYER {layer_idx} ANALYSIS")
    print(f"{'='*80}\n")
    
    # Compute derived metrics for each token
    token_analysis = {}
    
    for token_id, stats in token_stats.items():
        if stats['total_occurrences'] > 0:
            # Average entropy
            avg_entropy = stats['total_entropy'] / stats['total_occurrences']
            
            # Expert distribution (normalized)
            expert_distribution = stats['expert_counts'] / stats['expert_counts'].sum()
            
            # Most common expert
            most_common_expert = np.argmax(stats['expert_counts'])
            
            # Average expert size
            avg_expert_size = stats['expert_size_sum'] / stats['expert_counts'].sum()
            
            token_analysis[token_id] = {
                'avg_entropy': avg_entropy,
                'occurrences': stats['total_occurrences'],
                'expert_distribution': expert_distribution,
                'most_common_expert': most_common_expert,
                'avg_expert_size': avg_expert_size,
            }
    
    # Plot distribution of average expert sizes
    all_expert_sizes = np.array([a['avg_expert_size'] for a in token_analysis.values()])
    all_occurrences = np.array([a['occurrences'] for a in token_analysis.values()])
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(f'Layer {layer_idx} Token Routing Analysis', fontsize=16)
    
    # Unweighted histogram (by unique tokens)
    ax1 = axes[0]
    ax1.hist(all_expert_sizes, bins=50, alpha=0.7, edgecolor='black')
    ax1.axvline(x=128, color='blue', linestyle='--', label='Small (128)', linewidth=2)
    ax1.axvline(x=2944, color='red', linestyle='--', label='Large (2944)', linewidth=2)
    ax1.axvline(x=np.mean(all_expert_sizes), color='green', linestyle='--', label=f'Mean ({np.mean(all_expert_sizes):.0f})', linewidth=2)
    ax1.set_xlabel('Average Expert Size per Token')
    ax1.set_ylabel('Number of Unique Tokens')
    ax1.set_title('Distribution by Unique Tokens (Unweighted)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Weighted histogram (by token occurrences)
    ax2 = axes[1]
    ax2.hist(all_expert_sizes, bins=50, weights=all_occurrences, alpha=0.7, edgecolor='black', color='orange')
    ax2.axvline(x=128, color='blue', linestyle='--', label='Small (128)', linewidth=2)
    ax2.axvline(x=2944, color='red', linestyle='--', label='Large (2944)', linewidth=2)
    weighted_mean = np.average(all_expert_sizes, weights=all_occurrences)
    ax2.axvline(x=weighted_mean, color='green', linestyle='--', label=f'Weighted Mean ({weighted_mean:.0f})', linewidth=2)
    ax2.set_xlabel('Average Expert Size per Token')
    ax2.set_ylabel('Total Token Occurrences')
    ax2.set_title('Distribution by Token Occurrences (Weighted)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Pie chart - weighted breakdown
    ax3 = axes[2]
    large_expert_occurrences = sum(all_occurrences[all_expert_sizes > 1536])
    small_expert_occurrences = sum(all_occurrences[all_expert_sizes <= 1536])
    total_occurrences = large_expert_occurrences + small_expert_occurrences
    
    ax3.pie([large_expert_occurrences, small_expert_occurrences],
            labels=['Large experts\n(>1536)', 'Small experts\n(≤1536)'],
            autopct='%1.1f%%',
            colors=['red', 'blue'])
    ax3.set_title(f'Token Occurrences by Expert Size\n(Total: {total_occurrences:,} tokens)')
    
    plt.tight_layout()
    plt.show()
    
    # Statistics
    print(f"Average Expert Size Statistics (Unweighted):")
    print(f"  Mean: {np.mean(all_expert_sizes):.2f}")
    print(f"  Median: {np.median(all_expert_sizes):.2f}")
    print(f"  Std: {np.std(all_expert_sizes):.2f}")
    
    print(f"\nAverage Expert Size Statistics (Weighted by occurrences):")
    print(f"  Weighted Mean: {weighted_mean:.2f}")
    
    # Count how many tokens go to mostly large vs mostly small experts
    large_expert_tokens = sum(1 for s in all_expert_sizes if s > 1536)
    small_expert_tokens = sum(1 for s in all_expert_sizes if s <= 1536)
    print(f"\nUnique Token routing breakdown:")
    print(f"  Unique tokens routing mostly to LARGE experts: {large_expert_tokens} ({100*large_expert_tokens/len(all_expert_sizes):.1f}%)")
    print(f"  Unique tokens routing mostly to SMALL experts: {small_expert_tokens} ({100*small_expert_tokens/len(all_expert_sizes):.1f}%)")
    
    print(f"\nWeighted by occurrences:")
    print(f"  Token occurrences routed to LARGE experts: {large_expert_occurrences:,} ({100*large_expert_occurrences/total_occurrences:.1f}%)")
    print(f"  Token occurrences routed to SMALL experts: {small_expert_occurrences:,} ({100*small_expert_occurrences/total_occurrences:.1f}%)")

print("\n" + "="*80)
print("SUMMARY ACROSS ALL LAYERS")
print("="*80)