# Hierarchical BDH vs Original BDH Comparison

This notebook compares:
- **Original BDH**: Single-scale byte-level attention
- **Hierarchical BDH**: MEGABYTE-inspired global (patch) + local (byte) architecture

## Architecture Comparison

| Aspect | Original BDH | Hierarchical BDH |
|--------|-------------|------------------|
| Scale | Single (byte) | Multi (patch + byte) |
| Attention | O(T²) on bytes | O((T/P)²) global + O(P²) local |
| Context | Direct byte attention | Patch-level + byte-level |
| Semantic grouping | Implicit | Explicit (patches) |

## Hypothesis

Hierarchical BDH should show:
1. **Better long-range coherence** (global model captures cross-patch dependencies)
2. **Improved perplexity** on natural language (semantic units align with patches)
3. **Similar or better efficiency** (reduced attention complexity)

In [None]:
# Setup - Clone repo if in Colab
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || echo 'Repo exists'
%cd bdh
!pip install -q torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
from datetime import datetime
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Import both models
from bdh import BDH, BDHConfig
from bdh_hierarchical import HierarchicalBDH, HierarchicalBDHConfig, create_hierarchical_bdh

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## 1. Load WikiText-2

In [None]:
from datasets import load_dataset

print('Loading WikiText-2...')
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

# Concatenate all text
train_text = '\n'.join(dataset['train']['text'])
val_text = '\n'.join(dataset['validation']['text'])
test_text = '\n'.join(dataset['test']['text'])

train_data = text_to_bytes(train_text)
val_data = text_to_bytes(val_text)
test_data = text_to_bytes(test_text)

print(f'Train: {len(train_data):,} bytes ({len(train_data)/1e6:.1f}M)')
print(f'Val: {len(val_data):,} bytes')
print(f'Test: {len(test_data):,} bytes')

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

## 2. Create Models

We'll create comparable models (similar parameter counts) for fair comparison.

In [None]:
# Training config
BLOCK_SIZE = 512  # Must be divisible by patch_size
BATCH_SIZE = 32
MAX_STEPS = 5000
WARMUP_STEPS = 200
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.1
VAL_INTERVAL = 100
VAL_BATCHES = 50
PATIENCE = 10

In [None]:
# Original BDH (~25M params to match hierarchical tiny/small)
original_config = BDHConfig(
    n_layer=6,
    n_head=8,
    n_embd=256,
    vocab_size=256,
    dropout=0.2,
)

original_model = BDH(original_config).to(device)
original_params = sum(p.numel() for p in original_model.parameters())

print('ORIGINAL BDH')
print(f'  Layers: {original_config.n_layer}')
print(f'  Embed dim: {original_config.n_embd}')
print(f'  Heads: {original_config.n_head}')
print(f'  Parameters: {original_params:,} ({original_params/1e6:.1f}M)')

In [None]:
# Hierarchical BDH (small config - ~30M params to match original)
hier_config = HierarchicalBDHConfig.small(
    max_seq_len=BLOCK_SIZE,
    dropout=0.2,
)

hier_model = HierarchicalBDH(hier_config).to(device)
hier_params = hier_model.count_parameters()

print('\nHIERARCHICAL BDH')
print(f'  Patch size: {hier_config.patch_size}')
print(f'  Global: {hier_config.global_n_layer}L x {hier_config.global_n_embd}D x {hier_config.global_n_head}H')
print(f'  Local:  {hier_config.local_n_layer}L x {hier_config.local_n_embd}D x {hier_config.local_n_head}H')
print(f'  Parameters by component:')
for name, count in hier_params.items():
    print(f'    {name}: {count:,} ({count/1e6:.1f}M)')

In [None]:
# Compare parameter counts
print('\nPARAMETER COMPARISON')
print(f'  Original BDH:     {original_params:>12,} ({original_params/1e6:.1f}M)')
print(f'  Hierarchical BDH: {hier_params["total"]:>12,} ({hier_params["total"]/1e6:.1f}M)')
print(f'  Ratio: {hier_params["total"]/original_params:.2f}x')

## 3. Training Infrastructure

In [None]:
# Create dataloaders
train_dataset = ByteDataset(train_data, BLOCK_SIZE)
val_dataset = ByteDataset(val_data, BLOCK_SIZE)
test_dataset = ByteDataset(test_data, BLOCK_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f'Train batches: {len(train_loader):,}')
print(f'Val batches: {len(val_loader):,}')
print(f'Test batches: {len(test_loader):,}')

In [None]:
def get_lr(step, warmup_steps=WARMUP_STEPS, max_steps=MAX_STEPS, lr=LEARNING_RATE):
    """Cosine LR schedule with warmup."""
    if step < warmup_steps:
        return lr * step / warmup_steps
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return lr * 0.5 * (1 + math.cos(math.pi * progress))


@torch.no_grad()
def evaluate(model, loader, max_batches=None):
    """Evaluate model, return loss and perplexity."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for i, (x, y) in enumerate(loader):
        if max_batches and i >= max_batches:
            break
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    
    model.train()
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    bpb = avg_loss / math.log(2)  # Bits per byte
    return {'loss': avg_loss, 'ppl': perplexity, 'bpb': bpb}

In [None]:
def train_model(model, name, max_steps=MAX_STEPS):
    """Train a model and return history."""
    print(f'\n{"="*60}')
    print(f'Training {name}')
    print('='*60)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        betas=(0.9, 0.95),
    )
    
    history = {
        'step': [], 'train_loss': [], 'val_loss': [],
        'val_ppl': [], 'val_bpb': [], 'lr': [],
    }
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_state = None
    
    model.train()
    train_iter = iter(train_loader)
    running_loss = 0
    
    pbar = tqdm(range(max_steps), desc=name)
    
    for step in pbar:
        # Get batch
        try:
            x, y = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            x, y = next(train_iter)
        
        x, y = x.to(device), y.to(device)
        
        # Update LR
        lr = get_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        # Forward + backward
        _, loss = model(x, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        running_loss += loss.item()
        
        # Logging
        if (step + 1) % 50 == 0:
            avg_loss = running_loss / 50
            pbar.set_postfix({'loss': f'{avg_loss:.3f}', 'lr': f'{lr:.2e}'})
            running_loss = 0
        
        # Validation
        if (step + 1) % VAL_INTERVAL == 0:
            val_metrics = evaluate(model, val_loader, VAL_BATCHES)
            train_metrics = evaluate(model, train_loader, VAL_BATCHES)
            
            history['step'].append(step + 1)
            history['train_loss'].append(train_metrics['loss'])
            history['val_loss'].append(val_metrics['loss'])
            history['val_ppl'].append(val_metrics['ppl'])
            history['val_bpb'].append(val_metrics['bpb'])
            history['lr'].append(lr)
            
            gap = val_metrics['loss'] - train_metrics['loss']
            print(f'\nStep {step+1}: train={train_metrics["loss"]:.3f}, val={val_metrics["loss"]:.3f}, ppl={val_metrics["ppl"]:.2f}, bpb={val_metrics["bpb"]:.3f}, gap={gap:.3f}')
            
            # Best model?
            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                patience_counter = 0
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                print(f'  ✓ New best!')
            else:
                patience_counter += 1
                print(f'  No improvement ({patience_counter}/{PATIENCE})')
            
            # Early stopping
            if patience_counter >= PATIENCE:
                print(f'\n⚠️ Early stopping at step {step+1}')
                break
    
    # Restore best
    if best_state:
        model.load_state_dict(best_state)
    
    return history, best_val_loss

## 4. Train Both Models

**Note**: This will take a while. For quick testing, reduce `MAX_STEPS`.

In [None]:
# Train Original BDH
original_history, original_best = train_model(original_model, 'Original BDH')

In [None]:
# Train Hierarchical BDH
hier_history, hier_best = train_model(hier_model, 'Hierarchical BDH')

## 5. Compare Results

In [None]:
# Final evaluation on test set
print('\n' + '='*60)
print('FINAL TEST SET EVALUATION')
print('='*60)

original_test = evaluate(original_model, test_loader)
hier_test = evaluate(hier_model, test_loader)

print(f'\nOriginal BDH:')
print(f'  Loss: {original_test["loss"]:.4f}')
print(f'  Perplexity: {original_test["ppl"]:.2f}')
print(f'  Bits/Byte: {original_test["bpb"]:.3f}')

print(f'\nHierarchical BDH:')
print(f'  Loss: {hier_test["loss"]:.4f}')
print(f'  Perplexity: {hier_test["ppl"]:.2f}')
print(f'  Bits/Byte: {hier_test["bpb"]:.3f}')

print(f'\nImprovement:')
ppl_improvement = (original_test['ppl'] - hier_test['ppl']) / original_test['ppl'] * 100
bpb_improvement = (original_test['bpb'] - hier_test['bpb']) / original_test['bpb'] * 100
print(f'  Perplexity: {ppl_improvement:+.1f}%')
print(f'  Bits/Byte: {bpb_improvement:+.1f}%')

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
ax = axes[0, 0]
ax.plot(original_history['step'], original_history['val_loss'], 'b-', label='Original BDH', linewidth=2)
ax.plot(hier_history['step'], hier_history['val_loss'], 'r-', label='Hierarchical BDH', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Validation Loss')
ax.set_title('Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Perplexity
ax = axes[0, 1]
ax.plot(original_history['step'], original_history['val_ppl'], 'b-', label='Original BDH', linewidth=2)
ax.plot(hier_history['step'], hier_history['val_ppl'], 'r-', label='Hierarchical BDH', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Perplexity')
ax.set_title('Validation Perplexity')
ax.legend()
ax.grid(True, alpha=0.3)

# Bits per byte
ax = axes[1, 0]
ax.plot(original_history['step'], original_history['val_bpb'], 'b-', label='Original BDH', linewidth=2)
ax.plot(hier_history['step'], hier_history['val_bpb'], 'r-', label='Hierarchical BDH', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Bits per Byte')
ax.set_title('Validation BPB')
ax.legend()
ax.grid(True, alpha=0.3)

# Overfitting gap
ax = axes[1, 1]
original_gap = [v - t for v, t in zip(original_history['val_loss'], original_history['train_loss'])]
hier_gap = [v - t for v, t in zip(hier_history['val_loss'], hier_history['train_loss'])]
ax.plot(original_history['step'], original_gap, 'b-', label='Original BDH', linewidth=2)
ax.plot(hier_history['step'], hier_gap, 'r-', label='Hierarchical BDH', linewidth=2)
ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Step')
ax.set_ylabel('Val - Train Loss')
ax.set_title('Overfitting Gap')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('hierarchical_comparison.png', dpi=150)
plt.show()

## 6. Generation Comparison

Compare text generation quality between models.

In [None]:
def generate_text(model, prompt, max_tokens=200, temperature=0.8):
    """Generate text from prompt."""
    model.eval()
    prompt_bytes = torch.tensor([list(prompt.encode('utf-8'))], device=device)
    
    with torch.no_grad():
        output = model.generate(prompt_bytes, max_new_tokens=max_tokens, temperature=temperature)
    
    # Decode, handling invalid UTF-8
    output_bytes = bytes(output[0].tolist())
    return output_bytes.decode('utf-8', errors='replace')


# Test prompts
prompts = [
    'The history of',
    'In the year 2024,',
    'Scientists discovered that',
    'The quick brown fox',
]

print('='*70)
print('GENERATION COMPARISON')
print('='*70)

for prompt in prompts:
    print(f'\n--- Prompt: "{prompt}" ---')
    
    print('\nOriginal BDH:')
    print(generate_text(original_model, prompt))
    
    print('\nHierarchical BDH:')
    print(generate_text(hier_model, prompt))
    print()

## 7. Save Results

In [None]:
# Save comparison results
results = {
    'timestamp': datetime.now().isoformat(),
    'original': {
        'config': original_config.__dict__,
        'params': original_params,
        'test_loss': original_test['loss'],
        'test_ppl': original_test['ppl'],
        'test_bpb': original_test['bpb'],
        'history': original_history,
    },
    'hierarchical': {
        'config': {k: v for k, v in hier_config.__dict__.items()},
        'params': hier_params,
        'test_loss': hier_test['loss'],
        'test_ppl': hier_test['ppl'],
        'test_bpb': hier_test['bpb'],
        'history': hier_history,
    },
}

results_dir = Path('results')
results_dir.mkdir(exist_ok=True)

with open(results_dir / 'hierarchical_comparison.json', 'w') as f:
    json.dump(results, f, indent=2, default=str)

print(f'Results saved to {results_dir}/hierarchical_comparison.json')

In [None]:
# Save model checkpoints
ckpt_dir = Path('checkpoints')
ckpt_dir.mkdir(exist_ok=True)

torch.save({
    'model_state_dict': original_model.state_dict(),
    'config': original_config.__dict__,
    'test_metrics': original_test,
}, ckpt_dir / 'original_bdh_wikitext2.pt')

torch.save({
    'model_state_dict': hier_model.state_dict(),
    'config': hier_config.__dict__,
    'test_metrics': hier_test,
}, ckpt_dir / 'hierarchical_bdh_wikitext2.pt')

print(f'Checkpoints saved to {ckpt_dir}/')

## Summary

| Metric | Original BDH | Hierarchical BDH | Winner |
|--------|-------------|------------------|--------|
| Parameters | - | - | - |
| Test Loss | - | - | - |
| Test Perplexity | - | - | - |
| Test BPB | - | - | - |

**Observations:**
- (Fill in after running)

**Next Steps:**
- Try different patch sizes (4, 8, 16)
- Test on code datasets (curriculum)
- Scale up model size
- Compare generation coherence qualitatively