# Svend Ensemble Training on Colab

Train the 4-model specialist ensemble:

| Model | Size | Role |
|-------|------|------|
| **Router** | 125M | Intent classification, routes to specialists |
| **Language** | 500M | Prompt interpretation, synthesis, output |
| **Reasoning** | 500M | Math, logic, chain-of-thought, tools |
| **Verifier** | 250M | Checks answers, catches errors |

**Total: ~1.4B parameters across 4 specialist models**

This is more efficient than a single 7B model and allows:
- Faster inference (router picks which model runs)
- Specialized training per domain
- Mix-and-match during inference

In [None]:
# Check GPU
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    
    if gpu_memory > 30:
        print("\n[OK] Sufficient GPU memory for ensemble training!")
    else:
        print("\n[OK] Can train models sequentially")
else:
    print("[X] No GPU - training will be very slow")

In [None]:
# Install dependencies
!pip install -q torch transformers datasets accelerate wandb sympy

# Clone the repository
!git clone https://github.com/YOUR_USERNAME/reasoning-lab.git
%cd reasoning-lab

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
import os
CHECKPOINT_DIR = '/content/drive/MyDrive/svend-ensemble'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")

In [None]:
import sys
sys.path.insert(0, '.')

from src.models import (
    create_model,
    create_router_config,
    create_language_specialist_config,
    create_reasoning_specialist_config,
    create_verifier_specialist_config,
    create_ensemble_config,
)
from src.data import create_tokenizer

print("Imports successful!")

## Ensemble Overview

In [None]:
# Print ensemble summary
ensemble = create_ensemble_config()
ensemble.print_summary()

## Configuration

In [None]:
# Training configuration
CONFIG = {
    # Which models to train (set to False to skip)
    "train_router": True,
    "train_language": True,
    "train_reasoning": True,
    "train_verifier": True,
    
    # Data
    "num_samples": 50000,  # Per model (None for full dataset)
    "max_seq_length": 2048,
    
    # Training (per model)
    "epochs": 3,
    "batch_size": 8,
    "gradient_accumulation": 4,
    "warmup_ratio": 0.05,
    
    # Efficiency
    "gradient_checkpointing": True,
    "bf16": True,
    
    # Saving
    "save_steps": 500,
    "output_dir": CHECKPOINT_DIR,
    
    # Logging
    "use_wandb": True,
    "wandb_project": "svend-ensemble",
}

# Model-specific learning rates
LEARNING_RATES = {
    "router": 5e-4,      # Smaller model, can use higher LR
    "language": 1e-4,
    "reasoning": 1e-4,
    "verifier": 3e-4,    # Smaller model
}

print("Configuration set!")
print(f"Training: Router={CONFIG['train_router']}, Language={CONFIG['train_language']}, "
      f"Reasoning={CONFIG['train_reasoning']}, Verifier={CONFIG['train_verifier']}")

## Load Tokenizer

In [None]:
# Load tokenizer (shared across all models)
tokenizer = create_tokenizer(
    base_tokenizer="mistralai/Mistral-7B-v0.1",
    vocab_size=32000,
    add_reasoning_tokens=True,
)
print(f"Tokenizer vocabulary: {len(tokenizer)} tokens")

## Load Training Data

In [None]:
from src.data import DatasetConfig, create_combined_dataset, ReasoningDataset, create_dataloaders

print("Loading datasets...")
data_config = DatasetConfig(max_seq_length=CONFIG["max_seq_length"])

try:
    dataset = create_combined_dataset(data_config)
    
    if CONFIG["num_samples"] and len(dataset) > CONFIG["num_samples"]:
        dataset = dataset.select(range(CONFIG["num_samples"]))
    
    print(f"Total examples: {len(dataset)}")
    
    # Split
    train_size = int(0.95 * len(dataset))
    train_data = dataset.select(range(train_size))
    val_data = dataset.select(range(train_size, len(dataset)))
    
    print(f"Train: {len(train_data)}, Val: {len(val_data)}")
    
    train_dataset = ReasoningDataset(train_data, tokenizer, max_length=CONFIG["max_seq_length"])
    val_dataset = ReasoningDataset(val_data, tokenizer, max_length=CONFIG["max_seq_length"])
    
except Exception as e:
    print(f"Warning: Could not load full dataset: {e}")
    print("Using synthetic data for testing...")
    # Fallback to synthetic

## Training Function

In [None]:
from src.training import TrainingConfig, Trainer
import os
from datetime import datetime

def train_specialist(name, config_fn, train_dataset, val_dataset, tokenizer):
    """Train a single specialist model."""
    print(f"\n{'='*60}")
    print(f"Training: {name.upper()} Specialist")
    print(f"{'='*60}")
    
    # Create model config
    config = config_fn()
    config.gradient_checkpointing = CONFIG["gradient_checkpointing"]
    
    print(f"Model: {config.name}")
    print(f"Parameters: {config.num_parameters() / 1e6:.0f}M")
    
    # Create model
    model = create_model(config)
    
    # Resize embeddings if needed
    if len(tokenizer) > config.vocab_size:
        model.embed_tokens = torch.nn.Embedding(len(tokenizer), config.hidden_size)
    
    # Create dataloaders
    dataloaders = create_dataloaders(
        train_dataset, val_dataset,
        batch_size=CONFIG["batch_size"],
        num_workers=2,
    )
    
    # Output directory
    output_dir = os.path.join(CONFIG["output_dir"], name)
    os.makedirs(output_dir, exist_ok=True)
    
    # Training config
    training_config = TrainingConfig(
        num_epochs=CONFIG["epochs"],
        learning_rate=LEARNING_RATES[name],
        batch_size=CONFIG["batch_size"],
        gradient_accumulation_steps=CONFIG["gradient_accumulation"],
        warmup_ratio=CONFIG["warmup_ratio"],
        mixed_precision=True,
        bf16=CONFIG["bf16"],
        gradient_checkpointing=CONFIG["gradient_checkpointing"],
        output_dir=output_dir,
        save_steps=CONFIG["save_steps"],
        use_wandb=CONFIG["use_wandb"],
        wandb_project=CONFIG["wandb_project"],
        wandb_run_name=f"{name}-{datetime.now().strftime('%Y%m%d-%H%M')}",
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        config=training_config,
        train_dataloader=dataloaders["train"],
        eval_dataloader=dataloaders.get("val"),
    )
    
    # Train!
    print(f"\nStarting training...")
    results = trainer.train()
    
    # Save final model
    final_path = os.path.join(output_dir, "final")
    os.makedirs(final_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(final_path, "model.pt"))
    config.save(os.path.join(final_path, "config.json"))
    tokenizer.save_pretrained(final_path)
    
    print(f"\n[DONE] {name} saved to: {final_path}")
    
    # Free memory
    del model, trainer
    torch.cuda.empty_cache()
    
    return results

## Train Router (125M)

In [None]:
if CONFIG["train_router"]:
    router_results = train_specialist(
        "router",
        create_router_config,
        train_dataset,
        val_dataset,
        tokenizer
    )
else:
    print("Skipping router training")

## Train Language Specialist (500M)

In [None]:
if CONFIG["train_language"]:
    language_results = train_specialist(
        "language",
        create_language_specialist_config,
        train_dataset,
        val_dataset,
        tokenizer
    )
else:
    print("Skipping language training")

## Train Reasoning Specialist (500M)

In [None]:
if CONFIG["train_reasoning"]:
    reasoning_results = train_specialist(
        "reasoning",
        create_reasoning_specialist_config,
        train_dataset,
        val_dataset,
        tokenizer
    )
else:
    print("Skipping reasoning training")

## Train Verifier Specialist (250M)

In [None]:
if CONFIG["train_verifier"]:
    verifier_results = train_specialist(
        "verifier",
        create_verifier_specialist_config,
        train_dataset,
        val_dataset,
        tokenizer
    )
else:
    print("Skipping verifier training")

## Training Complete!

In [None]:
print("\n" + "="*60)
print("ENSEMBLE TRAINING COMPLETE!")
print("="*60)

print(f"\nModels saved to: {CONFIG['output_dir']}")
print("\nDirectory structure:")
!ls -la {CONFIG['output_dir']}

print("\n\nNext steps:")
print("1. Download models from Google Drive")
print("2. Run inference locally with the ensemble")
print("3. Fine-tune individual specialists as needed")

## Test Inference (Optional)

In [None]:
# Quick test of the reasoning specialist
import os

reasoning_path = os.path.join(CONFIG['output_dir'], 'reasoning', 'final')

if os.path.exists(reasoning_path):
    print("Loading reasoning specialist for test...")
    
    config = create_reasoning_specialist_config()
    model = create_model(config)
    model.load_state_dict(torch.load(os.path.join(reasoning_path, 'model.pt')))
    model.eval()
    model.cuda()
    
    # Test prompts
    prompts = [
        "What is 15 + 27?",
        "Find the derivative of x^2 + 3x",
    ]
    
    for prompt in prompts:
        print(f"\nPrompt: {prompt}")
        inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_new_tokens=100,
                temperature=0.7,
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Response: {response}")
else:
    print("Reasoning model not found - skipping test")

## Download Models

In [None]:
# Compress all models for download
!tar -czvf svend-ensemble.tar.gz -C /content/drive/MyDrive svend-ensemble

from google.colab import files
files.download('svend-ensemble.tar.gz')