# Svend Reasoning Specialist Training

**Complete pipeline: Generate tool traces â†’ Train reasoning specialist**

This notebook:
1. Generates synthetic tool-calling training data (10K+ examples)
2. Fine-tunes the language model for reasoning and tool use

**Setup:**
1. Add `ANTHROPIC_API_KEY` to Colab Secrets (key icon in left sidebar)
2. Run all cells

**Prerequisites:** 
- Trained language model checkpoint in Drive
- ~$30 API budget for 10K tool traces

## 1. Setup and API Key

In [None]:
# Mount Drive and load API key
from google.colab import drive, userdata
drive.mount('/content/drive')

import os
try:
    ANTHROPIC_API_KEY = userdata.get('ANTHROPIC_API_KEY')
    os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
    print(f'Loaded ANTHROPIC_API_KEY (length: {len(ANTHROPIC_API_KEY)})')
except Exception as e:
    ANTHROPIC_API_KEY = None
    print(f'ERROR: {e}')
    print('Add ANTHROPIC_API_KEY to Colab Secrets')

!pip install -q anthropic transformers accelerate datasets wandb

os.chdir('/content')
!rm -rf svend
!git clone https://github.com/ewolters/svend.git
os.chdir('/content/svend')

import sys
sys.path.insert(0, '/content/svend')

In [None]:
#@title Config { display-mode: "form" }
#@markdown ### Paths
DRIVE_BASE = '/content/drive/MyDrive/svend-checkpoints'  #@param {type:"string"}
BASE_CHECKPOINT = f'{DRIVE_BASE}/language-model/checkpoint-200000.pt'  #@param {type:"string"}
OUTPUT_DIR = f'{DRIVE_BASE}/reasoning-specialist'  #@param {type:"string"}

#@markdown ### Data Generation
GENERATE_TOOL_TRACES = True  #@param {type:"boolean"}
NUM_TOOL_TRACES = 10000  #@param {type:"integer"}
TOOL_TRACES_FILE = f'{DRIVE_BASE}/data/tool_traces.jsonl'

#@markdown ### Training
MAX_STEPS = 20000  #@param {type:"integer"}
BATCH_SIZE = 4  #@param {type:"integer"}
GRAD_ACCUM = 8  #@param {type:"integer"}
LEARNING_RATE = 5e-5  #@param {type:"number"}
MAX_SEQ_LENGTH = 2048  #@param {type:"integer"}

#@markdown ### Logging
USE_WANDB = False  #@param {type:"boolean"}

from pathlib import Path
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
Path(TOOL_TRACES_FILE).parent.mkdir(parents=True, exist_ok=True)

print(f'Base checkpoint: {BASE_CHECKPOINT}')
print(f'Output dir: {OUTPUT_DIR}')
print(f'Generate {NUM_TOOL_TRACES} tool traces: {GENERATE_TOOL_TRACES}')

In [None]:
## 2. Generate Tool Traces (if enabled)

In [None]:
import subprocess
from pathlib import Path

if GENERATE_TOOL_TRACES:
    # Check if we already have enough traces
    existing_count = 0
    if Path(TOOL_TRACES_FILE).exists():
        with open(TOOL_TRACES_FILE) as f:
            existing_count = sum(1 for _ in f)
        print(f'Found {existing_count} existing traces')
    
    if existing_count < NUM_TOOL_TRACES:
        print(f'\\n=== Generating {NUM_TOOL_TRACES} tool traces ===')
        print('This may take 30-60 minutes and cost ~$30 in API calls')
        print('='*60)
        
        env = os.environ.copy()
        
        process = subprocess.Popen(
            ['python', 'scripts/generate_tool_data.py', 
             '--num-examples', str(NUM_TOOL_TRACES),
             '--output', TOOL_TRACES_FILE],
            cwd='/content/svend',
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1
        )
        
        for line in process.stdout:
            print(line, end='')
        
        process.wait()
        print('='*60)
        print(f'Exit code: {process.returncode}')
    else:
        print(f'Skipping generation - already have {existing_count} traces')
else:
    print('Tool trace generation disabled')

# Count final traces
if Path(TOOL_TRACES_FILE).exists():
    with open(TOOL_TRACES_FILE) as f:
        final_count = sum(1 for _ in f)
    print(f'\\nTool traces available: {final_count}')

## 3. Setup PyTorch

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import json
import math
import random
from datetime import datetime
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'PyTorch: {torch.__version__}')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print(f'Training dtype: {dtype}')

## 4. Load Base Language Model

In [None]:
from src.models.config import TransformerConfig
from src.models.transformer import ReasoningTransformer

print(f'Loading base model from: {BASE_CHECKPOINT}')
checkpoint = torch.load(BASE_CHECKPOINT, map_location='cpu', weights_only=False)

# Get config from checkpoint
config_dict = checkpoint.get('config', {})
if isinstance(config_dict, dict):
    config_dict = {k: v for k, v in config_dict.items() if k != 'head_dim'}
    base_config = TransformerConfig(**config_dict)
else:
    base_config = config_dict

print(f'Base model: {base_config.hidden_size}h x {base_config.num_hidden_layers}L')
print(f'Parameters: ~{base_config.num_parameters() / 1e6:.0f}M')

In [None]:
# Create model and load weights
model = ReasoningTransformer(base_config)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

params = sum(p.numel() for p in model.parameters())
print(f'Model loaded: {params/1e6:.1f}M parameters')

In [None]:
# Setup tokenizer with special reasoning tokens
tokenizer = AutoTokenizer.from_pretrained(checkpoint.get('tokenizer_name', 'gpt2'))
tokenizer.pad_token = tokenizer.eos_token

# Add special tokens for tool calling
special_tokens = {
    'additional_special_tokens': [
        '<|think|>', '<|/think|>',
        '<|step|>', '<|/step|>',
        '<|tool_call|>', '<|/tool_call|>',
        '<|tool_result|>', '<|/tool_result|>',
        '<|answer|>', '<|/answer|>',
        '<|verify|>', '<|/verify|>',
    ]
}

num_added = tokenizer.add_special_tokens(special_tokens)
print(f'Added {num_added} special tokens')
print(f'Vocab size: {len(tokenizer)}')

In [None]:
# Resize embeddings if vocab changed
if len(tokenizer) > model.embed_tokens.weight.shape[0]:
    old_embeddings = model.embed_tokens.weight.data
    new_vocab_size = len(tokenizer)
    
    new_embeddings = nn.Embedding(new_vocab_size, base_config.hidden_size)
    new_embeddings.weight.data[:old_embeddings.shape[0]] = old_embeddings
    new_embeddings.weight.data[old_embeddings.shape[0]:] = old_embeddings.mean(dim=0)
    model.embed_tokens = new_embeddings.to(device)
    
    if hasattr(model, 'lm_head') and model.lm_head is not None:
        old_lm_head = model.lm_head.weight.data
        new_lm_head = nn.Linear(base_config.hidden_size, new_vocab_size, bias=False)
        new_lm_head.weight.data[:old_lm_head.shape[0]] = old_lm_head
        new_lm_head.weight.data[old_lm_head.shape[0]:] = old_lm_head.mean(dim=0)
        model.lm_head = new_lm_head.to(device)
    
    print(f'Resized embeddings: {old_embeddings.shape[0]} -> {new_vocab_size}')

## 5. Load Training Data

In [None]:
def format_tool_trace(data):
    """Format a tool trace into training text."""
    text = f"Question: {data['question']}\n\n<|think|>\n"
    
    for step in data.get('reasoning', []):
        specialist = step.get('specialist', 'reasoning')
        content = step.get('content', '')
        
        text += f"[{specialist.upper()}] {content}\n"
        
        if 'tool_call' in step:
            tc = step['tool_call']
            text += f"<|tool_call|>{tc.get('name', '')}({json.dumps(tc.get('args', {}))})<|/tool_call|>\n"
        
        if 'tool_result' in step:
            text += f"<|tool_result|>{step['tool_result']}<|/tool_result|>\n"
    
    text += "<|/think|>\n\n"
    text += f"<|answer|>{data.get('answer', '')}<|/answer|>"
    
    return text

# Load tool traces
all_examples = []

if Path(TOOL_TRACES_FILE).exists():
    with open(TOOL_TRACES_FILE, 'r') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                text = format_tool_trace(data)
                all_examples.append({'text': text, 'source': 'tool_trace'})
    print(f'Loaded {len(all_examples)} tool traces')
else:
    print(f'WARNING: No tool traces found at {TOOL_TRACES_FILE}')

# Also load GSM8K for math reasoning
try:
    from datasets import load_dataset
    gsm8k = load_dataset('openai/gsm8k', 'main', split='train')
    for ex in gsm8k:
        question = ex['question']
        answer = ex['answer']
        if '####' in answer:
            reasoning, final = answer.rsplit('####', 1)
        else:
            reasoning, final = answer, ''
        
        text = f"Question: {question}\n\n<|think|>\n{reasoning.strip()}\n<|/think|>\n\n<|answer|>{final.strip()}<|/answer|>"
        all_examples.append({'text': text, 'source': 'gsm8k'})
    print(f'Loaded {len(gsm8k)} GSM8K examples')
except Exception as e:
    print(f'GSM8K load failed: {e}')

print(f'\nTotal training examples: {len(all_examples)}')
random.seed(42)
random.shuffle(all_examples)

In [None]:
# Preview examples
print('Sample tool trace:')
print('='*60)
for ex in all_examples[:2]:
    print(f"[{ex['source']}]")
    print(ex['text'][:500])
    print('...\n' + '-'*60)

In [None]:
## 6. Create Dataset

In [None]:
class ReasoningDataset(Dataset):
    def __init__(self, examples, tokenizer, max_length=2048):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        text = self.examples[idx]['text']
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }

dataset = ReasoningDataset(all_examples, tokenizer, MAX_SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

print(f'Dataset: {len(dataset)} examples')
print(f'Batches: {len(dataloader)}')

## 7. Training Setup

In [None]:
from transformers import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=MAX_STEPS)
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

print(f'Optimizer: AdamW, LR={LEARNING_RATE}')
print(f'Scheduler: Cosine with 500 warmup steps')
print(f'Max steps: {MAX_STEPS}')

In [None]:
if USE_WANDB:
    import wandb
    wandb.init(project='svend-reasoning', name=f'reasoning-{datetime.now().strftime("%m%d-%H%M")}')

In [None]:
def save_checkpoint(step, loss):
    path = f'{OUTPUT_DIR}/checkpoint-{step}.pt'
    torch.save({
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict,
        'tokenizer_name': 'gpt2',
        'special_tokens': special_tokens,
        'loss': loss,
    }, path)
    print(f'Saved: {path}')

## 8. Training Loop

In [None]:
print('='*60)
print('TRAINING REASONING SPECIALIST')
print('='*60)

model.train()
step = 0
epoch = 0
total_loss = 0
best_loss = float('inf')

progress = tqdm(total=MAX_STEPS, desc='Training')

try:
    while step < MAX_STEPS:
        epoch += 1
        
        for batch in dataloader:
            if step >= MAX_STEPS:
                break
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            if scaler:
                with torch.amp.autocast('cuda', dtype=dtype):
                    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs['loss'] / GRAD_ACCUM
                scaler.scale(loss).backward()
            else:
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss'] / GRAD_ACCUM
                loss.backward()
            
            total_loss += loss.item() * GRAD_ACCUM
            
            if (step + 1) % GRAD_ACCUM == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            step += 1
            progress.update(1)
            
            if step % 50 == 0:
                avg_loss = total_loss / 50
                progress.set_postfix({
                    'loss': f'{avg_loss:.4f}',
                    'ppl': f'{math.exp(min(avg_loss, 10)):.1f}',
                    'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                })
                
                if USE_WANDB:
                    wandb.log({'loss': avg_loss, 'lr': scheduler.get_last_lr()[0], 'step': step})
                
                total_loss = 0
            
            if step % 2000 == 0:
                save_checkpoint(step, avg_loss)
                if avg_loss < best_loss:
                    best_loss = avg_loss

except KeyboardInterrupt:
    print('\\nInterrupted!')
    save_checkpoint(step, total_loss / max(1, step % 50))

progress.close()
print(f'\\nTraining complete at step {step}')

## 9. Save Final Model

In [None]:
final_path = f'{OUTPUT_DIR}/final-reasoning-specialist.pt'

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config_dict,
    'tokenizer_name': 'gpt2',
    'special_tokens': special_tokens,
    'training_steps': step,
    'base_model': BASE_CHECKPOINT,
}, final_path)

print(f'Final model saved: {final_path}')

## 10. Test Generation

In [None]:
@torch.no_grad()
def test_reasoning(prompt):
    model.eval()
    full_prompt = f"Question: {prompt}\n\n<|think|>\n"
    input_ids = tokenizer.encode(full_prompt, return_tensors='pt').to(device)
    
    with torch.amp.autocast('cuda', dtype=dtype):
        output = model.generate(
            input_ids,
            max_new_tokens=512,
            temperature=0.3,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=False)

# Test problems
test_problems = [
    "What is the derivative of x^3 + 2x?",
    "A store sells apples for $2 each. If I buy 5 apples, how much do I spend?",
    "Find the area of a circle with radius 5.",
]

print('='*60)
print('TEST GENERATION')
print('='*60)
for p in test_problems:
    print(f'\n{p}')
    print('-'*40)
    response = test_reasoning(p)
    print(response[:600])
    if len(response) > 600:
        print('...')

In [None]:
if USE_WANDB:
    wandb.finish()

print('\n' + '='*60)
print('REASONING SPECIALIST TRAINING COMPLETE')
print('='*60)
print(f'Checkpoint: {final_path}')
print(f'Steps: {step:,}')
print(f'Best loss: {best_loss:.4f}')
print('='*60)