# BDH Curriculum Training

This notebook trains BDH through a graduated curriculum:

| Phase | Content | Target Examples | Training Ratio |
|-------|---------|-----------------|----------------|
| 1 | Primitives (simple functions) | 10,000 | 1.5x |
| 2 | Composition (loops, recursion) | 20,000 | 1.2x |
| 3 | Algorithms (sorting, searching) | 30,000 | 1.0x |
| 4 | Systems (OOP, patterns) | 15,000 | 1.5x |
| 5 | Language (WikiText-2) | ~50,000 | 0.8x |

**Training Ratios** ensure balanced learning - smaller phases get more epochs per example.

## 0. Setup Environment

In [None]:
# Clone repo and install dependencies
!git clone https://github.com/newsbubbles/bdh.git 2>/dev/null || echo 'Repo exists'
%cd bdh
!pip install -q torch transformers datasets tqdm

In [None]:
# Check GPU
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## 1. Generate Curriculum Data

This uses the proper generators in `scripts/data_pipeline/` to create:
- **Phase 1-4**: Synthetically generated Python code with variations
- **Phase 5**: WikiText-2 natural language

Use `--quick` for testing, remove it for full training.

In [None]:
# Generate curriculum data (use --quick for testing, remove for full)
!python scripts/generate_full_curriculum.py --quick

In [None]:
# Verify data was generated
!python scripts/generate_full_curriculum.py --stats

## 2. Load Model and Tokenizer

In [None]:
import sys
sys.path.insert(0, 'src')

from model import BDHModel
from tokenizer import BDHTokenizer

# Initialize tokenizer
tokenizer = BDHTokenizer()
print(f'Vocab size: {tokenizer.vocab_size}')

In [None]:
from config import BDHConfig

# Model config - adjust based on GPU memory
config = BDHConfig(
    vocab_size=tokenizer.vocab_size,
    d_model=512,
    n_heads=8,
    n_layers=8,
    d_ff=2048,
    max_seq_len=512,
    dropout=0.1,
)

model = BDHModel(config)
model = model.cuda()

# Count parameters
params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {params:,} ({params/1e6:.1f}M)')

## 3. Curriculum Training Loop

Training proceeds through phases with:
- **Balanced epochs**: Smaller phases get more epochs (via training ratio)
- **LR warmup per phase**: Fresh warmup when switching phases
- **Checkpoint saving**: Best model saved per phase

In [None]:
import json
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm.auto import tqdm

class CurriculumDataset(Dataset):
    """Dataset for curriculum JSONL files."""
    
    def __init__(self, jsonl_path, tokenizer, max_len=512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.examples = []
        
        with open(jsonl_path) as f:
            for line in f:
                data = json.loads(line)
                # Handle both 'code' and 'text' fields
                text = data.get('code') or data.get('text', '')
                if text:
                    self.examples.append(text)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        text = self.examples[idx]
        tokens = self.tokenizer.encode(text)[:self.max_len]
        
        # Pad to max_len
        if len(tokens) < self.max_len:
            tokens = tokens + [self.tokenizer.pad_token_id] * (self.max_len - len(tokens))
        
        tokens = torch.tensor(tokens, dtype=torch.long)
        return tokens[:-1], tokens[1:]  # input, target

In [None]:
# Training configuration
TRAINING_CONFIG = {
    'batch_size': 16,
    'base_lr': 1e-4,
    'warmup_steps': 100,
    'weight_decay': 0.01,
    'max_grad_norm': 1.0,
    'base_epochs': 5,  # Base epochs, multiplied by training ratio
}

# Phase configuration with file paths and training ratios
PHASES = [
    {
        'name': 'Phase 1: Primitives',
        'file': 'data/curriculum/phase1_primitives/phase1_primitives.jsonl',
        'ratio': 1.5,  # 1.5x epochs
    },
    {
        'name': 'Phase 2: Composition',
        'file': 'data/curriculum/phase2_composition/phase2_composition.jsonl',
        'ratio': 1.2,
    },
    {
        'name': 'Phase 3: Algorithms',
        'file': 'data/curriculum/phase3_algorithms/phase3_algorithms.jsonl',
        'ratio': 1.0,
    },
    {
        'name': 'Phase 4: Systems',
        'file': 'data/curriculum/phase4_systems/phase4_systems.jsonl',
        'ratio': 1.5,
    },
    {
        'name': 'Phase 5: Language',
        'file': 'data/curriculum/phase5_language/phase5_train.jsonl',
        'ratio': 0.8,
    },
]

In [None]:
def train_phase(model, phase_config, tokenizer, training_config, device='cuda'):
    """Train model on a single curriculum phase."""
    
    phase_name = phase_config['name']
    jsonl_path = Path(phase_config['file'])
    ratio = phase_config['ratio']
    
    if not jsonl_path.exists():
        print(f'WARNING: {jsonl_path} not found, skipping phase')
        return {}
    
    # Calculate epochs for this phase
    epochs = int(training_config['base_epochs'] * ratio)
    
    print('\n' + '='*60)
    print(f'{phase_name}')
    print('='*60)
    
    # Load dataset
    dataset = CurriculumDataset(jsonl_path, tokenizer)
    print(f'Loaded {len(dataset):,} examples')
    print(f'Training for {epochs} epochs (ratio: {ratio}x)')
    
    # Split train/val (90/10)
    val_size = max(1, len(dataset) // 10)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    # DataLoaders (num_workers=0 for Colab compatibility)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=training_config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=training_config['batch_size'],
        num_workers=0,
        pin_memory=True
    )
    
    # Optimizer with fresh state for each phase
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=training_config['base_lr'],
        weight_decay=training_config['weight_decay']
    )
    
    # LR scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = min(training_config['warmup_steps'], total_steps // 5)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        return 1.0
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Training loop
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'lr': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            
            logits = model(inputs)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=tokenizer.pad_token_id
            )
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 
                training_config['max_grad_norm']
            )
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })
        
        avg_train_loss = total_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(device)
                targets = targets.to(device)
                logits = model(inputs)
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    targets.view(-1),
                    ignore_index=tokenizer.pad_token_id
                )
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['lr'].append(scheduler.get_last_lr()[0])
        
        print(f'Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, val_loss={avg_val_loss:.4f}')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            checkpoint_path = f'checkpoints/best_{phase_name.lower().replace(" ", "_").replace(":", "")}.pt'
            Path('checkpoints').mkdir(exist_ok=True)
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': model.config.__dict__,
                'phase': phase_name,
                'val_loss': best_val_loss,
            }, checkpoint_path)
            print(f'  → New best model saved!')
    
    return history

In [None]:
# Run full curriculum training
print('#' * 60)
print('STARTING CURRICULUM TRAINING')
print('#' * 60)

all_history = {}

for phase in PHASES:
    history = train_phase(model, phase, tokenizer, TRAINING_CONFIG)
    all_history[phase['name']] = history

print('\n' + '#' * 60)
print('CURRICULUM TRAINING COMPLETE')
print('#' * 60)

## 4. Save Final Model

In [None]:
# Save final model
final_path = 'checkpoints/bdh_curriculum_final.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'config': model.config.__dict__,
    'training_history': all_history,
}, final_path)
print(f'Final model saved to {final_path}')

## 5. Test Generation

In [None]:
def generate(model, tokenizer, prompt, max_tokens=100, temperature=0.8):
    """Generate text from prompt."""
    model.eval()
    tokens = tokenizer.encode(prompt)
    tokens = torch.tensor([tokens], device='cuda')
    
    with torch.no_grad():
        for _ in range(max_tokens):
            logits = model(tokens)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            tokens = torch.cat([tokens, next_token], dim=1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(tokens[0].tolist())

# Test prompts
prompts = [
    'def add(',
    'def fibonacci(',
    'class Stack:',
    'The quick brown',
]

for prompt in prompts:
    print(f'\n--- Prompt: {repr(prompt)} ---')
    output = generate(model, tokenizer, prompt)
    print(output)