# Model Comparison: Hymba vs Diff Transformer vs Mamba

This notebook compares three state-of-the-art sequence models:
1. **Hymba** - Hybrid architecture combining Attention + Mamba with SWA and KV sharing
2. **Diff Transformer** - Differential attention mechanism
3. **Mamba** - Pure selective state space model

All models use the same:
- Dataset (TinyShakespeare)
- Vocabulary size
- Training configuration
- Evaluation metrics

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import time
import math
from contextlib import nullcontext

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Models and Data

In [None]:
# Import all three models
from backbone.hymba_v2 import HymbaV2, ModelCfg as HymbaCfg, TrainCfg, build_everything as build_hymba
from backbone.diff_transformer import DiffTransformer, ModelCfg as DiffCfg, build_everything as build_diff
from backbone.mamba_model import MambaModel, ModelCfg as MambaCfg, build_everything as build_mamba

In [None]:
# Shared configuration
SEQ_LEN = 512
BATCH_SIZE = 32
VOCAB_SIZE = 6000
D_MODEL = 384
N_LAYERS = 12
N_HEADS = 6
N_KV_HEADS = 2

# Training configuration
STEPS = 500
LR = 6e-4
WARMUP = 50

## 2. Build Models

In [None]:
# Build Hymba
print("Building Hymba...")
hymba_model, tok, train_dl, val_dl = build_hymba(seq_len=SEQ_LEN, bs=BATCH_SIZE, vocab_size=VOCAB_SIZE)
hymba_model.to(device)

print(f"\nHymba parameters: {sum(p.numel() for p in hymba_model.parameters() if p.requires_grad):,}")
display(hymba_model.layer_table())

In [None]:
# Build Diff Transformer
print("Building Diff Transformer...")
diff_cfg = DiffCfg(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    n_kv_heads=N_KV_HEADS,
    seq_len=SEQ_LEN
)
diff_model = DiffTransformer(diff_cfg).to(device)

print(f"\nDiff Transformer parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad):,}")

In [None]:
# Build Mamba
print("Building Mamba...")
mamba_cfg = MambaCfg(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    state_size=16,
    seq_len=SEQ_LEN
)
mamba_model = MambaModel(mamba_cfg).to(device)

print(f"\nMamba parameters: {sum(p.numel() for p in mamba_model.parameters() if p.requires_grad):,}")

## 3. Training Function

In [None]:
def train_model(model, train_dl, val_dl, model_name, steps=STEPS):
    """Train a model and return metrics"""
    from backbone.hymba_v2 import train_loop
    
    tcfg = TrainCfg(
        seq_len=SEQ_LEN,
        batch_size=BATCH_SIZE,
        steps=steps,
        lr=LR,
        warmup=WARMUP,
        amp=True,
        grad_clip=1.0
    )
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")
    
    start_time = time.time()
    stats = train_loop(model, train_dl, val_dl, tcfg, device=device)
    elapsed = time.time() - start_time
    
    stats['model'] = model_name
    stats['time_s'] = elapsed
    stats['params'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return stats

## 4. Evaluation Function

In [None]:
@torch.no_grad()
def evaluate_ppl(model, val_dl, amp=True):
    """Evaluate perplexity on validation set"""
    model.eval()
    nll = 0.0
    tok = 0
    ctx = (torch.amp.autocast("cuda") if (amp and device=="cuda") else nullcontext())
    
    with ctx:
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb, targets=yb)
            nll += out["loss"].item() * xb.numel()
            tok += xb.numel()
    
    return math.exp(nll / max(1, tok))

@torch.no_grad()
def bench_generate(model, prompt_len=512, gen_len=256, warmup=1, repeat=2):
    """Benchmark generation speed"""
    model.eval()
    device = next(model.parameters()).device
    
    # Use vocab_size from model config
    if hasattr(model, 'cfg'):
        vocab = model.cfg.vocab_size
    else:
        vocab = VOCAB_SIZE
    
    torch.manual_seed(0)
    prompt = torch.randint(0, vocab, (1, prompt_len), device=device)

    # Warmup
    for _ in range(warmup):
        _ = model.generate(prompt, max_new_tokens=16)

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats()

    times = []
    for _ in range(repeat):
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.time()
        _ = model.generate(prompt, max_new_tokens=gen_len)
        if device.type == "cuda":
            torch.cuda.synchronize()
        times.append(time.time() - t0)

    sec = sum(times) / len(times)
    tps = int((prompt_len + gen_len) / sec)
    mem = 0.0
    if device.type == "cuda":
        mem = torch.cuda.max_memory_allocated() / (1024**2)

    return {
        "gen_latency_s": round(sec, 3),
        "gen_tps": tps,
        "gen_peak_mb": round(mem, 2),
    }

## 5. Train All Models

In [None]:
results = []

# Train Hymba
hymba_stats = train_model(hymba_model, train_dl, val_dl, "Hymba")
results.append(hymba_stats)

In [None]:
# Train Diff Transformer
diff_stats = train_model(diff_model, train_dl, val_dl, "Diff Transformer")
results.append(diff_stats)

In [None]:
# Train Mamba
mamba_stats = train_model(mamba_model, train_dl, val_dl, "Mamba")
results.append(mamba_stats)

## 6. Benchmark Generation

In [None]:
gen_results = []

for model, name in [(hymba_model, "Hymba"), (diff_model, "Diff Transformer"), (mamba_model, "Mamba")]:
    print(f"\nBenchmarking {name} generation...")
    bench = bench_generate(model, prompt_len=512, gen_len=256)
    bench['model'] = name
    gen_results.append(bench)
    print(f"{name}: {bench['gen_tps']} tokens/s, {bench['gen_latency_s']}s latency")

## 7. Results Summary

In [None]:
# Training results
df_train = pd.DataFrame(results)
df_train = df_train[['model', 'train_loss', 'val_loss', 'ppl', 'tps', 'time_s', 'params']]
display(df_train)

In [None]:
# Generation results
df_gen = pd.DataFrame(gen_results)
df_gen = df_gen[['model', 'gen_tps', 'gen_latency_s', 'gen_peak_mb']]
display(df_gen)

## 8. Visualization

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot 1: Validation Loss
axes[0, 0].bar(df_train['model'], df_train['val_loss'])
axes[0, 0].set_title('Validation Loss (Lower is Better)')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].tick_params(axis='x', rotation=15)

# Plot 2: Perplexity
axes[0, 1].bar(df_train['model'], df_train['ppl'])
axes[0, 1].set_title('Perplexity (Lower is Better)')
axes[0, 1].set_ylabel('PPL')
axes[0, 1].tick_params(axis='x', rotation=15)

# Plot 3: Training Throughput
axes[0, 2].bar(df_train['model'], df_train['tps'])
axes[0, 2].set_title('Training Throughput (Higher is Better)')
axes[0, 2].set_ylabel('Tokens/s')
axes[0, 2].tick_params(axis='x', rotation=15)

# Plot 4: Generation Speed
axes[1, 0].bar(df_gen['model'], df_gen['gen_tps'])
axes[1, 0].set_title('Generation Speed (Higher is Better)')
axes[1, 0].set_ylabel('Tokens/s')
axes[1, 0].tick_params(axis='x', rotation=15)

# Plot 5: Generation Latency
axes[1, 1].bar(df_gen['model'], df_gen['gen_latency_s'])
axes[1, 1].set_title('Generation Latency (Lower is Better)')
axes[1, 1].set_ylabel('Seconds')
axes[1, 1].tick_params(axis='x', rotation=15)

# Plot 6: Memory Usage
axes[1, 2].bar(df_gen['model'], df_gen['gen_peak_mb'])
axes[1, 2].set_title('Peak Memory Usage (Lower is Better)')
axes[1, 2].set_ylabel('MB')
axes[1, 2].tick_params(axis='x', rotation=15)

plt.tight_layout()
plt.show()

## 9. Text Generation Comparison

In [None]:
# Generate sample text from each model
prompt = "To be or not to be"
prompt_ids = torch.tensor([tok.encode(prompt)], device=device)

print("Prompt:", prompt)
print("="*60)

for model, name in [(hymba_model, "Hymba"), (diff_model, "Diff Transformer"), (mamba_model, "Mamba")]:
    print(f"\n{name}:")
    print("-"*60)
    generated = model.generate(prompt_ids, max_new_tokens=100, temperature=0.8)
    text = tok.decode(generated[0].tolist())
    print(text[:200], "...")

## 10. Key Findings

This comparison evaluates:

1. **Training Efficiency**: Tokens/second during training
2. **Model Quality**: Validation loss and perplexity
3. **Inference Speed**: Generation throughput and latency
4. **Memory Efficiency**: Peak memory usage during generation

### Expected Results:

- **Hymba**: Best balance of quality and efficiency with hybrid architecture
- **Diff Transformer**: Strong quality from differential attention but higher compute
- **Mamba**: Fast inference with linear complexity but may need more tuning

### Architecture Highlights:

- **Hymba**: Combines attention (global context) + Mamba (efficiency) with SWA and KV sharing
- **Diff Transformer**: Differential attention reduces noise by subtracting two attention patterns
- **Mamba**: Selective SSM with input-dependent parameters for efficient long-range modeling