# BDH Training on WikiText-2

Train BDH on WikiText-2 (~2M tokens) to reduce overfitting observed on Shakespeare (1M tokens).

**Changes from Shakespeare run:**
- 2x more training data
- Increased dropout (0.1 → 0.2)
- Added weight decay
- More frequent validation
- Early stopping patience

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clone repo (or upload files)
!git clone https://github.com/newsbubbles/bdh.git
%cd bdh

In [None]:
# Install dependencies
!pip install torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
from datetime import datetime
from tqdm.auto import tqdm

# Import BDH
from bdh import BDH, BDHConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Load WikiText-2

In [None]:
from datasets import load_dataset

print('Loading WikiText-2...')
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

# Concatenate all text
train_text = '\n'.join(dataset['train']['text'])
val_text = '\n'.join(dataset['validation']['text'])
test_text = '\n'.join(dataset['test']['text'])

train_data = text_to_bytes(train_text)
val_data = text_to_bytes(val_text)
test_data = text_to_bytes(test_text)

print(f'Train: {len(train_data):,} bytes')
print(f'Val: {len(val_data):,} bytes')
print(f'Test: {len(test_data):,} bytes')

## Dataset Class

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

## Config - Improved Regularization

In [None]:
# Training config
config = {
    # Model
    'n_layer': 6,
    'n_head': 8,
    'n_embd': 256,
    'block_size': 512,
    'vocab_size': 256,
    'dropout': 0.2,  # Increased from 0.1
    'bias': False,
    
    # Training
    'batch_size': 32,
    'learning_rate': 3e-4,
    'weight_decay': 0.1,  # Added weight decay
    'max_steps': 5000,
    'warmup_steps': 200,
    
    # Validation
    'val_interval': 100,  # More frequent
    'val_batches': 50,
    
    # Early stopping
    'patience': 10,  # Stop if no improvement for 10 evals
    
    # Logging
    'log_interval': 50,
}

print('Config:')
for k, v in config.items():
    print(f'  {k}: {v}')

## Create Model

In [None]:
model_config = BDHConfig(
    n_layer=config['n_layer'],
    n_head=config['n_head'],
    n_embd=config['n_embd'],
    block_size=config['block_size'],
    vocab_size=config['vocab_size'],
    dropout=config['dropout'],
    bias=config['bias'],
)

model = BDH(model_config).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {n_params:,} ({n_params/1e6:.1f}M)')

## Setup Training

In [None]:
# Dataloaders
train_dataset = ByteDataset(train_data, config['block_size'])
val_dataset = ByteDataset(val_data, config['block_size'])

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

print(f'Train batches: {len(train_loader):,}')
print(f'Val batches: {len(val_loader):,}')

In [None]:
# Optimizer with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'],
    betas=(0.9, 0.95),
)

# Learning rate scheduler with warmup + cosine decay
def get_lr(step):
    # Warmup
    if step < config['warmup_steps']:
        return config['learning_rate'] * step / config['warmup_steps']
    # Cosine decay
    progress = (step - config['warmup_steps']) / (config['max_steps'] - config['warmup_steps'])
    return config['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))

print('Optimizer: AdamW with weight_decay={}'.format(config['weight_decay']))

## Training Loop with Early Stopping

In [None]:
@torch.no_grad()
def evaluate(model, loader, max_batches=None):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for i, (x, y) in enumerate(loader):
        if max_batches and i >= max_batches:
            break
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    
    model.train()
    avg_loss = total_loss / total_tokens
    return avg_loss, math.exp(avg_loss)

In [None]:
# Training state
history = {
    'step': [],
    'train_loss': [],
    'val_loss': [],
    'val_ppl': [],
    'lr': [],
}

best_val_loss = float('inf')
patience_counter = 0
step = 0

# Checkpoints dir
ckpt_dir = Path('checkpoints_wikitext2')
ckpt_dir.mkdir(exist_ok=True)

print('Starting training...')
print('=' * 60)

In [None]:
model.train()
train_iter = iter(train_loader)
running_loss = 0

pbar = tqdm(range(config['max_steps']), desc='Training')

for step in pbar:
    # Get batch (cycle through data)
    try:
        x, y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x, y = next(train_iter)
    
    x, y = x.to(device), y.to(device)
    
    # Update learning rate
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Forward + backward
    _, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    running_loss += loss.item()
    
    # Logging
    if (step + 1) % config['log_interval'] == 0:
        avg_loss = running_loss / config['log_interval']
        pbar.set_postfix({'loss': f'{avg_loss:.3f}', 'lr': f'{lr:.2e}'})
        running_loss = 0
    
    # Validation
    if (step + 1) % config['val_interval'] == 0:
        val_loss, val_ppl = evaluate(model, val_loader, config['val_batches'])
        train_loss, _ = evaluate(model, train_loader, config['val_batches'])
        
        history['step'].append(step + 1)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_ppl'].append(val_ppl)
        history['lr'].append(lr)
        
        gap = val_loss - train_loss
        print(f'\nStep {step+1}: train={train_loss:.3f}, val={val_loss:.3f}, ppl={val_ppl:.2f}, gap={gap:.3f}')
        
        # Best model?
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            
            # Save best
            torch.save({
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'model_config': model_config.__dict__,
                'val_loss': val_loss,
                'val_ppl': val_ppl,
            }, ckpt_dir / 'best.pt')
            print(f'  ✓ New best! Saved to {ckpt_dir}/best.pt')
        else:
            patience_counter += 1
            print(f'  No improvement ({patience_counter}/{config["patience"]})')
        
        # Early stopping
        if patience_counter >= config['patience']:
            print(f'\n⚠️ Early stopping at step {step+1}')
            break

print('\n' + '=' * 60)
print('Training complete!')
print(f'Best val loss: {best_val_loss:.4f}')

## Save Final Results

In [None]:
# Save training history
history_path = ckpt_dir / 'training_history.json'
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f'Saved history to {history_path}')

# Save config
config_path = ckpt_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)
print(f'Saved config to {config_path}')

## Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss curves
ax = axes[0]
ax.plot(history['step'], history['train_loss'], 'b-', label='Train', linewidth=2)
ax.plot(history['step'], history['val_loss'], 'r-', label='Val', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Loss Curves')
ax.legend()
ax.grid(True, alpha=0.3)

# Perplexity
ax = axes[1]
ax.plot(history['step'], history['val_ppl'], 'g-', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Perplexity')
ax.set_title('Validation Perplexity')
ax.grid(True, alpha=0.3)

# Overfitting gap
ax = axes[2]
gaps = [v - t for v, t in zip(history['val_loss'], history['train_loss'])]
ax.fill_between(history['step'], gaps, alpha=0.5, color='orange')
ax.plot(history['step'], gaps, 'orange', linewidth=2)
ax.axhline(y=0.5, color='red', linestyle='--', label='Overfitting threshold')
ax.set_xlabel('Step')
ax.set_ylabel('Val - Train Loss')
ax.set_title('Overfitting Gap')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(ckpt_dir / 'training_curves.png', dpi=150)
plt.show()
print(f'Saved plot to {ckpt_dir}/training_curves.png')

## Evaluate Best Model

In [None]:
# Load best checkpoint
ckpt = torch.load(ckpt_dir / 'best.pt')
model.load_state_dict(ckpt['model_state_dict'])

# Full validation
print('Evaluating best model on full validation set...')
val_loss, val_ppl = evaluate(model, val_loader)
print(f'Val Loss: {val_loss:.4f}')
print(f'Val Perplexity: {val_ppl:.2f}')

# Test set
test_dataset = ByteDataset(test_data, config['block_size'])
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

print('\nEvaluating on test set...')
test_loss, test_ppl = evaluate(model, test_loader)
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Perplexity: {test_ppl:.2f}')

## Download Checkpoint

Run this to download the trained model:

In [None]:
from google.colab import files

# Zip checkpoints
!zip -r checkpoints_wikitext2.zip checkpoints_wikitext2/

# Download
files.download('checkpoints_wikitext2.zip')