# Svend Fine-Tuning: Reasoning + Tool Use

Fine-tune the base language model on synthetic reasoning data.

**Setup:** 
1. Add `ANTHROPIC_API_KEY` to Colab Secrets (key icon in left sidebar)
2. Run all cells

In [None]:
# Mount Drive, setup, and load API key from Colab Secrets
from google.colab import drive, userdata
drive.mount('/content/drive')

import os
try:
    ANTHROPIC_API_KEY = userdata.get('ANTHROPIC_API_KEY')
    os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
    print('Loaded ANTHROPIC_API_KEY from Colab Secrets')
except Exception as e:
    ANTHROPIC_API_KEY = None
    print(f'Could not load from Secrets: {e}')
    print('Add ANTHROPIC_API_KEY to Colab Secrets (key icon in left sidebar)')

!pip install -q transformers accelerate anthropic

os.chdir('/content')
!git clone https://github.com/ewolters/svend.git 2>/dev/null || (cd svend && git pull)
os.chdir('/content/svend')

import sys
sys.path.insert(0, '/content/svend')

print('\n=== Available checkpoints ===')
!ls -lh /content/drive/MyDrive/svend-checkpoints/language-model/*.pt 2>/dev/null | tail -5

In [None]:
#@title Config { display-mode: "form" }
#@markdown ### Paths
DRIVE_BASE = '/content/drive/MyDrive/svend-checkpoints'  #@param {type:"string"}
BASE_CHECKPOINT = f'{DRIVE_BASE}/language-model/checkpoint-200000.pt'  #@param {type:"string"}
OUTPUT_NAME = 'svend-reasoning-v1.pt'  #@param {type:"string"}
OUTPUT_CHECKPOINT = f'{DRIVE_BASE}/language-model/{OUTPUT_NAME}'

CONVERSATIONS_FILE = '/content/svend/data/conversations.jsonl'
TOOL_TRACES_FILE = '/content/svend/data/tool_traces.jsonl'

#@markdown ### Training params
EPOCHS = 3  #@param {type:"integer"}
BATCH_SIZE = 4  #@param {type:"integer"}
GRAD_ACCUM = 8  #@param {type:"integer"}
LR = 2e-5  #@param {type:"number"}
MAX_LENGTH = 1024  #@param {type:"integer"}

#@markdown ### Data generation (set to 0 to skip)
EXPAND_SEEDS_PER = 50  #@param {type:"integer"}
TOOL_TRACES_COUNT = 500  #@param {type:"integer"}

from pathlib import Path
print('Checking checkpoint...')
print(f'  Base checkpoint: {"OK" if Path(BASE_CHECKPOINT).exists() else "NOT FOUND"} - {BASE_CHECKPOINT}')
print(f'  Output:          {OUTPUT_CHECKPOINT}')

In [None]:
# Option 1: Use seed files directly (no API needed, ~80 examples)
# Option 2: Generate more data with API (set EXPAND_SEEDS_PER > 0 above)

import json
from pathlib import Path

!mkdir -p /content/svend/data

# Load seeds directly as training data
seed_dir = Path('/content/svend/data/seeds')
examples = []

if seed_dir.exists():
    for seed_file in seed_dir.glob('*.jsonl'):
        with open(seed_file, 'r') as f:
            for line in f:
                if line.strip():
                    examples.append(json.loads(line))
    print(f'Loaded {len(examples)} seed examples directly')
    
    # Write to conversations file
    with open(CONVERSATIONS_FILE, 'w') as f:
        for ex in examples:
            f.write(json.dumps(ex) + '\n')
    print(f'Wrote to {CONVERSATIONS_FILE}')
else:
    print(f'ERROR: Seeds not found at {seed_dir}')
    print('Run: cd /content && rm -rf svend && git clone https://github.com/ewolters/svend.git')

# Skip tool traces for now - train on seeds only
print(f'\nTraining will use {len(examples)} examples')
print('(For more data, run expand_seeds.py locally with your API key)')

In [None]:
import torch
import json
from pathlib import Path
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Load base model
from src.models.transformer import ReasoningTransformer
from src.models.config import TransformerConfig

print(f'Loading checkpoint from {BASE_CHECKPOINT}...')
checkpoint = torch.load(BASE_CHECKPOINT, map_location='cpu', weights_only=False)

config_dict = checkpoint.get('config', {})
if isinstance(config_dict, dict):
    config_dict = {k: v for k, v in config_dict.items() if k != 'head_dim'}
    config = TransformerConfig(**config_dict)
else:
    config = config_dict

model = ReasoningTransformer(config)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

params = sum(p.numel() for p in model.parameters())
print(f'Model loaded: {params/1e6:.1f}M parameters')

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f'Tokenizer vocab size: {len(tokenizer)}')

In [None]:
class ReasoningDataset(Dataset):
    def __init__(self, conversations_file, tool_traces_file, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        
        if Path(conversations_file).exists():
            with open(conversations_file, 'r', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line)
                    self.examples.append(f"Q: {data['prompt']}\nA: {data['response']}")
            print(f'Loaded {len(self.examples)} conversations')
        
        n_conv = len(self.examples)
        if Path(tool_traces_file).exists():
            with open(tool_traces_file, 'r', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line)
                    text = f"Q: {data['question']}\n"
                    for step in data.get('reasoning', []):
                        text += f"Step {step.get('step', '?')}: {step.get('content', '')}\n"
                        if 'tool_call' in step:
                            tc = step['tool_call']
                            text += f"<|tool_call|>{tc.get('name', '')}({json.dumps(tc.get('args', {}))})<|/tool_call|>\n"
                        if 'tool_result' in step:
                            text += f"<|tool_result|>{step['tool_result']}<|/tool_result|>\n"
                    text += f"A: {data.get('answer', '')}"
                    self.examples.append(text)
            print(f'Loaded {len(self.examples) - n_conv} tool traces')
        
        print(f'Total examples: {len(self.examples)}')
        if len(self.examples) == 0:
            raise ValueError('No training data found!')
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        encoded = self.tokenizer(self.examples[idx], truncation=True, max_length=self.max_length, padding='max_length', return_tensors='pt')
        input_ids = encoded['input_ids'].squeeze()
        attention_mask = encoded['attention_mask'].squeeze()
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

In [None]:
dataset = ReasoningDataset(CONVERSATIONS_FILE, TOOL_TRACES_FILE, tokenizer, max_length=MAX_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = max(1, len(dataloader) * EPOCHS // GRAD_ACCUM)

from torch.optim.lr_scheduler import OneCycleLR
scheduler = OneCycleLR(optimizer, max_lr=LR, total_steps=total_steps, pct_start=0.1)

print(f'Training examples: {len(dataset)}')
print(f'Batches per epoch: {len(dataloader)}')
print(f'Total steps: {total_steps}')
print(f'Effective batch size: {BATCH_SIZE * GRAD_ACCUM}')

In [None]:
model.train()
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
global_step = 0
best_loss = float('inf')

for epoch in range(EPOCHS):
    epoch_loss = 0
    optimizer.zero_grad()
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for step, batch in enumerate(pbar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        if scaler:
            with torch.amp.autocast('cuda'):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss'] / GRAD_ACCUM  # dict access, not attribute
            scaler.scale(loss).backward()
        else:
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss'] / GRAD_ACCUM
            loss.backward()
        
        epoch_loss += loss.item() * GRAD_ACCUM
        
        if (step + 1) % GRAD_ACCUM == 0:
            if scaler:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1
        
        pbar.set_postfix({'loss': f'{loss.item() * GRAD_ACCUM:.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})
    
    avg_loss = epoch_loss / len(dataloader)
    print(f'Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}')
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        print(f'Saving checkpoint...')
        torch.save({
            'model_state_dict': model.state_dict(),
            'config': config_dict,
            'tokenizer_name': 'gpt2',
            'training_steps': global_step,
            'epoch': epoch + 1,
            'loss': avg_loss,
            'fine_tuned_on': ['conversations', 'tool_traces']
        }, OUTPUT_CHECKPOINT)
        print(f'Saved to {OUTPUT_CHECKPOINT}')

In [None]:
print('\n' + '='*60)
print('TRAINING COMPLETE')
print('='*60)
print(f'Final loss: {best_loss:.4f}')
print(f'Checkpoint: {OUTPUT_CHECKPOINT}')
print('\nTest locally: py -3 scripts/quick_test.py checkpoints/svend-reasoning-v1.pt')