# Hierarchical BDH Training on PG-19 (Books)

Train **Hierarchical BDH** on PG-19 - a long-form book corpus.

## Why PG-19?
- **Long-range dependencies**: Full books, not article snippets
- **Standard benchmark**: Used by Transformer-XL, Longformer, etc.
- **Perfect for hierarchical**: Tests global/local architecture on truly long context
- **~11GB** of text from Project Gutenberg books (pre-1919)

## Expected Outcomes
- Tests hierarchical architecture's long-range modeling
- Higher perplexity than WikiText (more diverse, literary style)
- Global model should shine here (cross-chapter coherence)

## Hardware Requirements
- **Recommended**: A100 40GB+ (80GB for large model)
- **Minimum**: V100 32GB with small model
- Training time: 4-8 hours for 20K steps

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || (cd bdh && git pull)
%cd bdh

In [None]:
!pip install -q torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
from datetime import datetime
from tqdm.auto import tqdm

from bdh_hierarchical import HierarchicalBDH, HierarchicalBDHConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
print(f'PyTorch version: {torch.__version__}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## Load PG-19

PG-19 is large (~11GB). We'll stream and sample to manage memory.

**Options:**
- Full dataset: Best results, needs more RAM
- Sampled: Faster iteration, good for testing

In [None]:
from datasets import load_dataset
import random

# =============================================================
# DATASET SIZE CONFIG
# =============================================================
USE_FULL_DATASET = False  # Set True for full training
SAMPLE_SIZE = 1000  # Number of books to sample if not full
# =============================================================

print('Loading PG-19 dataset...')
print('(This may take several minutes for first download)')

if USE_FULL_DATASET:
    dataset = load_dataset('pg19', split='train')
    val_dataset_raw = load_dataset('pg19', split='validation')
    test_dataset_raw = load_dataset('pg19', split='test')
else:
    # Stream and sample for faster iteration
    print(f'Sampling {SAMPLE_SIZE} books...')
    dataset = load_dataset('pg19', split=f'train[:{SAMPLE_SIZE}]')
    val_dataset_raw = load_dataset('pg19', split='validation[:100]')
    test_dataset_raw = load_dataset('pg19', split='test[:100]')

print(f'Train books: {len(dataset)}')

In [None]:
def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

# Concatenate books with separator
print('Processing train split...')
train_text = '\n\n=== BOOK SEPARATOR ===\n\n'.join(dataset['text'])
print('Processing validation split...')
val_text = '\n\n=== BOOK SEPARATOR ===\n\n'.join(val_dataset_raw['text'])
print('Processing test split...')
test_text = '\n\n=== BOOK SEPARATOR ===\n\n'.join(test_dataset_raw['text'])

print('\nConverting to bytes...')
train_data = text_to_bytes(train_text)
val_data = text_to_bytes(val_text)
test_data = text_to_bytes(test_text)

# Free memory
del train_text, val_text, test_text, dataset, val_dataset_raw, test_dataset_raw
import gc; gc.collect()

print(f'\nDataset sizes:')
print(f'  Train: {len(train_data):,} bytes ({len(train_data)/1e9:.2f}GB)')
print(f'  Val:   {len(val_data):,} bytes ({len(val_data)/1e6:.1f}MB)')
print(f'  Test:  {len(test_data):,} bytes ({len(test_data)/1e6:.1f}MB)')

## Dataset Class

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size, patch_size=8):
        assert block_size % patch_size == 0
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

## Configuration

For PG-19, we use **longer context** to capture book structure.

| Size | Params | Context | Batch | VRAM Est. |
|------|--------|---------|-------|-----------|
| small | 30M | 2048 | 16 | ~20GB |
| base | 73M | 2048 | 8 | ~35GB |
| large | 278M | 2048 | 4 | ~60GB+ |

In [None]:
# =============================================================
MODEL_SIZE = 'base'  # Options: 'tiny', 'small', 'base', 'large'
# =============================================================

preset_map = {
    'tiny': HierarchicalBDHConfig.tiny,
    'small': HierarchicalBDHConfig.small,
    'base': HierarchicalBDHConfig.base,
    'large': HierarchicalBDHConfig.large,
}
model_config = preset_map[MODEL_SIZE](dropout=0.1)

# Training config - optimized for long-form text
config = {
    'model_size': MODEL_SIZE,
    'patch_size': model_config.patch_size,
    
    # Longer context for books
    'block_size': 2048,  # 256 patches of 8 bytes
    
    # Training
    'batch_size': {'tiny': 32, 'small': 16, 'base': 8, 'large': 4}[MODEL_SIZE],
    'learning_rate': 3e-4,
    'weight_decay': 0.1,
    'max_steps': 20000,
    'warmup_steps': 500,
    
    'gradient_accumulation_steps': 4,
    
    'val_interval': 500,
    'val_batches': 100,
    'patience': 10,
    'log_interval': 100,
}

assert config['block_size'] % config['patch_size'] == 0
effective_batch = config['batch_size'] * config['gradient_accumulation_steps']

print(f'Model: {MODEL_SIZE}')
print(f'Block size: {config["block_size"]} ({config["block_size"] // config["patch_size"]} patches)')
print(f'Effective batch: {effective_batch}')

## Create Model

In [None]:
model = HierarchicalBDH(model_config).to(device)
params = model.count_parameters()

print('\n' + '-'*50)
print(f'HIERARCHICAL BDH - {MODEL_SIZE.upper()}')
print('-'*50)
print(f'Global: {model_config.global_n_layer}L x {model_config.global_n_embd}D x {model_config.global_n_head}H')
print(f'Local:  {model_config.local_n_layer}L x {model_config.local_n_embd}D x {model_config.local_n_head}H')
print(f'Total:  {params["total"]:,} ({params["total"]/1e6:.1f}M)')

## Setup Training

In [None]:
train_dataset = ByteDataset(train_data, config['block_size'], config['patch_size'])
val_dataset = ByteDataset(val_data, config['block_size'], config['patch_size'])

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

print(f'Train batches: {len(train_loader):,}')
print(f'Val batches: {len(val_loader):,}')

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'], betas=(0.9, 0.95))

def get_lr(step):
    if step < config['warmup_steps']:
        return config['learning_rate'] * step / config['warmup_steps']
    progress = (step - config['warmup_steps']) / (config['max_steps'] - config['warmup_steps'])
    return config['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))

scaler = torch.cuda.amp.GradScaler()
print('Optimizer ready with mixed precision')

## Training Loop

In [None]:
@torch.no_grad()
def evaluate(model, loader, max_batches=None):
    model.eval()
    total_loss = 0
    total_tokens = 0
    for i, (x, y) in enumerate(loader):
        if max_batches and i >= max_batches:
            break
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    model.train()
    avg_loss = total_loss / total_tokens
    return avg_loss, math.exp(avg_loss), avg_loss / math.log(2)

In [None]:
history = {'step': [], 'train_loss': [], 'val_loss': [], 'val_ppl': [], 'val_bpb': [], 'lr': []}
best_val_loss = float('inf')
patience_counter = 0

ckpt_dir = Path(f'checkpoints_hierarchical_{MODEL_SIZE}_pg19')
ckpt_dir.mkdir(exist_ok=True)

print(f'Checkpoints: {ckpt_dir}')
print('Starting training...')

In [None]:
model.train()
train_iter = iter(train_loader)
running_loss = 0

pbar = tqdm(range(config['max_steps']), desc='Training')

for step in pbar:
    optimizer.zero_grad()
    accum_loss = 0
    
    for _ in range(config['gradient_accumulation_steps']):
        try:
            x, y = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            x, y = next(train_iter)
        
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            _, loss = model(x, y)
            loss = loss / config['gradient_accumulation_steps']
        scaler.scale(loss).backward()
        accum_loss += loss.item()
    
    lr = get_lr(step)
    for pg in optimizer.param_groups:
        pg['lr'] = lr
    
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    
    running_loss += accum_loss
    
    if (step + 1) % config['log_interval'] == 0:
        pbar.set_postfix({'loss': f'{running_loss/config["log_interval"]:.3f}', 'lr': f'{lr:.2e}'})
        running_loss = 0
    
    if (step + 1) % config['val_interval'] == 0:
        val_loss, val_ppl, val_bpb = evaluate(model, val_loader, config['val_batches'])
        train_loss, _, _ = evaluate(model, train_loader, config['val_batches'])
        
        history['step'].append(step + 1)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_ppl'].append(val_ppl)
        history['val_bpb'].append(val_bpb)
        history['lr'].append(lr)
        
        gap = val_loss - train_loss
        print(f'\nStep {step+1}: train={train_loss:.3f}, val={val_loss:.3f}, ppl={val_ppl:.2f}, bpb={val_bpb:.3f}, gap={gap:.3f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'model_config': model_config.__dict__,
                'val_loss': val_loss, 'val_ppl': val_ppl, 'val_bpb': val_bpb,
                'dataset': 'pg19',
            }, ckpt_dir / 'best.pt')
            print(f'  ✓ New best!')
        else:
            patience_counter += 1
            print(f'  No improvement ({patience_counter}/{config["patience"]})')
        
        if patience_counter >= config['patience']:
            print(f'\nEarly stopping at step {step+1}')
            break

print('\nTraining complete!')
print(f'Best val loss: {best_val_loss:.4f}')

## Save & Plot

In [None]:
# Save history
with open(ckpt_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

with open(ckpt_dir / 'config.json', 'w') as f:
    json.dump({**config, 'model_config': model_config.__dict__, 'parameters': params, 'dataset': 'pg19'}, f, indent=2)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

axes[0].plot(history['step'], history['train_loss'], 'b-', label='Train')
axes[0].plot(history['step'], history['val_loss'], 'r-', label='Val')
axes[0].set_xlabel('Step'); axes[0].set_ylabel('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].plot(history['step'], history['val_ppl'], 'g-')
axes[1].set_xlabel('Step'); axes[1].set_ylabel('Perplexity'); axes[1].grid(True, alpha=0.3)

axes[2].plot(history['step'], history['val_bpb'], 'm-')
axes[2].set_xlabel('Step'); axes[2].set_ylabel('BPB'); axes[2].grid(True, alpha=0.3)

gaps = [v - t for v, t in zip(history['val_loss'], history['train_loss'])]
axes[3].fill_between(history['step'], gaps, alpha=0.5, color='orange')
axes[3].set_xlabel('Step'); axes[3].set_ylabel('Gap'); axes[3].grid(True, alpha=0.3)

plt.suptitle(f'Hierarchical BDH ({MODEL_SIZE}) - PG-19 Books', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(ckpt_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## Evaluate Best Model

In [None]:
ckpt = torch.load(ckpt_dir / 'best.pt')
model.load_state_dict(ckpt['model_state_dict'])

print('Full validation...')
val_loss, val_ppl, val_bpb = evaluate(model, val_loader)
print(f'Val: loss={val_loss:.4f}, ppl={val_ppl:.2f}, bpb={val_bpb:.3f}')

test_dataset = ByteDataset(test_data, config['block_size'], config['patch_size'])
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

print('\nTest set...')
test_loss, test_ppl, test_bpb = evaluate(model, test_loader)
print(f'Test: loss={test_loss:.4f}, ppl={test_ppl:.2f}, bpb={test_bpb:.3f}')

## Generation Sample

In [None]:
def generate_text(model, prompt, max_new_tokens=300, temperature=0.8, top_k=50):
    model.eval()
    idx = torch.tensor([list(prompt.encode('utf-8'))], device=device, dtype=torch.long)
    with torch.no_grad():
        output = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    return bytes(output[0].tolist()).decode('utf-8', errors='replace')

# Literary prompts appropriate for PG-19
prompts = [
    'It was a dark and stormy night',
    'Chapter 1\n\nThe old house stood',
    'She looked at him with eyes that',
]

print('=' * 60)
print('GENERATION SAMPLES (Literary Style)')
print('=' * 60)

for prompt in prompts:
    print(f'\nPrompt: "{prompt}"')
    print('-' * 40)
    print(generate_text(model, prompt, max_new_tokens=200))
    print()

## Download

In [None]:
import subprocess

zip_name = f'hierarchical_bdh_{MODEL_SIZE}_pg19.zip'
subprocess.run(['zip', '-r', zip_name, str(ckpt_dir)], check=True)

try:
    from google.colab import files
    files.download(zip_name)
    print('✅ Download started!')
except ImportError:
    print(f'Saved: {zip_name}')