In [None]:
import os, sys, tomllib, torch, torch.nn.functional as F
from pathlib import Path
from dataclasses import asdict
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from nmoe.config import Config
from nmoe.model import Transformer
from nmoe.data.loader import build_loader
from nmoe.opt import build_optimizer, update_lr, step
from nmoe.checkpoint import Checkpointer, load_checkpoint
from nmoe import runtime

plt.style.use('dark_background')
torch.set_float32_matmul_precision('high')

---
## toolkit

In [None]:
# ============================================================================
# MEMORY
# ============================================================================

def mem_stats():
    """Current GPU memory state."""
    alloc = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    peak = torch.cuda.max_memory_allocated() / 1e9
    return {'allocated_gb': alloc, 'reserved_gb': reserved, 'peak_gb': peak}

def param_memory(model):
    """Memory used by parameters (not activations)."""
    total = sum(p.numel() * p.element_size() for p in model.parameters())
    return total / 1e9

def optimizer_memory(optimizer):
    """Memory used by optimizer state (adam moments etc)."""
    total = 0
    for state in optimizer.state.values():
        for v in state.values():
            if torch.is_tensor(v):
                total += v.numel() * v.element_size()
    return total / 1e9

def activation_memory(model, x):
    """Estimate activation memory from a forward pass."""
    torch.cuda.reset_peak_memory_stats()
    baseline = torch.cuda.memory_allocated()
    with torch.no_grad():
        _ = model(x)
    torch.cuda.synchronize()
    return (torch.cuda.max_memory_allocated() - baseline) / 1e9

def memory_breakdown(model, optimizer, x):
    """Full memory attribution."""
    return {
        'params_gb': param_memory(model),
        'optimizer_gb': optimizer_memory(optimizer) if optimizer else 0,
        'activations_gb': activation_memory(model, x),
        **mem_stats()
    }

In [None]:
# ============================================================================
# TIMING
# ============================================================================

def time_forward_backward(model, x, y, vocab_size, warmup=3, iters=10):
    """Time forward + backward pass."""
    for _ in range(warmup):
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
        loss.backward()
        model.zero_grad(set_to_none=True)
    
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(iters):
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
        loss.backward()
        model.zero_grad(set_to_none=True)
    end.record()
    torch.cuda.synchronize()
    
    ms = start.elapsed_time(end) / iters
    tokens = x.numel()
    return {'ms': ms, 'tokens_per_sec': tokens / (ms / 1000)}

def profile_kernels(model, x, y, vocab_size):
    """Get top CUDA kernels by time."""
    model.zero_grad(set_to_none=True)
    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
        loss.backward()
        torch.cuda.synchronize()
    return prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)

In [None]:
# ============================================================================
# MOE ROUTING
# ============================================================================

class RouterProbe:
    """Capture routing decisions during forward pass."""
    def __init__(self, model):
        self.model = model
        self.stats = defaultdict(list)
        self._hooks = []
        
    def _hook(self, layer_idx):
        def fn(module, input, output):
            if hasattr(module, 'gate') and hasattr(module.gate, 'weight'):
                # Assuming output contains routing info or we can compute it
                x = input[0] if isinstance(input, tuple) else input
                with torch.no_grad():
                    scores = F.softmax(x @ module.gate.weight.T, dim=-1)
                    topk = scores.topk(module.n_activated, dim=-1)
                    expert_counts = torch.bincount(topk.indices.view(-1), minlength=module.n_experts)
                    self.stats[layer_idx].append({
                        'counts': expert_counts.cpu(),
                        'entropy': -(scores * scores.log().nan_to_num()).sum(-1).mean().item(),
                        'max_prob': scores.max(-1).values.mean().item(),
                    })
        return fn
    
    def attach(self):
        for i, layer in enumerate(self.model.layers):
            if hasattr(layer, 'moe'):
                h = layer.moe.register_forward_hook(self._hook(i))
                self._hooks.append(h)
        return self
    
    def detach(self):
        for h in self._hooks:
            h.remove()
        self._hooks = []
        
    def summary(self):
        rows = []
        for layer_idx, records in self.stats.items():
            counts = torch.stack([r['counts'] for r in records]).float().mean(0)
            rows.append({
                'layer': layer_idx,
                'entropy': np.mean([r['entropy'] for r in records]),
                'max_prob': np.mean([r['max_prob'] for r in records]),
                'load_std': counts.std().item(),
                'dead_experts': (counts == 0).sum().item(),
            })
        return pd.DataFrame(rows)
    
    def plot_load(self, layers=None):
        layers = layers or list(self.stats.keys())[:4]
        fig, axes = plt.subplots(1, len(layers), figsize=(4*len(layers), 3))
        if len(layers) == 1: axes = [axes]
        for ax, layer_idx in zip(axes, layers):
            counts = torch.stack([r['counts'] for r in self.stats[layer_idx]]).float().mean(0)
            ax.bar(range(len(counts)), counts.numpy())
            ax.set_title(f'Layer {layer_idx}')
            ax.set_xlabel('Expert')
        plt.tight_layout()
        return fig

In [None]:
# ============================================================================
# GRADIENT HEALTH
# ============================================================================

def grad_norms(model, per_layer=False):
    """Gradient L2 norms."""
    norms = {}
    for name, p in model.named_parameters():
        if p.grad is not None:
            norms[name] = p.grad.norm().item()
    s = pd.Series(norms)
    if per_layer:
        return s
    return {'max': s.max(), 'min': s.min(), 'mean': s.mean(), 'std': s.std()}

def grad_flow(model):
    """Plot gradient flow through layers."""
    norms = []
    names = []
    for name, p in model.named_parameters():
        if p.grad is not None and 'weight' in name:
            norms.append(p.grad.norm().item())
            names.append(name.replace('.weight', ''))
    
    fig, ax = plt.subplots(figsize=(max(12, len(norms)//4), 4))
    ax.bar(range(len(norms)), norms)
    ax.set_xticks(range(len(norms)))
    ax.set_xticklabels(names, rotation=90, fontsize=6)
    ax.set_ylabel('Gradient Norm')
    ax.set_yscale('log')
    plt.tight_layout()
    return fig

def check_nan_inf(model):
    """Find NaN/Inf in params or grads."""
    issues = []
    for name, p in model.named_parameters():
        if torch.isnan(p).any():
            issues.append(f'{name}: NaN in param')
        if torch.isinf(p).any():
            issues.append(f'{name}: Inf in param')
        if p.grad is not None:
            if torch.isnan(p.grad).any():
                issues.append(f'{name}: NaN in grad')
            if torch.isinf(p.grad).any():
                issues.append(f'{name}: Inf in grad')
    return issues or ['clean']

In [None]:
# ============================================================================
# CHECKPOINTS
# ============================================================================

def list_checkpoints(base='/data/checkpoints'):
    """List available checkpoints."""
    p = Path(base)
    if not p.exists():
        return []
    return sorted([d.name for d in p.iterdir() if d.is_dir()])

def peek_checkpoint(path):
    """Inspect checkpoint contents without loading model."""
    p = Path(path)
    info = {}
    
    meta = p / 'meta.pt'
    if meta.exists():
        m = torch.load(meta, map_location='cpu', weights_only=False)
        info['step'] = m.get('step')
        info['tokens'] = m.get('tokens_seen')
        info['config_hash'] = m.get('config_fingerprint', '')[:8]
    
    model_pt = p / 'model.pt'
    if model_pt.exists():
        info['model_size_gb'] = model_pt.stat().st_size / 1e9
    
    return info

def diff_configs(cfg1, cfg2):
    """Show differences between two configs."""
    d1, d2 = asdict(cfg1), asdict(cfg2)
    diffs = []
    for k in set(d1) | set(d2):
        v1, v2 = d1.get(k), d2.get(k)
        if v1 != v2:
            diffs.append({'key': k, 'cfg1': v1, 'cfg2': v2})
    return pd.DataFrame(diffs)

In [None]:
# ============================================================================
# QUICK ABLATIONS
# ============================================================================

def quick_train(cfg_overrides, steps=20, log_every=5):
    """Run a few steps, return loss curve."""
    with open('configs/moonlet.toml', 'rb') as f:
        cfg = Config(**{**tomllib.load(f), **cfg_overrides, 'steps': steps})
    
    rank, world = runtime.init(cfg.seed)
    loader, plan = build_loader(cfg, rank, world)
    model = Transformer(cfg).cuda()
    model.init_weights()
    model.train()
    optimizer, dense_groups = build_optimizer(model, cfg)
    
    losses = []
    for i in range(steps):
        x, y = loader.next()
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, cfg.vocab_size), y.view(-1))
        loss.backward()
        step(model, optimizer, dense_groups, {}, cfg, world)
        model.zero_grad(set_to_none=True)
        
        losses.append(loss.item())
        if (i + 1) % log_every == 0:
            print(f'[{i+1}/{steps}] loss={loss.item():.4f}')
    
    del model, optimizer
    torch.cuda.empty_cache()
    runtime.finalize()
    
    return losses

def compare_ablations(configs, steps=20):
    """Run multiple configs, plot loss curves."""
    results = {}
    for name, overrides in configs.items():
        print(f'\n=== {name} ===')
        results[name] = quick_train(overrides, steps=steps)
    
    fig, ax = plt.subplots(figsize=(8, 4))
    for name, losses in results.items():
        ax.plot(losses, label=name)
    ax.set_xlabel('Step')
    ax.set_ylabel('Loss')
    ax.legend()
    return fig, results

---
## scratch

In [None]:
# Load config
with open('configs/moonlet.toml', 'rb') as f:
    cfg = Config(**tomllib.load(f))

# Init
rank, world = runtime.init(cfg.seed)
model = Transformer(cfg).cuda()
model.init_weights()
model.train()
optimizer, dense_groups = build_optimizer(model, cfg)

# Dummy batch
x = torch.randint(0, cfg.vocab_size, (cfg.batch_size, cfg.seq_len), device='cuda')
y = torch.randint(0, cfg.vocab_size, (cfg.batch_size, cfg.seq_len), device='cuda')

In [None]:
# Your experiments here


In [None]:
# Cleanup
del model, optimizer, x, y
torch.cuda.empty_cache()
runtime.finalize()