# Train Base Reasoner Model

This notebook trains **just the base reasoning model** - no router, no ensemble.

**Requirements:** Colab Pro+ with A100

**IMPORTANT:** Run cells in order. If you re-clone the repo, restart the runtime (Runtime -> Restart runtime) before continuing.

## 1. Setup

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Mount Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo (fresh each time to get latest)
!rm -rf /content/svend
!git clone https://github.com/ewolters/svend.git /content/svend
%cd /content/svend

# Show commit to verify we have latest
!git log -1 --oneline

In [None]:
# Install dependencies
!pip install -q torch transformers datasets accelerate wandb
!pip install -q sympy sentencepiece tiktoken

In [None]:
# IMPORTANT: Fresh imports - clear any cached modules
import sys

# Remove any cached svend/src modules
modules_to_remove = [key for key in sys.modules.keys() if key.startswith('src')]
for mod in modules_to_remove:
    del sys.modules[mod]

# Add to path
if '/content/svend' not in sys.path:
    sys.path.insert(0, '/content/svend')

# Now import
import torch

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"bf16: {torch.cuda.is_bf16_supported()}")

In [None]:
# Verify model imports work and labels parameter exists
from src.models.config import get_config
from src.models.transformer import ReasoningTransformer
import inspect

# Check that forward() accepts labels
sig = inspect.signature(ReasoningTransformer.forward)
params = list(sig.parameters.keys())
print(f"ReasoningTransformer.forward() parameters: {params}")

if 'labels' in params:
    print("\n[OK] 'labels' parameter found - imports are fresh")
else:
    print("\n[ERROR] 'labels' parameter NOT found!")
    print("Please restart runtime: Runtime -> Restart runtime")
    print("Then re-run all cells from the beginning.")
    raise RuntimeError("Stale imports detected - restart runtime")

## 2. Configuration

In [None]:
# Training config
CONFIG = {
    "model_size": "125m",  # Start small: 125m, 350m, 500m, 1b
    "epochs": 3,
    "batch_size": 8,
    "gradient_accumulation": 4,
    "learning_rate": 5e-5,
    "max_length": 512,
    "warmup_steps": 100,
    
    # Checkpointing
    "save_steps": 500,
    "checkpoint_dir": "/content/drive/MyDrive/svend-checkpoints/base-reasoner",
    
    # WandB
    "use_wandb": True,
    "wandb_project": "svend",
    "run_name": "base-reasoner-125m",
}

# Create checkpoint dir
import os
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

print("Config:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# WandB login (optional - set use_wandb to False to skip)
if CONFIG["use_wandb"]:
    import wandb
    wandb.login()

## 3. Load Data

In [None]:
from datasets import load_dataset, concatenate_datasets

print("Loading datasets...")

# Load reasoning datasets
datasets_to_load = []

# GSM8K - math word problems
try:
    gsm8k = load_dataset("gsm8k", "main", split="train")
    print(f"GSM8K: {len(gsm8k)} examples")
    datasets_to_load.append(("gsm8k", gsm8k))
except Exception as e:
    print(f"GSM8K failed: {e}")

# MATH - harder math problems  
try:
    math_ds = load_dataset("lighteval/MATH", split="train", trust_remote_code=True)
    print(f"MATH: {len(math_ds)} examples")
    datasets_to_load.append(("math", math_ds))
except Exception as e:
    print(f"MATH failed: {e}")

print(f"\nLoaded {len(datasets_to_load)} datasets")

In [None]:
# Prepare data for training
from transformers import AutoTokenizer

# Use GPT-2 tokenizer as base
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def format_gsm8k(example):
    """Format GSM8K for training."""
    return {
        "text": f"Question: {example['question']}\n\nAnswer: {example['answer']}"
    }

def format_math(example):
    """Format MATH dataset for training."""
    return {
        "text": f"Problem: {example['problem']}\n\nSolution: {example['solution']}"
    }

# Format datasets
formatted = []
for name, ds in datasets_to_load:
    if name == "gsm8k":
        formatted.append(ds.map(format_gsm8k, remove_columns=ds.column_names))
    elif name == "math":
        formatted.append(ds.map(format_math, remove_columns=ds.column_names))

# Combine
if formatted:
    train_dataset = concatenate_datasets(formatted)
    print(f"Combined dataset: {len(train_dataset)} examples")
else:
    raise ValueError("No datasets loaded!")

In [None]:
# Tokenize
def tokenize(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=CONFIG["max_length"],
        padding="max_length",
    )

print("Tokenizing...")
tokenized = train_dataset.map(
    tokenize,
    batched=True,
    remove_columns=["text"],
    desc="Tokenizing"
)
tokenized.set_format("torch")

print(f"Tokenized: {len(tokenized)} examples")
print(f"Sample keys: {list(tokenized[0].keys())}")

## 4. Create Model

In [None]:
# Get model config
model_config = get_config(CONFIG["model_size"])
model_config.vocab_size = tokenizer.vocab_size

print(f"Model config:")
print(f"  Hidden size: {model_config.hidden_size}")
print(f"  Layers: {model_config.num_hidden_layers}")
print(f"  Heads: {model_config.num_attention_heads}")
print(f"  Vocab size: {model_config.vocab_size}")

# Create model
model = ReasoningTransformer(model_config)
model = model.cuda()

# Count parameters
params = sum(p.numel() for p in model.parameters())
print(f"\nParameters: {params:,} ({params/1e6:.1f}M)")

In [None]:
# Quick sanity check - verify forward pass with labels works
print("Testing forward pass with labels...")

test_input = torch.randint(0, model_config.vocab_size, (2, 64)).cuda()
test_mask = torch.ones_like(test_input).cuda()

with torch.no_grad():
    outputs = model(
        input_ids=test_input,
        attention_mask=test_mask,
        labels=test_input
    )

print(f"  Output keys: {list(outputs.keys())}")
print(f"  Loss: {outputs['loss'].item():.4f}")
print(f"  Logits shape: {outputs['logits'].shape}")
print("\n[OK] Forward pass with labels works!")

## 5. Training Loop

In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm

# DataLoader
train_loader = DataLoader(
    tokenized,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# Optimizer
optimizer = AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=0.01)

# Scheduler
total_steps = len(train_loader) * CONFIG["epochs"] // CONFIG["gradient_accumulation"]
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)

print(f"Training setup:")
print(f"  Total steps: {total_steps}")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Gradient accumulation: {CONFIG['gradient_accumulation']}")
print(f"  Effective batch: {CONFIG['batch_size'] * CONFIG['gradient_accumulation']}")

In [None]:
# WandB init
if CONFIG["use_wandb"]:
    import wandb
    wandb.init(
        project=CONFIG["wandb_project"],
        name=CONFIG["run_name"],
        config=CONFIG
    )

In [None]:
# Training loop
model.train()
global_step = 0
accumulation_step = 0

# Mixed precision
scaler = torch.amp.GradScaler('cuda')
use_bf16 = torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float16

print(f"\nStarting training...")
print(f"Mixed precision: {dtype}")
print("="*60)

for epoch in range(CONFIG["epochs"]):
    epoch_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    
    for batch in pbar:
        input_ids = batch["input_ids"].cuda()
        attention_mask = batch["attention_mask"].cuda()
        
        # Forward with mixed precision
        with torch.amp.autocast('cuda', dtype=dtype):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids  # Causal LM: predict next token
            )
            loss = outputs["loss"] / CONFIG["gradient_accumulation"]
        
        # Backward
        scaler.scale(loss).backward()
        
        accumulation_step += 1
        epoch_loss += loss.item() * CONFIG["gradient_accumulation"]
        num_batches += 1
        
        # Optimizer step
        if accumulation_step >= CONFIG["gradient_accumulation"]:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            
            global_step += 1
            accumulation_step = 0
            
            # Logging
            avg_loss = epoch_loss / num_batches
            pbar.set_postfix({"loss": f"{avg_loss:.4f}", "step": global_step})
            
            if CONFIG["use_wandb"]:
                wandb.log({
                    "loss": avg_loss,
                    "lr": scheduler.get_last_lr()[0],
                    "step": global_step,
                    "epoch": epoch
                })
            
            # Save checkpoint
            if global_step % CONFIG["save_steps"] == 0:
                ckpt_path = f"{CONFIG['checkpoint_dir']}/step_{global_step:06d}.pt"
                torch.save({
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "global_step": global_step,
                    "epoch": epoch,
                    "config": CONFIG,
                }, ckpt_path)
                print(f"\nSaved checkpoint: {ckpt_path}")
    
    print(f"Epoch {epoch+1} complete. Avg loss: {epoch_loss/num_batches:.4f}")

print("\n" + "="*60)
print("Training complete!")

In [None]:
# Save final model
final_path = f"{CONFIG['checkpoint_dir']}/final.pt"
torch.save({
    "model_state_dict": model.state_dict(),
    "config": CONFIG,
    "model_config": model_config.__dict__ if hasattr(model_config, '__dict__') else str(model_config),
}, final_path)
print(f"Saved final model: {final_path}")

if CONFIG["use_wandb"]:
    wandb.finish()

## 6. Quick Test

In [None]:
# Test generation
model.eval()

test_prompt = "Question: What is 15% of 200?\n\nAnswer:"
inputs = tokenizer(test_prompt, return_tensors="pt").to("cuda")

with torch.no_grad():
    output_ids = model.generate(
        inputs["input_ids"],
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True
    )

response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("Generated:")
print(response)

---
## Resume Training (if disconnected)

If Colab disconnects:
1. Run cells 1-6 (Setup section)
2. Run Config and Data cells
3. Run Create Model cell
4. Then run the cell below with your checkpoint path

In [None]:
# Resume from checkpoint - update this path!
RESUME_FROM = "/content/drive/MyDrive/svend-checkpoints/base-reasoner/step_000500.pt"

checkpoint = torch.load(RESUME_FROM)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
global_step = checkpoint["global_step"]
start_epoch = checkpoint["epoch"]

print(f"Resumed from step {global_step}, epoch {start_epoch}")
print("Now run the training loop cell to continue.")