# Svend Fine-Tuning: Reasoning + Tool Use

Fine-tune the base language model on synthetic reasoning data.

**Inputs:**
- Base checkpoint: `svend-base-200k.pt` (from Drive)
- Training data: `conversations.jsonl` + `tool_traces.jsonl`

**Output:**
- Fine-tuned checkpoint: `svend-reasoning-v1.pt`

In [None]:
# Mount Drive and setup
from google.colab import drive
drive.mount('/content/drive')

!pip install -q transformers torch accelerate

import os
os.chdir('/content')
!git clone https://github.com/ewolters/svend.git 2>/dev/null || (cd svend && git pull)
os.chdir('/content/svend')

import sys
sys.path.insert(0, '/content/svend')

In [None]:
# Config
DRIVE_PATH = '/content/drive/MyDrive/svend'
BASE_CHECKPOINT = f'{DRIVE_PATH}/svend-base-200k.pt'  # Your 200K checkpoint
CONVERSATIONS_FILE = f'{DRIVE_PATH}/data/conversations.jsonl'
TOOL_TRACES_FILE = f'{DRIVE_PATH}/data/tool_traces.jsonl'
OUTPUT_CHECKPOINT = f'{DRIVE_PATH}/svend-reasoning-v1.pt'

# Training params
EPOCHS = 3
BATCH_SIZE = 4
GRAD_ACCUM = 8  # Effective batch = 32
LR = 2e-5
MAX_LENGTH = 1024

In [None]:
import torch
import json
from pathlib import Path
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Load base model
from src.models.transformer import ReasoningTransformer
from src.models.config import TransformerConfig

print(f'Loading checkpoint from {BASE_CHECKPOINT}...')
checkpoint = torch.load(BASE_CHECKPOINT, map_location='cpu')

config_dict = checkpoint.get('config', {})
if isinstance(config_dict, dict):
    # Remove computed fields
    config_dict = {k: v for k, v in config_dict.items() if k != 'head_dim'}
    config = TransformerConfig(**config_dict)
else:
    config = config_dict

model = ReasoningTransformer(config)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# Count params
params = sum(p.numel() for p in model.parameters())
print(f'Model loaded: {params/1e6:.1f}M parameters')

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f'Tokenizer vocab size: {len(tokenizer)}')

In [None]:
# Dataset
class ReasoningDataset(Dataset):
    def __init__(self, conversations_file, tool_traces_file, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        
        # Load conversations (prompt/response format)
        if Path(conversations_file).exists():
            with open(conversations_file, 'r', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line)
                    text = f"Q: {data['prompt']}\nA: {data['response']}"
                    self.examples.append(text)
            print(f'Loaded {len(self.examples)} conversations')
        
        # Load tool traces (reasoning format)
        n_conv = len(self.examples)
        if Path(tool_traces_file).exists():
            with open(tool_traces_file, 'r', encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line)
                    # Format reasoning trace
                    text = f"Q: {data['question']}\n"
                    for step in data.get('reasoning', []):
                        text += f"Step {step.get('step', '?')}: {step.get('content', '')}\n"
                        if 'tool_call' in step:
                            tc = step['tool_call']
                            text += f"<|tool_call|>{tc.get('name', '')}({json.dumps(tc.get('args', {}))})<|/tool_call|>\n"
                        if 'tool_result' in step:
                            text += f"<|tool_result|>{step['tool_result']}<|/tool_result|>\n"
                    text += f"A: {data.get('answer', '')}"
                    self.examples.append(text)
            print(f'Loaded {len(self.examples) - n_conv} tool traces')
        
        print(f'Total examples: {len(self.examples)}')
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        text = self.examples[idx]
        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        input_ids = encoded['input_ids'].squeeze()
        attention_mask = encoded['attention_mask'].squeeze()
        
        # Labels = input_ids (causal LM)
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100  # Ignore padding
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [None]:
# Create dataset and dataloader
dataset = ReasoningDataset(
    CONVERSATIONS_FILE,
    TOOL_TRACES_FILE,
    tokenizer,
    max_length=MAX_LENGTH
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

In [None]:
# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

total_steps = len(dataloader) * EPOCHS // GRAD_ACCUM
warmup_steps = total_steps // 10

from torch.optim.lr_scheduler import OneCycleLR
scheduler = OneCycleLR(
    optimizer,
    max_lr=LR,
    total_steps=total_steps,
    pct_start=0.1
)

print(f'Total steps: {total_steps}')
print(f'Warmup steps: {warmup_steps}')

In [None]:
# Training loop
model.train()
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

global_step = 0
best_loss = float('inf')

for epoch in range(EPOCHS):
    epoch_loss = 0
    optimizer.zero_grad()
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for step, batch in enumerate(pbar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass with mixed precision
        if scaler:
            with torch.amp.autocast('cuda'):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss / GRAD_ACCUM
            scaler.scale(loss).backward()
        else:
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss / GRAD_ACCUM
            loss.backward()
        
        epoch_loss += loss.item() * GRAD_ACCUM
        
        # Gradient accumulation step
        if (step + 1) % GRAD_ACCUM == 0:
            if scaler:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1
        
        pbar.set_postfix({'loss': f'{loss.item() * GRAD_ACCUM:.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})
    
    avg_loss = epoch_loss / len(dataloader)
    print(f'Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}')
    
    # Save best
    if avg_loss < best_loss:
        best_loss = avg_loss
        print(f'New best! Saving checkpoint...')
        torch.save({
            'model_state_dict': model.state_dict(),
            'config': config_dict,
            'tokenizer_name': 'gpt2',
            'training_steps': global_step,
            'epoch': epoch + 1,
            'loss': avg_loss,
            'fine_tuned_on': ['conversations', 'tool_traces']
        }, OUTPUT_CHECKPOINT)
        print(f'Saved to {OUTPUT_CHECKPOINT}')

In [None]:
# Final save
print('\n' + '='*60)
print('TRAINING COMPLETE')
print('='*60)
print(f'Final loss: {best_loss:.4f}')
print(f'Checkpoint saved to: {OUTPUT_CHECKPOINT}')
print('\nDownload and test locally with:')
print('  py -3 scripts/quick_test.py checkpoints/svend-reasoning-v1.pt')