# BDH Benchmark Evaluation

Evaluate trained BDH/Hierarchical BDH models on standard benchmarks.

## Benchmarks Included

| Benchmark | What it Tests | Metric |
|-----------|---------------|--------|
| **LAMBADA** | Long-range word prediction | Accuracy, PPL |
| **WikiText-2** | General LM | PPL, BPB |
| **WikiText-103** | Large-scale LM | PPL, BPB |
| **PG-19** | Long-form books | PPL, BPB |
| **1BW** | Billion Word Benchmark | PPL |

## Usage
1. Upload your trained model checkpoint (.zip or .pt)
2. Select which benchmarks to run
3. Get standardized metrics for comparison

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || (cd bdh && git pull)
%cd bdh

In [None]:
!pip install -q torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
import zipfile
import os
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

## Upload Model Checkpoint

Upload your trained model. Supports:
- `.zip` file from training notebooks
- `.pt` checkpoint file directly

In [None]:
# Upload checkpoint
try:
    from google.colab import files
    print('Upload your model checkpoint (.zip or .pt):')
    uploaded = files.upload()
    uploaded_file = list(uploaded.keys())[0]
    print(f'Uploaded: {uploaded_file}')
except ImportError:
    # Not in Colab - specify path manually
    uploaded_file = 'checkpoints_hierarchical_small/best.pt'  # Edit this
    print(f'Using local file: {uploaded_file}')

In [None]:
# Extract if zip
if uploaded_file.endswith('.zip'):
    print('Extracting zip...')
    with zipfile.ZipFile(uploaded_file, 'r') as z:
        z.extractall('.')
    # Find the checkpoint
    for root, dirs, files in os.walk('.'):
        for f in files:
            if f == 'best.pt':
                checkpoint_path = os.path.join(root, f)
                break
    print(f'Found checkpoint: {checkpoint_path}')
else:
    checkpoint_path = uploaded_file

# Load checkpoint
print('Loading checkpoint...')
ckpt = torch.load(checkpoint_path, map_location=device)
print(f'Loaded from step {ckpt.get("step", "unknown")}')
print(f'Original val_loss: {ckpt.get("val_loss", "unknown")}')
print(f'Original val_ppl: {ckpt.get("val_ppl", "unknown")}')
print(f'Dataset: {ckpt.get("dataset", "unknown")}')
print(f'\nModel config: {ckpt["model_config"]}')

## Load Model

In [None]:
# Detect model type and load
model_config = ckpt['model_config']

# Check if hierarchical (has global_n_layer) or standard BDH
is_hierarchical = 'global_n_layer' in model_config

if is_hierarchical:
    print('Detected: Hierarchical BDH')
    from bdh_hierarchical import HierarchicalBDH, HierarchicalBDHConfig
    config = HierarchicalBDHConfig(**model_config)
    model = HierarchicalBDH(config).to(device)
else:
    print('Detected: Standard BDH')
    from bdh import BDH, BDHConfig
    config = BDHConfig(**model_config)
    model = BDH(config).to(device)

model.load_state_dict(ckpt['model_state_dict'])
model.eval()

# Count params
total_params = sum(p.numel() for p in model.parameters())
print(f'Parameters: {total_params:,} ({total_params/1e6:.1f}M)')

# Get patch size
patch_size = getattr(config, 'patch_size', 1)
print(f'Patch size: {patch_size}')

## Evaluation Utilities

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size, patch_size=8):
        # Ensure divisibility
        block_size = (block_size // patch_size) * patch_size
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return max(1, len(self.data) - self.block_size)
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        return chunk[:-1], chunk[1:]

def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

@torch.no_grad()
def evaluate_perplexity(model, loader, desc='Evaluating'):
    """Compute perplexity on a dataset."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for x, y in tqdm(loader, desc=desc):
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    bpb = avg_loss / math.log(2)
    
    return {
        'loss': avg_loss,
        'perplexity': perplexity,
        'bpb': bpb,
        'tokens': total_tokens,
    }

## Benchmark Selection

Choose which benchmarks to run:

In [None]:
# =============================================================
# SELECT BENCHMARKS
# =============================================================
RUN_LAMBADA = True
RUN_WIKITEXT2 = True
RUN_WIKITEXT103 = True
RUN_PG19 = False  # Large, takes time

# Evaluation settings
BLOCK_SIZE = 512  # Context window
BATCH_SIZE = 16
# =============================================================

# Ensure block_size is divisible by patch_size
BLOCK_SIZE = (BLOCK_SIZE // patch_size) * patch_size
print(f'Block size: {BLOCK_SIZE}')

## LAMBADA Benchmark

Tests long-range dependency by predicting final word of passages.
Standard LM benchmark for context understanding.

In [None]:
lambada_results = None

if RUN_LAMBADA:
    from datasets import load_dataset
    
    print('=' * 60)
    print('LAMBADA BENCHMARK')
    print('=' * 60)
    
    print('Loading LAMBADA...')
    lambada = load_dataset('lambada', split='test')
    print(f'Test examples: {len(lambada)}')
    
    # For byte-level: evaluate perplexity on full passages
    # Also compute accuracy on final word prediction
    
    correct = 0
    total = 0
    total_loss = 0
    total_tokens = 0
    
    model.eval()
    
    for example in tqdm(lambada, desc='LAMBADA'):
        text = example['text']
        
        # Split into context and target word
        words = text.rsplit(' ', 1)
        if len(words) != 2:
            continue
        context, target_word = words
        context = context + ' '  # Include space before target
        
        # Encode
        context_bytes = list(context.encode('utf-8'))
        target_bytes = list(target_word.encode('utf-8'))
        full_bytes = context_bytes + target_bytes
        
        # Truncate context if too long
        if len(full_bytes) > BLOCK_SIZE:
            context_bytes = context_bytes[-(BLOCK_SIZE - len(target_bytes)):]
            full_bytes = context_bytes + target_bytes
        
        if len(full_bytes) < 2:
            continue
        
        # Compute loss on target word bytes
        x = torch.tensor([full_bytes[:-1]], device=device, dtype=torch.long)
        y = torch.tensor([full_bytes[1:]], device=device, dtype=torch.long)
        
        with torch.no_grad(), torch.cuda.amp.autocast():
            logits, loss = model(x, y)
        
        # Loss on target portion only
        target_start = len(context_bytes) - 1
        target_logits = logits[0, target_start:]
        target_labels = y[0, target_start:]
        
        # Check if predictions match
        predictions = target_logits.argmax(dim=-1)
        if torch.equal(predictions, target_labels):
            correct += 1
        
        # Accumulate loss
        target_loss = nn.functional.cross_entropy(
            target_logits, target_labels, reduction='sum'
        )
        total_loss += target_loss.item()
        total_tokens += len(target_labels)
        total += 1
    
    accuracy = correct / total * 100
    avg_loss = total_loss / total_tokens
    ppl = math.exp(avg_loss)
    
    lambada_results = {
        'accuracy': accuracy,
        'perplexity': ppl,
        'loss': avg_loss,
        'correct': correct,
        'total': total,
    }
    
    print(f'\nLAMBADA Results:')
    print(f'  Accuracy: {accuracy:.2f}%')
    print(f'  Perplexity: {ppl:.2f}')
    print(f'  ({correct}/{total} correct)')
else:
    print('Skipping LAMBADA')

## WikiText-2 Benchmark

In [None]:
wikitext2_results = None

if RUN_WIKITEXT2:
    from datasets import load_dataset
    
    print('=' * 60)
    print('WIKITEXT-2 BENCHMARK')
    print('=' * 60)
    
    print('Loading WikiText-2...')
    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
    
    test_text = '\n'.join(dataset['test']['text'])
    test_data = text_to_bytes(test_text)
    print(f'Test size: {len(test_data):,} bytes')
    
    test_dataset = ByteDataset(test_data, BLOCK_SIZE, patch_size)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    wikitext2_results = evaluate_perplexity(model, test_loader, 'WikiText-2')
    
    print(f'\nWikiText-2 Results:')
    print(f'  Loss: {wikitext2_results["loss"]:.4f}')
    print(f'  Perplexity: {wikitext2_results["perplexity"]:.2f}')
    print(f'  BPB: {wikitext2_results["bpb"]:.3f}')
    
    del test_data, test_dataset, test_loader
else:
    print('Skipping WikiText-2')

## WikiText-103 Benchmark

In [None]:
wikitext103_results = None

if RUN_WIKITEXT103:
    from datasets import load_dataset
    
    print('=' * 60)
    print('WIKITEXT-103 BENCHMARK')
    print('=' * 60)
    
    print('Loading WikiText-103 test set...')
    dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')
    
    test_text = '\n'.join(dataset['test']['text'])
    test_data = text_to_bytes(test_text)
    print(f'Test size: {len(test_data):,} bytes')
    
    test_dataset = ByteDataset(test_data, BLOCK_SIZE, patch_size)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    wikitext103_results = evaluate_perplexity(model, test_loader, 'WikiText-103')
    
    print(f'\nWikiText-103 Results:')
    print(f'  Loss: {wikitext103_results["loss"]:.4f}')
    print(f'  Perplexity: {wikitext103_results["perplexity"]:.2f}')
    print(f'  BPB: {wikitext103_results["bpb"]:.3f}')
    
    del test_data, test_dataset, test_loader
else:
    print('Skipping WikiText-103')

## PG-19 Benchmark

In [None]:
pg19_results = None

if RUN_PG19:
    from datasets import load_dataset
    
    print('=' * 60)
    print('PG-19 BENCHMARK')
    print('=' * 60)
    
    print('Loading PG-19 test set...')
    dataset = load_dataset('pg19', split='test[:100]')  # Sample for speed
    
    test_text = '\n\n'.join(dataset['text'])
    test_data = text_to_bytes(test_text)
    print(f'Test size: {len(test_data):,} bytes')
    
    test_dataset = ByteDataset(test_data, BLOCK_SIZE, patch_size)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    pg19_results = evaluate_perplexity(model, test_loader, 'PG-19')
    
    print(f'\nPG-19 Results:')
    print(f'  Loss: {pg19_results["loss"]:.4f}')
    print(f'  Perplexity: {pg19_results["perplexity"]:.2f}')
    print(f'  BPB: {pg19_results["bpb"]:.3f}')
    
    del test_data, test_dataset, test_loader
else:
    print('Skipping PG-19')

## Results Summary

In [None]:
print('=' * 70)
print('BENCHMARK RESULTS SUMMARY')
print('=' * 70)
print()
print(f'Model: {"Hierarchical BDH" if is_hierarchical else "BDH"}')
print(f'Parameters: {total_params:,} ({total_params/1e6:.1f}M)')
print(f'Trained on: {ckpt.get("dataset", "unknown")}')
print()
print(f'{"Benchmark":<20} {"Loss":>10} {"PPL":>10} {"BPB":>10} {"Acc":>10}')
print('-' * 70)

results_summary = {}

if lambada_results:
    print(f'{"LAMBADA":<20} {lambada_results["loss"]:>10.4f} {lambada_results["perplexity"]:>10.2f} {"N/A":>10} {lambada_results["accuracy"]:>9.2f}%')
    results_summary['lambada'] = lambada_results

if wikitext2_results:
    print(f'{"WikiText-2":<20} {wikitext2_results["loss"]:>10.4f} {wikitext2_results["perplexity"]:>10.2f} {wikitext2_results["bpb"]:>10.3f} {"N/A":>10}')
    results_summary['wikitext2'] = wikitext2_results

if wikitext103_results:
    print(f'{"WikiText-103":<20} {wikitext103_results["loss"]:>10.4f} {wikitext103_results["perplexity"]:>10.2f} {wikitext103_results["bpb"]:>10.3f} {"N/A":>10}')
    results_summary['wikitext103'] = wikitext103_results

if pg19_results:
    print(f'{"PG-19":<20} {pg19_results["loss"]:>10.4f} {pg19_results["perplexity"]:>10.2f} {pg19_results["bpb"]:>10.3f} {"N/A":>10}')
    results_summary['pg19'] = pg19_results

print('-' * 70)

## Save Results

In [None]:
# Save benchmark results
output = {
    'model_type': 'hierarchical_bdh' if is_hierarchical else 'bdh',
    'parameters': total_params,
    'trained_on': ckpt.get('dataset', 'unknown'),
    'original_val_loss': ckpt.get('val_loss'),
    'original_val_ppl': ckpt.get('val_ppl'),
    'block_size': BLOCK_SIZE,
    'benchmarks': results_summary,
}

output_path = 'benchmark_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2)

print(f'\nResults saved to {output_path}')

# Download
try:
    from google.colab import files
    files.download(output_path)
except ImportError:
    pass

## Reference: Published Results

For comparison with other models:

| Model | Params | WikiText-103 PPL | LAMBADA Acc |
|-------|--------|------------------|-------------|
| GPT-2 Small | 117M | 37.5 | 45.9% |
| GPT-2 Medium | 345M | 26.4 | 55.5% |
| GPT-2 Large | 762M | 22.1 | 60.1% |
| Transformer-XL | 257M | 18.3 | - |
| MEGABYTE (350M) | 350M | - | - |

**Note**: Direct comparison is tricky because:
- BDH uses byte-level (not BPE tokens)
- Perplexity scales differ between tokenizations
- BPB (bits-per-byte) is more comparable across tokenizations