# Fine-tuning Hierarchical BDH

Fine-tune a pretrained Hierarchical BDH model on a new domain.

## Fine-tuning Strategies

| Strategy | What's Trained | Use Case |
|----------|----------------|----------|
| **Full** | All parameters | Maximum adaptation |
| **Local-only** | Local model + cross-attn | Domain style, preserve knowledge |
| **Global-only** | Global model | Document structure |
| **Head-only** | LM head only | Quick adaptation |
| **LoRA-style** | Added adapters | Parameter efficient |

## Workflow
1. Upload pretrained checkpoint
2. Choose fine-tuning strategy
3. Upload/select fine-tuning data
4. Train and evaluate
5. Download fine-tuned model

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || (cd bdh && git pull)
%cd bdh

In [None]:
!pip install -q torch datasets tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import math
import json
import zipfile
import os
from tqdm.auto import tqdm

from bdh_hierarchical import HierarchicalBDH, HierarchicalBDHConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

## Load Pretrained Model

Upload your pretrained checkpoint from a training run.

In [None]:
# Upload pretrained checkpoint
try:
    from google.colab import files
    print('Upload your pretrained checkpoint (.zip or .pt):')
    uploaded = files.upload()
    uploaded_file = list(uploaded.keys())[0]
except ImportError:
    uploaded_file = 'checkpoints_hierarchical_small/best.pt'  # Edit for local
    print(f'Using: {uploaded_file}')

In [None]:
# Extract if zip
if uploaded_file.endswith('.zip'):
    with zipfile.ZipFile(uploaded_file, 'r') as z:
        z.extractall('.')
    for root, dirs, files_list in os.walk('.'):
        for f in files_list:
            if f == 'best.pt':
                checkpoint_path = os.path.join(root, f)
                break
else:
    checkpoint_path = uploaded_file

print(f'Loading: {checkpoint_path}')
ckpt = torch.load(checkpoint_path, map_location=device)

print(f'Pretrained on: {ckpt.get("dataset", "unknown")}')
print(f'Original PPL: {ckpt.get("val_ppl", "unknown")}')

In [None]:
# Create model from config
model_config = HierarchicalBDHConfig(**ckpt['model_config'])
model = HierarchicalBDH(model_config).to(device)
model.load_state_dict(ckpt['model_state_dict'])

params = model.count_parameters()
print(f'\nModel loaded: {params["total"]:,} params ({params["total"]/1e6:.1f}M)')
print(f'  Global: {params["global_model"]:,}')
print(f'  Local:  {params["local_model"]:,}')

## Choose Fine-tuning Strategy

In [None]:
# =============================================================
# FINE-TUNING STRATEGY
# =============================================================
STRATEGY = 'local_only'  # Options: 'full', 'local_only', 'global_only', 'head_only'
# =============================================================

def freeze_params(module):
    for p in module.parameters():
        p.requires_grad = False

def unfreeze_params(module):
    for p in module.parameters():
        p.requires_grad = True

# Apply strategy
if STRATEGY == 'full':
    print('Strategy: FULL fine-tuning (all parameters)')
    # Everything trainable by default
    
elif STRATEGY == 'local_only':
    print('Strategy: LOCAL-ONLY (freeze global, train local + cross-attn)')
    freeze_params(model.global_model)
    freeze_params(model.patch_embedder)
    # Local model and lm_head remain trainable
    
elif STRATEGY == 'global_only':
    print('Strategy: GLOBAL-ONLY (freeze local, train global)')
    freeze_params(model.local_model)
    freeze_params(model.lm_head)
    # Global model and patch_embedder remain trainable
    
elif STRATEGY == 'head_only':
    print('Strategy: HEAD-ONLY (freeze all, train only LM head)')
    freeze_params(model)
    unfreeze_params(model.lm_head)

# Count trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'\nTrainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')

## Load Fine-tuning Data

Options:
1. Upload your own text file
2. Use a HuggingFace dataset
3. Use built-in examples (code, legal, medical)

In [None]:
# =============================================================
# DATA SOURCE
# =============================================================
DATA_SOURCE = 'huggingface'  # Options: 'upload', 'huggingface', 'example'

# For huggingface:
HF_DATASET = 'wikitext'
HF_CONFIG = 'wikitext-2-raw-v1'
HF_SPLIT = 'train'
HF_TEXT_COLUMN = 'text'
HF_SAMPLE_SIZE = 10000  # Number of examples to use

# For example:
EXAMPLE_DOMAIN = 'code'  # Options: 'code', 'legal', 'scientific'
# =============================================================

In [None]:
def text_to_bytes(text):
    return torch.tensor(list(text.encode('utf-8')), dtype=torch.long)

if DATA_SOURCE == 'upload':
    # Upload custom text file
    try:
        from google.colab import files
        print('Upload your training data (.txt):')
        uploaded_data = files.upload()
        data_file = list(uploaded_data.keys())[0]
        with open(data_file, 'r') as f:
            train_text = f.read()
    except ImportError:
        with open('data/finetune_data.txt', 'r') as f:
            train_text = f.read()

elif DATA_SOURCE == 'huggingface':
    from datasets import load_dataset
    print(f'Loading {HF_DATASET}/{HF_CONFIG}...')
    
    if HF_CONFIG:
        dataset = load_dataset(HF_DATASET, HF_CONFIG, split=f'{HF_SPLIT}[:{HF_SAMPLE_SIZE}]')
    else:
        dataset = load_dataset(HF_DATASET, split=f'{HF_SPLIT}[:{HF_SAMPLE_SIZE}]')
    
    train_text = '\n'.join(dataset[HF_TEXT_COLUMN])
    print(f'Loaded {len(dataset)} examples')

elif DATA_SOURCE == 'example':
    from datasets import load_dataset
    
    if EXAMPLE_DOMAIN == 'code':
        print('Loading Python code...')
        dataset = load_dataset('codeparrot/codeparrot-clean', split='train[:5000]')
        train_text = '\n\n# === FILE ===\n\n'.join(dataset['content'])
    elif EXAMPLE_DOMAIN == 'legal':
        print('Loading legal text...')
        dataset = load_dataset('pile-of-law/pile-of-law', 'r_legaladvice', split='train[:2000]')
        train_text = '\n\n'.join(dataset['text'])
    elif EXAMPLE_DOMAIN == 'scientific':
        print('Loading scientific papers...')
        dataset = load_dataset('scientific_papers', 'arxiv', split='train[:1000]')
        train_text = '\n\n'.join(dataset['article'])

# Convert to bytes
print('Converting to bytes...')
all_data = text_to_bytes(train_text)

# Split 90/10
split_idx = int(0.9 * len(all_data))
train_data = all_data[:split_idx]
val_data = all_data[split_idx:]

print(f'\nData sizes:')
print(f'  Train: {len(train_data):,} bytes ({len(train_data)/1e6:.1f}MB)')
print(f'  Val:   {len(val_data):,} bytes ({len(val_data)/1e6:.1f}MB)')

## Dataset & Training Setup

In [None]:
class ByteDataset(Dataset):
    def __init__(self, data, block_size, patch_size=8):
        block_size = (block_size // patch_size) * patch_size
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return max(1, len(self.data) - self.block_size)
    
    def __getitem__(self, idx):
        chunk = self.data[idx:idx + self.block_size + 1]
        return chunk[:-1], chunk[1:]

In [None]:
# =============================================================
# FINE-TUNING CONFIG
# =============================================================
config = {
    'block_size': 512,
    'batch_size': 16,
    'learning_rate': 1e-4,  # Lower than pretraining!
    'weight_decay': 0.01,
    'max_steps': 2000,
    'warmup_steps': 100,
    'val_interval': 200,
    'val_batches': 50,
    'patience': 5,
    'log_interval': 50,
}
# =============================================================

patch_size = model_config.patch_size
config['block_size'] = (config['block_size'] // patch_size) * patch_size

print('Fine-tuning config:')
for k, v in config.items():
    print(f'  {k}: {v}')

In [None]:
# Create dataloaders
train_dataset = ByteDataset(train_data, config['block_size'], patch_size)
val_dataset = ByteDataset(val_data, config['block_size'], patch_size)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=0, pin_memory=True)

# Only optimize trainable parameters
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'],
)

def get_lr(step):
    if step < config['warmup_steps']:
        return config['learning_rate'] * step / config['warmup_steps']
    progress = (step - config['warmup_steps']) / (config['max_steps'] - config['warmup_steps'])
    return config['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))

print(f'Train batches: {len(train_loader)}')

## Fine-tuning Loop

In [None]:
@torch.no_grad()
def evaluate(model, loader, max_batches=None):
    model.eval()
    total_loss, total_tokens = 0, 0
    for i, (x, y) in enumerate(loader):
        if max_batches and i >= max_batches: break
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    model.train()
    avg_loss = total_loss / total_tokens
    return avg_loss, math.exp(avg_loss), avg_loss / math.log(2)

In [None]:
# Evaluate before fine-tuning
print('Evaluating pretrained model on new domain...')
pre_loss, pre_ppl, pre_bpb = evaluate(model, val_loader, config['val_batches'])
print(f'Before fine-tuning: loss={pre_loss:.4f}, ppl={pre_ppl:.2f}, bpb={pre_bpb:.3f}')

In [None]:
history = {'step': [], 'train_loss': [], 'val_loss': [], 'val_ppl': [], 'val_bpb': [], 'lr': []}
best_val_loss = float('inf')
patience_counter = 0

ckpt_dir = Path(f'checkpoints_finetuned_{STRATEGY}')
ckpt_dir.mkdir(exist_ok=True)

print(f'\nFine-tuning with {STRATEGY} strategy...')
print('=' * 60)

In [None]:
model.train()
train_iter = iter(train_loader)
running_loss = 0

pbar = tqdm(range(config['max_steps']), desc='Fine-tuning')

for step in pbar:
    try:
        x, y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x, y = next(train_iter)
    
    x, y = x.to(device), y.to(device)
    
    lr = get_lr(step)
    for pg in optimizer.param_groups: pg['lr'] = lr
    
    _, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    running_loss += loss.item()
    
    if (step + 1) % config['log_interval'] == 0:
        pbar.set_postfix({'loss': f'{running_loss/config["log_interval"]:.3f}', 'lr': f'{lr:.2e}'})
        running_loss = 0
    
    if (step + 1) % config['val_interval'] == 0:
        val_loss, val_ppl, val_bpb = evaluate(model, val_loader, config['val_batches'])
        train_loss, _, _ = evaluate(model, train_loader, config['val_batches'])
        
        history['step'].append(step + 1)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_ppl'].append(val_ppl)
        history['val_bpb'].append(val_bpb)
        history['lr'].append(lr)
        
        gap = val_loss - train_loss
        print(f'\nStep {step+1}: train={train_loss:.3f}, val={val_loss:.3f}, ppl={val_ppl:.2f}, gap={gap:.3f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'step': step + 1,
                'model_state_dict': model.state_dict(),
                'model_config': model_config.__dict__,
                'val_loss': val_loss, 'val_ppl': val_ppl,
                'finetune_strategy': STRATEGY,
                'pretrained_on': ckpt.get('dataset', 'unknown'),
            }, ckpt_dir / 'best.pt')
            print(f'  ✓ New best!')
        else:
            patience_counter += 1
            print(f'  No improvement ({patience_counter}/{config["patience"]})')
        
        if patience_counter >= config['patience']:
            print('\nEarly stopping')
            break

print(f'\nFine-tuning complete! Best val loss: {best_val_loss:.4f}')

## Results Comparison

In [None]:
# Load best fine-tuned model
ft_ckpt = torch.load(ckpt_dir / 'best.pt')
model.load_state_dict(ft_ckpt['model_state_dict'])

# Evaluate
post_loss, post_ppl, post_bpb = evaluate(model, val_loader)

print('=' * 60)
print('FINE-TUNING RESULTS')
print('=' * 60)
print(f'Strategy: {STRATEGY}')
print(f'Trainable params: {trainable:,} ({100*trainable/total:.1f}%)')
print()
print(f'{"Metric":<15} {"Before":>12} {"After":>12} {"Change":>12}')
print('-' * 55)
print(f'{"Loss":<15} {pre_loss:>12.4f} {post_loss:>12.4f} {post_loss-pre_loss:>+12.4f}')
print(f'{"Perplexity":<15} {pre_ppl:>12.2f} {post_ppl:>12.2f} {post_ppl-pre_ppl:>+12.2f}')
print(f'{"BPB":<15} {pre_bpb:>12.3f} {post_bpb:>12.3f} {post_bpb-pre_bpb:>+12.3f}')
print()
improvement = (pre_ppl - post_ppl) / pre_ppl * 100
print(f'Perplexity improved by {improvement:.1f}%')

## Plot Training

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Loss
axes[0].axhline(y=pre_loss, color='gray', linestyle='--', label='Pre-FT')
axes[0].plot(history['step'], history['train_loss'], 'b-', label='Train')
axes[0].plot(history['step'], history['val_loss'], 'r-', label='Val')
axes[0].set_xlabel('Step'); axes[0].set_ylabel('Loss')
axes[0].legend(); axes[0].grid(True, alpha=0.3)

# PPL
axes[1].axhline(y=pre_ppl, color='gray', linestyle='--', label='Pre-FT')
axes[1].plot(history['step'], history['val_ppl'], 'g-', label='Val')
axes[1].set_xlabel('Step'); axes[1].set_ylabel('Perplexity')
axes[1].legend(); axes[1].grid(True, alpha=0.3)

# Gap
gaps = [v - t for v, t in zip(history['val_loss'], history['train_loss'])]
axes[2].fill_between(history['step'], gaps, alpha=0.5, color='orange')
axes[2].set_xlabel('Step'); axes[2].set_ylabel('Gap')
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'Fine-tuning ({STRATEGY})', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(ckpt_dir / 'finetune_curves.png', dpi=150)
plt.show()

## Test Generation

In [None]:
def generate(model, prompt, max_tokens=150, temperature=0.8):
    model.eval()
    idx = torch.tensor([list(prompt.encode('utf-8'))], device=device, dtype=torch.long)
    with torch.no_grad():
        out = model.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=40)
    return bytes(out[0].tolist()).decode('utf-8', errors='replace')

# Test prompts
prompts = [
    'The ',
    'In the ',
]

print('Generation samples:')
for p in prompts:
    print(f'\nPrompt: "{p}"')
    print('-' * 40)
    print(generate(model, p))

## Download Fine-tuned Model

In [None]:
# Save config
with open(ckpt_dir / 'finetune_config.json', 'w') as f:
    json.dump({
        'strategy': STRATEGY,
        'trainable_params': trainable,
        'total_params': total,
        'pre_ppl': pre_ppl,
        'post_ppl': post_ppl,
        'improvement_pct': improvement,
        **config,
    }, f, indent=2)

with open(ckpt_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Zip and download
import subprocess
zip_name = f'finetuned_{STRATEGY}.zip'
subprocess.run(['zip', '-r', zip_name, str(ckpt_dir)], check=True)

try:
    from google.colab import files
    files.download(zip_name)
    print('✅ Download started!')
except ImportError:
    print(f'Saved: {zip_name}')