# Hierarchical BDH Training on Code (The Stack - Python)

Train **Hierarchical BDH** on Python code from The Stack dataset.

## Why Code?
- **Highly structured**: Clear syntax, indentation, brackets
- **Long-range dependencies**: Functions reference each other, imports at top
- **Byte-level advantage**: Handles any syntax without tokenization issues
- **Practical application**: Code completion, understanding

## Expected Outcomes
- Lower perplexity than natural language (more predictable syntax)
- Tests hierarchical on structured data
- Foundation for code fine-tuning experiments

## Hardware Requirements
- **Recommended**: A100 40GB+
- **Minimum**: V100 32GB with small model
- Training time: 2-4 hours

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || (cd bdh && git pull)
%cd bdh

In [None]:
!pip install -q torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
from tqdm.auto import tqdm

from bdh_hierarchical import HierarchicalBDH, HierarchicalBDHConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## Load The Stack (Python subset)

The Stack is a massive code dataset. We'll use a Python subset.

**Note**: The Stack requires accepting terms at huggingface.co/datasets/bigcode/the-stack

In [None]:
from datasets import load_dataset

# =============================================================
# DATASET CONFIG
# =============================================================
# Option 1: The Stack (requires HF login and terms acceptance)
# Option 2: CodeParrot (public, no login needed)
USE_THE_STACK = False  # Set True if you have access
SAMPLE_SIZE = 50000  # Number of files to sample
# =============================================================

print('Loading code dataset...')

if USE_THE_STACK:
    # The Stack - requires HF login
    print('Loading The Stack (Python)...')
    dataset = load_dataset(
        'bigcode/the-stack',
        data_dir='data/python',
        split=f'train[:{SAMPLE_SIZE}]',
        trust_remote_code=True
    )
    text_column = 'content'
else:
    # CodeParrot - public dataset
    print('Loading CodeParrot (Python)...')
    dataset = load_dataset(
        'codeparrot/codeparrot-clean',
        split=f'train[:{SAMPLE_SIZE}]'
    )
    text_column = 'content'

print(f'Loaded {len(dataset)} code files')

In [None]:
def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

# Concatenate code files with clear separators
separator = '\n\n# === END OF FILE ===\n\n'

print('Processing code files...')
all_code = separator.join(dataset[text_column])

# Split into train/val/test (90/5/5)
total_len = len(all_code)
train_end = int(0.9 * total_len)
val_end = int(0.95 * total_len)

train_text = all_code[:train_end]
val_text = all_code[train_end:val_end]
test_text = all_code[val_end:]

print('Converting to bytes...')
train_data = text_to_bytes(train_text)
val_data = text_to_bytes(val_text)
test_data = text_to_bytes(test_text)

# Free memory
del all_code, train_text, val_text, test_text, dataset
import gc; gc.collect()

print(f'\nDataset sizes:')
print(f'  Train: {len(train_data):,} bytes ({len(train_data)/1e6:.1f}MB)')
print(f'  Val:   {len(val_data):,} bytes ({len(val_data)/1e6:.1f}MB)')
print(f'  Test:  {len(test_data):,} bytes ({len(test_data)/1e6:.1f}MB)')

## Dataset Class

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size, patch_size=8):
        assert block_size % patch_size == 0
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        return chunk[:-1], chunk[1:]

## Configuration

In [None]:
# =============================================================
MODEL_SIZE = 'base'  # Options: 'tiny', 'small', 'base', 'large'
# =============================================================

preset_map = {
    'tiny': HierarchicalBDHConfig.tiny,
    'small': HierarchicalBDHConfig.small,
    'base': HierarchicalBDHConfig.base,
    'large': HierarchicalBDHConfig.large,
}
model_config = preset_map[MODEL_SIZE](dropout=0.1)

config = {
    'model_size': MODEL_SIZE,
    'patch_size': model_config.patch_size,
    'block_size': 1024,  # Good for function-level context
    
    'batch_size': {'tiny': 64, 'small': 32, 'base': 16, 'large': 8}[MODEL_SIZE],
    'learning_rate': 3e-4,
    'weight_decay': 0.1,
    'max_steps': 15000,
    'warmup_steps': 400,
    
    'gradient_accumulation_steps': 4,
    'val_interval': 500,
    'val_batches': 100,
    'patience': 10,
    'log_interval': 100,
}

print(f'Model: {MODEL_SIZE}')
print(f'Block size: {config["block_size"]}')

## Create Model

In [None]:
model = HierarchicalBDH(model_config).to(device)
params = model.count_parameters()

print(f'\nHierarchical BDH - {MODEL_SIZE.upper()}')
print(f'Global: {model_config.global_n_layer}L x {model_config.global_n_embd}D')
print(f'Local:  {model_config.local_n_layer}L x {model_config.local_n_embd}D')
print(f'Total:  {params["total"]:,} ({params["total"]/1e6:.1f}M)')

## Setup Training

In [None]:
train_dataset = ByteDataset(train_data, config['block_size'], config['patch_size'])
val_dataset = ByteDataset(val_data, config['block_size'], config['patch_size'])

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'], betas=(0.9, 0.95))

def get_lr(step):
    if step < config['warmup_steps']:
        return config['learning_rate'] * step / config['warmup_steps']
    progress = (step - config['warmup_steps']) / (config['max_steps'] - config['warmup_steps'])
    return config['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))

scaler = torch.cuda.amp.GradScaler()
print(f'Train batches: {len(train_loader):,}')

## Training

In [None]:
@torch.no_grad()
def evaluate(model, loader, max_batches=None):
    model.eval()
    total_loss, total_tokens = 0, 0
    for i, (x, y) in enumerate(loader):
        if max_batches and i >= max_batches: break
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    model.train()
    avg_loss = total_loss / total_tokens
    return avg_loss, math.exp(avg_loss), avg_loss / math.log(2)

In [None]:
history = {'step': [], 'train_loss': [], 'val_loss': [], 'val_ppl': [], 'val_bpb': [], 'lr': []}
best_val_loss = float('inf')
patience_counter = 0

ckpt_dir = Path(f'checkpoints_hierarchical_{MODEL_SIZE}_code')
ckpt_dir.mkdir(exist_ok=True)

print(f'Checkpoints: {ckpt_dir}')

In [None]:
model.train()
train_iter = iter(train_loader)
running_loss = 0

pbar = tqdm(range(config['max_steps']), desc='Training')

for step in pbar:
    optimizer.zero_grad()
    accum_loss = 0
    
    for _ in range(config['gradient_accumulation_steps']):
        try:
            x, y = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            x, y = next(train_iter)
        
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            _, loss = model(x, y)
            loss = loss / config['gradient_accumulation_steps']
        scaler.scale(loss).backward()
        accum_loss += loss.item()
    
    lr = get_lr(step)
    for pg in optimizer.param_groups: pg['lr'] = lr
    
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    
    running_loss += accum_loss
    
    if (step + 1) % config['log_interval'] == 0:
        pbar.set_postfix({'loss': f'{running_loss/config["log_interval"]:.3f}', 'lr': f'{lr:.2e}'})
        running_loss = 0
    
    if (step + 1) % config['val_interval'] == 0:
        val_loss, val_ppl, val_bpb = evaluate(model, val_loader, config['val_batches'])
        train_loss, _, _ = evaluate(model, train_loader, config['val_batches'])
        
        history['step'].append(step + 1)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_ppl'].append(val_ppl)
        history['val_bpb'].append(val_bpb)
        history['lr'].append(lr)
        
        gap = val_loss - train_loss
        print(f'\nStep {step+1}: train={train_loss:.3f}, val={val_loss:.3f}, ppl={val_ppl:.2f}, bpb={val_bpb:.3f}, gap={gap:.3f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'model_config': model_config.__dict__,
                'val_loss': val_loss, 'val_ppl': val_ppl, 'val_bpb': val_bpb,
                'dataset': 'code-python',
            }, ckpt_dir / 'best.pt')
            print(f'  ✓ New best!')
        else:
            patience_counter += 1
            print(f'  No improvement ({patience_counter}/{config["patience"]})')
        
        if patience_counter >= config['patience']:
            print(f'\nEarly stopping')
            break

print(f'\nBest val loss: {best_val_loss:.4f}')

## Save & Plot

In [None]:
with open(ckpt_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

with open(ckpt_dir / 'config.json', 'w') as f:
    json.dump({**config, 'model_config': model_config.__dict__, 'parameters': params, 'dataset': 'code-python'}, f, indent=2)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

axes[0].plot(history['step'], history['train_loss'], 'b-', label='Train')
axes[0].plot(history['step'], history['val_loss'], 'r-', label='Val')
axes[0].set_xlabel('Step'); axes[0].set_ylabel('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].plot(history['step'], history['val_ppl'], 'g-')
axes[1].set_xlabel('Step'); axes[1].set_ylabel('Perplexity'); axes[1].grid(True, alpha=0.3)

axes[2].plot(history['step'], history['val_bpb'], 'm-')
axes[2].set_xlabel('Step'); axes[2].set_ylabel('BPB'); axes[2].grid(True, alpha=0.3)

gaps = [v - t for v, t in zip(history['val_loss'], history['train_loss'])]
axes[3].fill_between(history['step'], gaps, alpha=0.5, color='orange')
axes[3].set_xlabel('Step'); axes[3].set_ylabel('Gap'); axes[3].grid(True, alpha=0.3)

plt.suptitle(f'Hierarchical BDH ({MODEL_SIZE}) - Python Code', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(ckpt_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## Evaluate

In [None]:
ckpt = torch.load(ckpt_dir / 'best.pt')
model.load_state_dict(ckpt['model_state_dict'])

print('Validation...')
val_loss, val_ppl, val_bpb = evaluate(model, val_loader)
print(f'Val: loss={val_loss:.4f}, ppl={val_ppl:.2f}, bpb={val_bpb:.3f}')

test_dataset = ByteDataset(test_data, config['block_size'], config['patch_size'])
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

print('\nTest...')
test_loss, test_ppl, test_bpb = evaluate(model, test_loader)
print(f'Test: loss={test_loss:.4f}, ppl={test_ppl:.2f}, bpb={test_bpb:.3f}')

## Code Generation Samples

In [None]:
def generate_code(model, prompt, max_new_tokens=200, temperature=0.7, top_k=40):
    model.eval()
    idx = torch.tensor([list(prompt.encode('utf-8'))], device=device, dtype=torch.long)
    with torch.no_grad():
        output = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    return bytes(output[0].tolist()).decode('utf-8', errors='replace')

# Code prompts
prompts = [
    'def fibonacci(n):\n',
    'class DataLoader:\n    def __init__(self',
    'import torch\nimport torch.nn as nn\n\nclass',
    '# Function to sort a list\ndef',
]

print('=' * 60)
print('CODE GENERATION SAMPLES')
print('=' * 60)

for prompt in prompts:
    print(f'\n--- Prompt ---')
    print(prompt)
    print('--- Generated ---')
    print(generate_code(model, prompt, max_new_tokens=150, temperature=0.7))
    print()

## Download

In [None]:
import subprocess

zip_name = f'hierarchical_bdh_{MODEL_SIZE}_code.zip'
subprocess.run(['zip', '-r', zip_name, str(ckpt_dir)], check=True)

try:
    from google.colab import files
    files.download(zip_name)
    print('✅ Download started!')
except ImportError:
    print(f'Saved: {zip_name}')