# Hierarchical BDH Training on WikiText-2

Train **Hierarchical BDH** (MEGABYTE-style architecture) on WikiText-2.

## Architecture

```
Input bytes (B, T)
       ↓
Byte Embedding (B, T, D_local)
       ↓
Patch Embedder (B, T/P, D_global)  <- Groups P bytes into patches
       ↓
Global BDH (6L, 512D, 8H)          <- Cross-patch semantics
       ↓
Global-to-Local Adapter            <- Injects global context
       ↓
Local BDH (4L, 256D, 4H)           <- Intra-patch refinement
       ↓
LM Head (B, T, 256)
```

**Key features:**
- Both global and local models use full BDH attention (bottleneck + ReLU + gating)
- Patch size P=8 (power of 2, ~1 token equivalent)
- ~3:1 global/local parameter ratio (following MEGABYTE)
- Size presets: tiny (~3M), small (~30M), base (~73M), large (~278M)

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clone repo
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || (cd bdh && git pull)
%cd bdh

In [None]:
# Install dependencies
!pip install -q torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
from datetime import datetime
from tqdm.auto import tqdm

# Import Hierarchical BDH
from bdh_hierarchical import HierarchicalBDH, HierarchicalBDHConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
print(f'PyTorch version: {torch.__version__}')

## Load WikiText-2

In [None]:
from datasets import load_dataset

print('Loading WikiText-2...')
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

# Concatenate all text
train_text = '\n'.join(dataset['train']['text'])
val_text = '\n'.join(dataset['validation']['text'])
test_text = '\n'.join(dataset['test']['text'])

train_data = text_to_bytes(train_text)
val_data = text_to_bytes(val_text)
test_data = text_to_bytes(test_text)

print(f'Train: {len(train_data):,} bytes ({len(train_data)/1e6:.1f}M)')
print(f'Val: {len(val_data):,} bytes')
print(f'Test: {len(test_data):,} bytes')

## Dataset Class

**Important**: Block size must be divisible by patch size!

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size, patch_size=8):
        # Ensure block_size is divisible by patch_size
        assert block_size % patch_size == 0, \
            f'block_size ({block_size}) must be divisible by patch_size ({patch_size})'
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

## Configuration

Choose a model size preset. Available options:
- `tiny`: ~3M params, good for testing
- `small`: ~30M params, matches original BDH, good for Colab T4
- `base`: ~73M params, good for A100/V100
- `large`: ~278M params, multi-GPU

In [None]:
# =============================================================
# CHOOSE MODEL SIZE HERE
# =============================================================
MODEL_SIZE = 'small'  # Options: 'tiny', 'small', 'base', 'large'
# =============================================================

# Get model config from preset
preset_map = {
    'tiny': HierarchicalBDHConfig.tiny,
    'small': HierarchicalBDHConfig.small,
    'base': HierarchicalBDHConfig.base,
    'large': HierarchicalBDHConfig.large,
}
model_config = preset_map[MODEL_SIZE](dropout=0.2)

# Training config
config = {
    # Model (from preset)
    'model_size': MODEL_SIZE,
    'patch_size': model_config.patch_size,
    
    # Sequence
    'block_size': 512,  # Must be divisible by patch_size
    
    # Training
    'batch_size': 32 if MODEL_SIZE in ['tiny', 'small'] else 16,
    'learning_rate': 3e-4,
    'weight_decay': 0.1,
    'max_steps': 5000,
    'warmup_steps': 200,
    
    # Validation
    'val_interval': 100,
    'val_batches': 50,
    
    # Early stopping
    'patience': 10,
    
    # Logging
    'log_interval': 50,
}

# Validate block_size
assert config['block_size'] % config['patch_size'] == 0, \
    f"block_size must be divisible by patch_size"

print(f'Model size: {MODEL_SIZE}')
print(f'Patch size: {config["patch_size"]}')
print(f'Block size: {config["block_size"]} ({config["block_size"] // config["patch_size"]} patches)')
print()
print('Training config:')
for k, v in config.items():
    print(f'  {k}: {v}')

## Create Model

In [None]:
# Create model
model = HierarchicalBDH(model_config).to(device)

# Count parameters
params = model.count_parameters()

print(f'\n{"-"*50}')
print(f'HIERARCHICAL BDH - {MODEL_SIZE.upper()}')
print(f'{"-"*50}')
print(f'Global model: {model_config.global_n_layer}L x {model_config.global_n_embd}D x {model_config.global_n_head}H')
print(f'Local model:  {model_config.local_n_layer}L x {model_config.local_n_embd}D x {model_config.local_n_head}H')
print(f'Patch size:   {model_config.patch_size}')
print()
print(f'Parameters:')
print(f'  Global:    {params["global"]:>12,}')
print(f'  Local:     {params["local"]:>12,}')
print(f'  Embedding: {params["embedding"]:>12,}')
print(f'  LM Head:   {params["lm_head"]:>12,}')
print(f'  {"-"*25}')
print(f'  Total:     {params["total"]:>12,} ({params["total"]/1e6:.1f}M)')
print(f'\nGlobal/Local ratio: {params["global_local_ratio"]:.2f}')

## Setup Training

In [None]:
# Dataloaders (num_workers=0 for Colab compatibility)
train_dataset = ByteDataset(train_data, config['block_size'], config['patch_size'])
val_dataset = ByteDataset(val_data, config['block_size'], config['patch_size'])

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=0,  # Colab-safe
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

print(f'Train batches: {len(train_loader):,}')
print(f'Val batches: {len(val_loader):,}')

In [None]:
# Optimizer with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'],
    betas=(0.9, 0.95),
)

# Learning rate scheduler with warmup + cosine decay
def get_lr(step):
    # Warmup
    if step < config['warmup_steps']:
        return config['learning_rate'] * step / config['warmup_steps']
    # Cosine decay
    progress = (step - config['warmup_steps']) / (config['max_steps'] - config['warmup_steps'])
    return config['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))

print('Optimizer: AdamW')
print(f'  lr: {config["learning_rate"]}')
print(f'  weight_decay: {config["weight_decay"]}')
print(f'  warmup_steps: {config["warmup_steps"]}')

## Training Loop

In [None]:
@torch.no_grad()
def evaluate(model, loader, max_batches=None):
    """Evaluate model on dataloader."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for i, (x, y) in enumerate(loader):
        if max_batches and i >= max_batches:
            break
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    
    model.train()
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    bpb = avg_loss / math.log(2)  # Bits per byte
    return avg_loss, perplexity, bpb

In [None]:
# Training state
history = {
    'step': [],
    'train_loss': [],
    'val_loss': [],
    'val_ppl': [],
    'val_bpb': [],
    'lr': [],
}

best_val_loss = float('inf')
patience_counter = 0
step = 0

# Checkpoints dir
ckpt_dir = Path(f'checkpoints_hierarchical_{MODEL_SIZE}')
ckpt_dir.mkdir(exist_ok=True)

print(f'Checkpoints will be saved to: {ckpt_dir}')
print('\nStarting training...')
print('=' * 60)

In [None]:
model.train()
train_iter = iter(train_loader)
running_loss = 0

pbar = tqdm(range(config['max_steps']), desc='Training')

for step in pbar:
    # Get batch (cycle through data)
    try:
        x, y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x, y = next(train_iter)
    
    x, y = x.to(device), y.to(device)
    
    # Update learning rate
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Forward + backward
    _, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    running_loss += loss.item()
    
    # Logging
    if (step + 1) % config['log_interval'] == 0:
        avg_loss = running_loss / config['log_interval']
        pbar.set_postfix({'loss': f'{avg_loss:.3f}', 'lr': f'{lr:.2e}'})
        running_loss = 0
    
    # Validation
    if (step + 1) % config['val_interval'] == 0:
        val_loss, val_ppl, val_bpb = evaluate(model, val_loader, config['val_batches'])
        train_loss, _, _ = evaluate(model, train_loader, config['val_batches'])
        
        history['step'].append(step + 1)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_ppl'].append(val_ppl)
        history['val_bpb'].append(val_bpb)
        history['lr'].append(lr)
        
        gap = val_loss - train_loss
        print(f'\nStep {step+1}: train={train_loss:.3f}, val={val_loss:.3f}, ppl={val_ppl:.2f}, bpb={val_bpb:.3f}, gap={gap:.3f}')
        
        # Best model?
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            
            # Save best
            torch.save({
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'model_config': model_config.__dict__,
                'val_loss': val_loss,
                'val_ppl': val_ppl,
                'val_bpb': val_bpb,
            }, ckpt_dir / 'best.pt')
            print(f'  ✓ New best! Saved to {ckpt_dir}/best.pt')
        else:
            patience_counter += 1
            print(f'  No improvement ({patience_counter}/{config["patience"]})')
        
        # Early stopping
        if patience_counter >= config['patience']:
            print(f'\n⚠️ Early stopping at step {step+1}')
            break

print('\n' + '=' * 60)
print('Training complete!')
print(f'Best val loss: {best_val_loss:.4f}')

## Save Results

In [None]:
# Save training history
history_path = ckpt_dir / 'training_history.json'
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f'Saved history to {history_path}')

# Save config
full_config = {
    **config,
    'model_config': model_config.__dict__,
    'parameters': params,
}
config_path = ckpt_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(full_config, f, indent=2)
print(f'Saved config to {config_path}')

## Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

# Loss curves
ax = axes[0]
ax.plot(history['step'], history['train_loss'], 'b-', label='Train', linewidth=2)
ax.plot(history['step'], history['val_loss'], 'r-', label='Val', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Loss Curves')
ax.legend()
ax.grid(True, alpha=0.3)

# Perplexity
ax = axes[1]
ax.plot(history['step'], history['val_ppl'], 'g-', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Perplexity')
ax.set_title('Validation Perplexity')
ax.grid(True, alpha=0.3)

# Bits per byte
ax = axes[2]
ax.plot(history['step'], history['val_bpb'], 'm-', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('BPB')
ax.set_title('Bits Per Byte')
ax.grid(True, alpha=0.3)

# Overfitting gap
ax = axes[3]
gaps = [v - t for v, t in zip(history['val_loss'], history['train_loss'])]
ax.fill_between(history['step'], gaps, alpha=0.5, color='orange')
ax.plot(history['step'], gaps, 'orange', linewidth=2)
ax.axhline(y=0.5, color='red', linestyle='--', label='Overfitting threshold')
ax.set_xlabel('Step')
ax.set_ylabel('Val - Train Loss')
ax.set_title('Overfitting Gap')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle(f'Hierarchical BDH ({MODEL_SIZE}) - WikiText-2', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(ckpt_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'Saved plot to {ckpt_dir}/training_curves.png')

## Evaluate Best Model

In [None]:
# Load best checkpoint
ckpt = torch.load(ckpt_dir / 'best.pt')
model.load_state_dict(ckpt['model_state_dict'])

# Full validation
print('Evaluating best model on full validation set...')
val_loss, val_ppl, val_bpb = evaluate(model, val_loader)
print(f'Val Loss: {val_loss:.4f}')
print(f'Val Perplexity: {val_ppl:.2f}')
print(f'Val BPB: {val_bpb:.3f}')

# Test set
test_dataset = ByteDataset(test_data, config['block_size'], config['patch_size'])
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=0)

print('\nEvaluating on test set...')
test_loss, test_ppl, test_bpb = evaluate(model, test_loader)
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Perplexity: {test_ppl:.2f}')
print(f'Test BPB: {test_bpb:.3f}')

## Test Generation

In [None]:
def generate_text(model, prompt, max_new_tokens=200, temperature=0.8, top_k=50):
    """Generate text from a prompt."""
    model.eval()
    
    # Encode prompt
    prompt_bytes = list(prompt.encode('utf-8'))
    idx = torch.tensor([prompt_bytes], device=device, dtype=torch.long)
    
    # Generate
    with torch.no_grad():
        output = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    
    # Decode
    output_bytes = output[0].tolist()
    try:
        text = bytes(output_bytes).decode('utf-8', errors='replace')
    except:
        text = ''.join(chr(b) if 32 <= b < 127 else '?' for b in output_bytes)
    
    return text

# Test prompts
prompts = [
    'The history of',
    'In 1920, the',
    'Scientists discovered',
    'The capital of France',
]

print('=' * 60)
print('GENERATION SAMPLES')
print('=' * 60)

for prompt in prompts:
    print(f'\nPrompt: "{prompt}"')
    print('-' * 40)
    output = generate_text(model, prompt, max_new_tokens=150, temperature=0.8)
    print(output)
    print()

## Compare with Baseline BDH

If you have results from the baseline BDH training, compare here:

In [None]:
print('=' * 60)
print('COMPARISON SUMMARY')
print('=' * 60)
print()
print(f'Hierarchical BDH ({MODEL_SIZE}):')
print(f'  Parameters: {params["total"]:,} ({params["total"]/1e6:.1f}M)')
print(f'  Test Loss:  {test_loss:.4f}')
print(f'  Test PPL:   {test_ppl:.2f}')
print(f'  Test BPB:   {test_bpb:.3f}')
print()
print('Baseline BDH (from previous run):')
print('  Parameters: ~25M')
print('  Test PPL:   ~3.19')
print('  Test BPB:   ~1.67')
print()
print('→ Hierarchical BDH should show improved coherence at similar perplexity,')
print('  due to multi-scale processing (global cross-patch + local intra-patch).')

## Download Checkpoint

In [None]:
import subprocess

# Create zip file
zip_name = f'hierarchical_bdh_{MODEL_SIZE}_wikitext2.zip'
subprocess.run(['zip', '-r', zip_name, str(ckpt_dir)], check=True)
print(f'Created: {zip_name}')

# Auto-download in Colab
try:
    from google.colab import files
    print('Downloading checkpoint...')
    files.download(zip_name)
    print('\n✅ Download started! Check your browser downloads.')
except ImportError:
    print('Not running in Colab.')
    print(f'Checkpoints saved to: {ckpt_dir}')
    print(f'Zip file: {zip_name}')