# Svend Reasoning Specialist Fine-Tuning

**Stage 2 of 2-stage training pipeline**

This notebook fine-tunes the pretrained language model for mathematical reasoning and chain-of-thought.

**Prerequisites:** 
- Trained language model from `train_language_model.ipynb`
- Checkpoint at `/content/drive/MyDrive/svend-checkpoints/language-model/final-language-model.pt`

**Goal:** Learn step-by-step reasoning, math problem solving, and tool calling.

**Output:** Reasoning specialist model ready for deployment.

## 1. Setup Environment

In [None]:
# Clone repository
!git clone https://github.com/ewolters/svend.git 2>/dev/null || echo "Already cloned"
%cd svend

In [None]:
# Clone repository
!git clone https://github.com/ewolters/svend.git reasoning-lab 2>/dev/null || echo "Already cloned"
%cd reasoning-lab

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from datasets import load_dataset
import os
import math
import random
from datetime import datetime
from tqdm.auto import tqdm

import sys
sys.path.insert(0, '/content/svend')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from datasets import load_dataset
import os
import math
import random
from datetime import datetime
from tqdm.auto import tqdm

import sys
sys.path.insert(0, '/content/reasoning-lab')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
CONFIG = {
    # Base model (output from train_language_model.ipynb)
    "base_model_path": "/content/drive/MyDrive/svend-checkpoints/language-model/final-language-model.pt",
    
    # Training
    "max_steps": 20_000,
    "batch_size": 4,
    "gradient_accumulation": 8,  # Effective batch = 32
    "max_seq_length": 2048,
    "learning_rate": 5e-5,  # Lower than pretraining
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "max_grad_norm": 1.0,
    
    # Checkpointing
    "checkpoint_dir": "/content/drive/MyDrive/svend-checkpoints/reasoning-specialist",
    "save_every": 2000,
    "eval_every": 500,
    
    # Logging
    "use_wandb": True,
    "wandb_project": "svend-reasoning",
    "experiment_name": f"reasoning-500m-{datetime.now().strftime('%Y%m%d-%H%M')}",
    
    # Resume
    "resume_from": None,
}

os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

print("Configuration:")
print(f"  Base model: {CONFIG['base_model_path']}")
print(f"  Max steps: {CONFIG['max_steps']:,}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Effective batch: {CONFIG['batch_size'] * CONFIG['gradient_accumulation']}")

## 3. Load Pretrained Language Model

In [None]:
from src.models.config import TransformerConfig
from src.models.transformer import ReasoningTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained language model checkpoint
print(f"Loading base model from: {CONFIG['base_model_path']}")
checkpoint = torch.load(CONFIG["base_model_path"], map_location=device, weights_only=False)

# Get the config from checkpoint
base_config = TransformerConfig(**checkpoint["config"])
print(f"\nBase model config:")
print(f"  Name: {base_config.name}")
print(f"  Parameters: {base_config.num_parameters() / 1e6:.0f}M")
print(f"  Hidden size: {base_config.hidden_size}")
print(f"  Layers: {base_config.num_hidden_layers}")

In [None]:
# Create reasoning specialist config (same architecture, enable tool calling)
reasoning_config = TransformerConfig(
    name="svend-reasoning-500m",
    model_type="reasoning",
    vocab_size=base_config.vocab_size,
    hidden_size=base_config.hidden_size,
    intermediate_size=base_config.intermediate_size,
    num_hidden_layers=base_config.num_hidden_layers,
    num_attention_heads=base_config.num_attention_heads,
    num_key_value_heads=base_config.num_key_value_heads,
    max_position_embeddings=base_config.max_position_embeddings,
    hidden_act=base_config.hidden_act,
    tool_calling=True,
    num_tool_tokens=64,
)

print(f"Reasoning specialist config:")
print(f"  Name: {reasoning_config.name}")
print(f"  Tool calling: {reasoning_config.tool_calling}")
print(f"  Context length: {reasoning_config.max_position_embeddings}")

In [None]:
# Create model and load pretrained weights
model = ReasoningTransformer(reasoning_config)

# Load weights from language model
pretrained_state = checkpoint["model_state_dict"]
model_state = model.state_dict()

loaded_keys = []
skipped_keys = []

for key in pretrained_state:
    if key in model_state and pretrained_state[key].shape == model_state[key].shape:
        model_state[key] = pretrained_state[key]
        loaded_keys.append(key)
    else:
        skipped_keys.append(key)

model.load_state_dict(model_state)
model = model.to(device)

print(f"Loaded {len(loaded_keys)} weight tensors from language model")
if skipped_keys:
    print(f"Skipped {len(skipped_keys)} tensors (shape mismatch)")

In [None]:
# Setup tokenizer with special reasoning tokens
tokenizer = AutoTokenizer.from_pretrained(checkpoint.get("tokenizer_name", "gpt2"))
tokenizer.pad_token = tokenizer.eos_token

# Add special tokens for reasoning
special_tokens = {
    "additional_special_tokens": [
        "<|think|>", "<|/think|>",
        "<|step|>", "<|/step|>",
        "<|tool_call|>", "<|/tool_call|>",
        "<|tool_result|>", "<|/tool_result|>",
        "<|answer|>", "<|/answer|>",
        "<|verify|>", "<|/verify|>",
    ]
}

num_added = tokenizer.add_special_tokens(special_tokens)
print(f"Added {num_added} special tokens")
print(f"Tokenizer vocab size: {len(tokenizer)}")

# Resize embeddings if needed
if len(tokenizer) > model.embed_tokens.weight.shape[0]:
    old_embeddings = model.embed_tokens.weight.data
    new_vocab_size = len(tokenizer)
    
    # Create new embedding layer
    new_embeddings = nn.Embedding(new_vocab_size, reasoning_config.hidden_size)
    new_embeddings.weight.data[:old_embeddings.shape[0]] = old_embeddings
    # Initialize new tokens with mean of existing
    new_embeddings.weight.data[old_embeddings.shape[0]:] = old_embeddings.mean(dim=0)
    
    model.embed_tokens = new_embeddings.to(device)
    
    # Update lm_head if not tied
    if model.lm_head is not None:
        old_lm_head = model.lm_head.weight.data
        new_lm_head = nn.Linear(reasoning_config.hidden_size, new_vocab_size, bias=False)
        new_lm_head.weight.data[:old_lm_head.shape[0]] = old_lm_head
        new_lm_head.weight.data[old_lm_head.shape[0]:] = old_lm_head.mean(dim=0)
        model.lm_head = new_lm_head.to(device)
    
    # Update config
    reasoning_config.vocab_size = new_vocab_size
    print(f"Resized embeddings to {new_vocab_size}")

# Mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print(f"\nModel ready on {device}, dtype: {dtype}")

## 4. Prepare Reasoning Datasets

In [None]:
def format_gsm8k(example):
    """Format GSM8K example with chain-of-thought."""
    question = example["question"]
    answer = example["answer"]
    
    if "####" in answer:
        reasoning, final = answer.rsplit("####", 1)
        reasoning = reasoning.strip()
        final = final.strip()
    else:
        reasoning = answer
        final = ""
    
    return {"text": f"""Question: {question}

<|think|>
{reasoning}
<|/think|>

<|answer|>{final}<|/answer|>"""}


def format_math(example):
    """Format MATH dataset example."""
    problem = example.get("problem", "")
    solution = example.get("solution", "")
    
    return {"text": f"""Problem: {problem}

<|think|>
{solution}
<|/think|>"""}


def format_arc(example):
    """Format ARC example."""
    question = example["question"]
    choices = example["choices"]
    answer_key = example["answerKey"]
    
    choice_text = "\n".join([
        f"{label}. {text}" 
        for label, text in zip(choices["label"], choices["text"])
    ])
    
    try:
        correct_idx = choices["label"].index(answer_key)
        correct_text = choices["text"][correct_idx]
    except ValueError:
        correct_text = answer_key
    
    return {"text": f"""Question: {question}

Choices:
{choice_text}

<|think|>
Let me analyze each option.
The correct answer is {answer_key}: {correct_text}
<|/think|>

<|answer|>{answer_key}<|/answer|>"""}


def format_orca_math(example):
    """Format Orca Math example."""
    question = example.get("question", "")
    answer = example.get("answer", "")
    
    return {"text": f"""Question: {question}

<|think|>
{answer}
<|/think|>"""}


print("Formatting functions defined.")

In [None]:
# Load and format all datasets
all_examples = []

# GSM8K
print("Loading GSM8K...")
try:
    gsm8k = load_dataset("openai/gsm8k", "main", split="train")
    for ex in gsm8k:
        formatted = format_gsm8k(ex)
        all_examples.append({"text": formatted["text"], "source": "gsm8k"})
    print(f"  GSM8K: {len(gsm8k)} examples")
except Exception as e:
    print(f"  GSM8K failed: {e}")

# MATH (competition_math)
print("Loading MATH...")
try:
    math_ds = load_dataset("lighteval/MATH", "all", split="train")
    for ex in math_ds:
        formatted = format_math(ex)
        all_examples.append({"text": formatted["text"], "source": "math"})
    print(f"  MATH: {len(math_ds)} examples")
except Exception as e:
    print(f"  MATH failed: {e}")

# ARC-Challenge
print("Loading ARC...")
try:
    arc = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="train")
    for ex in arc:
        formatted = format_arc(ex)
        all_examples.append({"text": formatted["text"], "source": "arc"})
    print(f"  ARC: {len(arc)} examples")
except Exception as e:
    print(f"  ARC failed: {e}")

# Orca Math (sample 20k)
print("Loading Orca Math...")
try:
    orca = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
    orca = orca.shuffle(seed=42).select(range(min(20000, len(orca))))
    for ex in orca:
        formatted = format_orca_math(ex)
        all_examples.append({"text": formatted["text"], "source": "orca_math"})
    print(f"  Orca Math: {len(orca)} examples")
except Exception as e:
    print(f"  Orca Math failed: {e}")

print(f"\nTotal examples: {len(all_examples)}")

# Count by source
from collections import Counter
source_counts = Counter(ex["source"] for ex in all_examples)
for source, count in source_counts.items():
    print(f"  {source}: {count}")

# Shuffle
random.seed(42)
random.shuffle(all_examples)

In [None]:
# Preview a few examples
print("Sample examples:")
print("="*60)
for i in range(min(3, len(all_examples))):
    ex = all_examples[i]
    print(f"\n[{ex['source'].upper()}]")
    print(ex["text"][:400])
    if len(ex["text"]) > 400:
        print("...")
    print("-"*60)

In [None]:
class ReasoningDataset(Dataset):
    """Dataset for reasoning fine-tuning."""
    
    def __init__(self, examples, tokenizer, max_length=2048):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        text = self.examples[idx]["text"]
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


dataset = ReasoningDataset(all_examples, tokenizer, CONFIG["max_seq_length"])

dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

print(f"Dataset: {len(dataset)} examples")
print(f"Batches per epoch: {len(dataloader)}")
print(f"Steps per epoch: {len(dataloader) // CONFIG['gradient_accumulation']}")

## 5. Setup Training

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
    betas=(0.9, 0.95),
)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG["warmup_steps"],
    num_training_steps=CONFIG["max_steps"],
)

scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))

print("Optimizer and scheduler configured.")

In [None]:
if CONFIG["use_wandb"]:
    import wandb
    wandb.init(
        project=CONFIG["wandb_project"],
        name=CONFIG["experiment_name"],
        config=CONFIG,
    )
    print(f"WandB initialized: {CONFIG['experiment_name']}")

In [None]:
def save_checkpoint(model, optimizer, scheduler, step, loss, path):
    torch.save({
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "loss": loss,
        "config": reasoning_config.to_dict(),
        "tokenizer_name": "gpt2",
        "special_tokens": special_tokens,
    }, path)
    print(f"Saved: {path}")


def load_checkpoint(path, model, optimizer, scheduler):
    ckpt = torch.load(path, map_location=device, weights_only=False)
    model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    scheduler.load_state_dict(ckpt["scheduler_state_dict"])
    return ckpt["step"], ckpt.get("loss", 0)


start_step = 0
if CONFIG["resume_from"]:
    start_step, _ = load_checkpoint(CONFIG["resume_from"], model, optimizer, scheduler)
    print(f"Resumed from step {start_step}")

## 6. Evaluation

In [None]:
@torch.no_grad()
def evaluate_reasoning(model, tokenizer, num_samples=5):
    """Generate sample outputs to check reasoning quality."""
    model.eval()
    
    test_problems = [
        "Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market for $2 per egg. How much does she make every day?",
        "Question: A train travels at 60 mph for 2 hours, then at 80 mph for 1.5 hours. What is the total distance traveled?",
        "Question: If x + 3 = 7, what is the value of x?",
        "Question: A rectangle has a perimeter of 24 cm. If the length is twice the width, what are the dimensions?",
        "Question: There are 5 red balls and 3 blue balls in a bag. What is the probability of drawing a red ball?",
    ]
    
    print("\n" + "="*60)
    print("REASONING EVALUATION")
    print("="*60)
    
    for problem in test_problems[:num_samples]:
        prompt = problem + "\n\n<|think|>\n"
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        
        with torch.amp.autocast("cuda", dtype=dtype):
            output = model.generate(
                input_ids,
                max_new_tokens=256,
                temperature=0.3,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=False)
        print(f"\n{'-'*60}")
        print(response[:600])
        if len(response) > 600:
            print("...")
    
    print("\n" + "="*60)
    model.train()

## 7. Training Loop

In [None]:
print("\n" + "="*60)
print("STARTING REASONING SPECIALIST FINE-TUNING")
print("="*60)
print(f"Base model: {CONFIG['base_model_path']}")
print(f"Max steps: {CONFIG['max_steps']:,}")
print(f"Dataset: {len(dataset)} examples")
print("="*60 + "\n")

model.train()
step = start_step
epoch = 0
total_loss = 0
log_interval = 50

progress = tqdm(total=CONFIG["max_steps"], initial=start_step, desc="Training")

try:
    while step < CONFIG["max_steps"]:
        epoch += 1
        
        for batch in dataloader:
            if step >= CONFIG["max_steps"]:
                break
            
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            with torch.amp.autocast("cuda", dtype=dtype):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs["loss"] / CONFIG["gradient_accumulation"]
            
            scaler.scale(loss).backward()
            total_loss += loss.item() * CONFIG["gradient_accumulation"]
            
            if (step + 1) % CONFIG["gradient_accumulation"] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["max_grad_norm"])
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
            
            step += 1
            progress.update(1)
            
            if step % log_interval == 0:
                avg_loss = total_loss / log_interval
                lr = scheduler.get_last_lr()[0]
                
                progress.set_postfix({
                    "loss": f"{avg_loss:.4f}",
                    "ppl": f"{math.exp(min(avg_loss, 10)):.1f}",
                    "lr": f"{lr:.2e}",
                })
                
                if CONFIG["use_wandb"]:
                    wandb.log({
                        "train/loss": avg_loss,
                        "train/perplexity": math.exp(min(avg_loss, 10)),
                        "train/lr": lr,
                        "train/step": step,
                        "train/epoch": epoch,
                    })
                
                total_loss = 0
            
            if step % CONFIG["eval_every"] == 0:
                evaluate_reasoning(model, tokenizer, num_samples=3)
            
            if step % CONFIG["save_every"] == 0:
                path = os.path.join(CONFIG["checkpoint_dir"], f"checkpoint-{step}.pt")
                save_checkpoint(model, optimizer, scheduler, step, avg_loss, path)

except KeyboardInterrupt:
    print("\nInterrupted. Saving checkpoint...")
    path = os.path.join(CONFIG["checkpoint_dir"], f"checkpoint-{step}-interrupted.pt")
    save_checkpoint(model, optimizer, scheduler, step, total_loss / max(1, step % log_interval), path)

progress.close()
print(f"\nTraining finished at step {step}")

## 8. Save Final Model

In [None]:
final_path = os.path.join(CONFIG["checkpoint_dir"], "final-reasoning-specialist.pt")

torch.save({
    "model_state_dict": model.state_dict(),
    "config": reasoning_config.to_dict(),
    "tokenizer_name": "gpt2",
    "special_tokens": special_tokens,
    "training_steps": step,
    "base_model": CONFIG["base_model_path"],
}, final_path)

print(f"Final model saved: {final_path}")

In [None]:
print("\nFinal Evaluation:")
evaluate_reasoning(model, tokenizer, num_samples=5)

In [None]:
if CONFIG["use_wandb"]:
    wandb.finish()

print("\n" + "="*60)
print("REASONING SPECIALIST TRAINING COMPLETE")
print("="*60)
print(f"Checkpoint: {final_path}")
print(f"Steps: {step:,}")
print(f"Epochs: {epoch}")
print("="*60)

---

## Next Steps

1. **Deploy**: Use `scripts/run_server.py` to serve the model
2. **Evaluate**: Run `scripts/run_unified_eval.py` for benchmarks
3. **Train Verifier**: Optional critic model for self-verification