# Svend Language Model Pretraining

**Stage 1 of 2-stage training pipeline**

This notebook trains the base language model that will later be fine-tuned for reasoning.

**Goal:** Learn language patterns, grammar, and world knowledge from diverse text data.

**Output:** Base language model checkpoint to be used by `train_reasoning_specialist.ipynb`

**Requirements:**
- Colab Pro+ (A100 recommended)
- Google Drive for checkpoint persistence
- ~10-20 hours training time for 500M model

## 1. Setup Environment

In [None]:
# Mount Google Drive first (for checkpoint persistence)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository
!git clone https://github.com/ewolters/svend.git 2>/dev/null || echo "Already cloned"
%cd svend

In [None]:
# Install dependencies
!pip install -q torch transformers datasets accelerate wandb
!pip install -q sentencepiece tiktoken

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from datasets import load_dataset
import os
import json
import math
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm

# Add src to path
import sys
sys.path.insert(0, '/content/svend')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# =============================================================================
# TRAINING CONFIGURATION
# =============================================================================

# =============================================
# QUICK CONFIG - CHANGE THESE FOR EACH RUN
# =============================================
# Resume from your latest checkpoint (100K steps)
RESUME_CHECKPOINT = "/content/drive/MyDrive/svend-checkpoints/language-model/checkpoint-100000.pt"
ADDITIONAL_STEPS = 50_000  # Steps to train THIS run (50K = ~30-45 min on A100)
# =============================================

CONFIG = {
    # Model
    "model_size": "500m",

    # Data - diverse mix for language understanding
    "datasets": [
        {"name": "openwebtext", "subset": None, "weight": 0.35},
        {"name": "wikimedia/wikipedia", "subset": "20231101.en", "weight": 0.25},
        {"name": "allenai/c4", "subset": "en", "weight": 0.20},
        {"name": "HuggingFaceFW/fineweb", "subset": "sample-10BT", "weight": 0.20},
    ],
    
    # Training
    "max_steps": ADDITIONAL_STEPS,  # Updated after loading checkpoint
    "batch_size": 8,
    "gradient_accumulation": 4,  # Effective batch = 32
    "max_seq_length": 1024,
    "learning_rate": 1e-4,  # Lower LR for continued training
    "weight_decay": 0.1,
    "warmup_steps": 500,  # Short warmup for resume
    "max_grad_norm": 1.0,
    
    # Checkpointing
    "checkpoint_dir": "/content/drive/MyDrive/svend-checkpoints/language-model",
    "save_every": 10000,  # Save every 10K
    "eval_every": None,
    
    # Logging
    "use_wandb": True,
    "wandb_project": "svend-language",
    "experiment_name": f"lm-500m-continued-{datetime.now().strftime('%Y%m%d-%H%M')}",
    
    # Resume
    "resume_from": RESUME_CHECKPOINT,
}

os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

print("="*60)
print("SVEND LANGUAGE MODEL - CONTINUED TRAINING")
print("="*60)
print(f"Resume from: checkpoint-100000.pt")
print(f"Additional steps: {ADDITIONAL_STEPS:,}")
print(f"Target: 150K total steps")
print(f"Learning rate: {CONFIG['learning_rate']} (reduced for fine-tuning)")
print("="*60)

## 3. Create Model

In [None]:
from src.models.config import create_language_specialist_config
from src.models.transformer import ReasoningTransformer

# Create language model config
model_config = create_language_specialist_config()

# Adjust for language pretraining (no tool tokens needed yet)
model_config.tool_calling = False
model_config.num_tool_tokens = 0

print(f"Model: {model_config.name}")
print(f"Parameters: {model_config.num_parameters() / 1e6:.0f}M")
print(f"Hidden size: {model_config.hidden_size}")
print(f"Layers: {model_config.num_hidden_layers}")
print(f"Attention heads: {model_config.num_attention_heads}")
print(f"Context length: {model_config.max_position_embeddings}")

memory = model_config.memory_footprint()
print(f"\nEstimated training memory: {memory['total_training_gb']:.1f} GB")

In [None]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Update vocab size in config
model_config.vocab_size = len(tokenizer)

print(f"Tokenizer vocab size: {len(tokenizer)}")

In [None]:
# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ReasoningTransformer(model_config)
model = model.to(device)

# Use mixed precision
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

print(f"Model on: {device}")
print(f"Training dtype: {dtype}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Setup Data Loading

We use streaming datasets to avoid downloading everything upfront.

In [None]:
class StreamingTextDataset(IterableDataset):
    """
    Streams text from multiple datasets with weighted sampling.
    Handles tokenization and chunking on-the-fly.
    """
    
    def __init__(self, dataset_configs, tokenizer, max_length=1024, seed=42):
        self.dataset_configs = dataset_configs
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.seed = seed
        
    def __iter__(self):
        # Load streaming datasets
        streams = []
        weights = []
        
        for config in self.dataset_configs:
            try:
                if config["subset"]:
                    ds = load_dataset(
                        config["name"], 
                        config["subset"], 
                        split="train", 
                        streaming=True,
                    )
                else:
                    ds = load_dataset(
                        config["name"], 
                        split="train", 
                        streaming=True,
                    )
                streams.append(iter(ds))
                weights.append(config["weight"])
                print(f"Loaded: {config['name']}")
            except Exception as e:
                print(f"Warning: Could not load {config['name']}: {e}")
        
        if not streams:
            raise ValueError("No datasets could be loaded!")
        
        # Normalize weights
        total = sum(weights)
        weights = [w / total for w in weights]
        
        # Buffer for accumulating tokens
        token_buffer = []
        
        import random
        rng = random.Random(self.seed)
        
        while True:
            # Sample a dataset based on weights
            idx = rng.choices(range(len(streams)), weights=weights)[0]
            
            try:
                example = next(streams[idx])
            except StopIteration:
                # Dataset exhausted, remove it
                streams.pop(idx)
                weights.pop(idx)
                if not streams:
                    break
                total = sum(weights)
                weights = [w / total for w in weights]
                continue
            
            # Get text from example
            text = self._extract_text(example)
            if not text:
                continue
            
            # Tokenize
            tokens = self.tokenizer.encode(text, add_special_tokens=False)
            token_buffer.extend(tokens)
            token_buffer.append(self.tokenizer.eos_token_id)
            
            # Yield chunks when buffer is full
            while len(token_buffer) >= self.max_length:
                chunk = token_buffer[:self.max_length]
                token_buffer = token_buffer[self.max_length:]
                
                yield {
                    "input_ids": torch.tensor(chunk, dtype=torch.long),
                    "labels": torch.tensor(chunk, dtype=torch.long),
                }
    
    def _extract_text(self, example):
        """Extract text from different dataset formats."""
        # Try common field names
        for field in ["text", "content", "article", "passage"]:
            if field in example and example[field]:
                return example[field]
        return None


def collate_fn(batch):
    """Collate batch of examples."""
    input_ids = torch.stack([x["input_ids"] for x in batch])
    labels = torch.stack([x["labels"] for x in batch])
    return {"input_ids": input_ids, "labels": labels}


print("Data loading utilities defined.")

In [None]:
# Create dataset and dataloader
dataset = StreamingTextDataset(
    dataset_configs=CONFIG["datasets"],
    tokenizer=tokenizer,
    max_length=CONFIG["max_seq_length"],
)

dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    collate_fn=collate_fn,
    num_workers=2,
    prefetch_factor=4,
)

print(f"DataLoader created with batch_size={CONFIG['batch_size']}")

## 5. Setup Training

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
    betas=(0.9, 0.95),
)

# Learning rate scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG["warmup_steps"],
    num_training_steps=CONFIG["max_steps"],
)

# Gradient scaler for mixed precision
scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))

print("Optimizer and scheduler configured.")

In [None]:
# Initialize wandb
if CONFIG["use_wandb"]:
    import wandb
    wandb.init(
        project=CONFIG["wandb_project"],
        name=CONFIG["experiment_name"],
        config={
            **CONFIG,
            "model_params": model_config.num_parameters(),
            "model_config": model_config.to_dict(),
        }
    )
    print(f"WandB initialized: {CONFIG['experiment_name']}")

In [None]:
def save_checkpoint(model, optimizer, scheduler, step, loss, path):
    """Save training checkpoint."""
    checkpoint = {
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "loss": loss,
        "config": model_config.to_dict(),
        "training_config": CONFIG,
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")


def load_checkpoint(path, model, optimizer, scheduler):
    """Load training checkpoint."""
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    return checkpoint["step"], checkpoint["loss"]


print("Checkpoint utilities defined.")

In [None]:
# Resume from checkpoint if specified
start_step = 0

if CONFIG["resume_from"]:
    print(f"Loading checkpoint: {CONFIG['resume_from']}")
    checkpoint = torch.load(CONFIG["resume_from"], map_location=device)
    
    # Load model weights
    model.load_state_dict(checkpoint["model_state_dict"])
    
    # Get previous step count
    prev_steps = checkpoint.get("training_steps", checkpoint.get("step", 0))
    start_step = prev_steps
    
    # Update max_steps to be previous + additional
    CONFIG["max_steps"] = prev_steps + ADDITIONAL_STEPS
    
    print(f"  Previous training: {prev_steps:,} steps")
    print(f"  This run will add: {ADDITIONAL_STEPS:,} steps")
    print(f"  Target total: {CONFIG['max_steps']:,} steps")
    
    # Reinitialize optimizer and scheduler for the new run
    # (fresh optimizer often works better than loading old state for continued training)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=CONFIG["weight_decay"],
        betas=(0.9, 0.95),
    )
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=CONFIG["warmup_steps"],
        num_training_steps=ADDITIONAL_STEPS,  # Schedule for this run's steps
    )
    print(f"  Fresh optimizer initialized")
else:
    print("Starting fresh training")
    CONFIG["max_steps"] = ADDITIONAL_STEPS

## 6. Training Loop

In [None]:
@torch.no_grad()
def evaluate(model, tokenizer, num_samples=3):
    """Quick evaluation - generate samples."""
    torch.cuda.empty_cache()
    
    try:
        model.eval()
        
        prompts = [
            "The capital of France is",
            "In 1969, humans first",
            "Water boils at",
        ]
        
        print("\n" + "="*60)
        print("Sample generations:")
        print("="*60)
        
        for prompt in prompts[:num_samples]:
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
            
            try:
                output = model.generate(
                    input_ids,
                    max_new_tokens=30,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    pad_token_id=tokenizer.eos_token_id,
                )
                
                generated = tokenizer.decode(output[0], skip_special_tokens=True)
                print(f"\nPrompt: {prompt}")
                print(f"Output: {generated}")
            except Exception as e:
                print(f"\nPrompt: {prompt}")
                print(f"Generation failed: {type(e).__name__}")
        
        print("="*60 + "\n")
    except Exception as e:
        print(f"\nEvaluation skipped: {type(e).__name__}\n")
    finally:
        model.train()
        torch.cuda.empty_cache()


print("Evaluation function defined.")

In [None]:
# =============================================================================
# MAIN TRAINING LOOP
# =============================================================================

print("\n" + "="*60)
print("STARTING LANGUAGE MODEL PRETRAINING")
print("="*60)
print(f"Model: {model_config.name}")
print(f"Parameters: {model_config.num_parameters() / 1e6:.0f}M")
print(f"Max steps: {CONFIG['max_steps']:,}")
print(f"Checkpoint dir: {CONFIG['checkpoint_dir']}")
print("="*60 + "\n")

model.train()
step = start_step
total_loss = 0
log_interval = 100

progress = tqdm(total=CONFIG["max_steps"], initial=start_step, desc="Training")

try:
    for batch in dataloader:
        if step >= CONFIG["max_steps"]:
            break
        
        # Move to device
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        
        # Forward pass with mixed precision
        with torch.amp.autocast("cuda", dtype=dtype):
            outputs = model(input_ids, labels=labels)
            loss = outputs["loss"] / CONFIG["gradient_accumulation"]
        
        # Backward pass
        scaler.scale(loss).backward()
        total_loss += loss.item() * CONFIG["gradient_accumulation"]
        
        # Gradient accumulation
        if (step + 1) % CONFIG["gradient_accumulation"] == 0:
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["max_grad_norm"])
            
            # Optimizer step
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
        
        step += 1
        progress.update(1)
        
        # Logging
        if step % log_interval == 0:
            avg_loss = total_loss / log_interval
            lr = scheduler.get_last_lr()[0]
            
            progress.set_postfix({
                "loss": f"{avg_loss:.4f}",
                "lr": f"{lr:.2e}",
                "ppl": f"{math.exp(min(avg_loss, 20)):.1f}"
            })
            
            if CONFIG["use_wandb"]:
                wandb.log({
                    "train/loss": avg_loss,
                    "train/perplexity": math.exp(min(avg_loss, 20)),
                    "train/learning_rate": lr,
                    "train/step": step,
                })
            
            total_loss = 0
        
        # Checkpoint
        if step % CONFIG["save_every"] == 0:
            checkpoint_path = os.path.join(
                CONFIG["checkpoint_dir"],
                f"checkpoint-{step}.pt"
            )
            save_checkpoint(model, optimizer, scheduler, step, avg_loss, checkpoint_path)

except KeyboardInterrupt:
    print("\nTraining interrupted. Saving checkpoint...")
    checkpoint_path = os.path.join(CONFIG["checkpoint_dir"], f"checkpoint-{step}-interrupted.pt")
    save_checkpoint(model, optimizer, scheduler, step, total_loss / log_interval if total_loss > 0 else 0, checkpoint_path)

progress.close()
print(f"\nTraining completed at step {step}")

## 7. Save Final Model

In [None]:
# Save final model
final_path = os.path.join(CONFIG["checkpoint_dir"], "final-language-model.pt")

torch.save({
    "model_state_dict": model.state_dict(),
    "config": model_config.to_dict(),
    "tokenizer_name": "gpt2",
    "training_steps": step,
    "training_config": CONFIG,
}, final_path)

print(f"\nFinal model saved to: {final_path}")
print("\nThis checkpoint will be used as the base for reasoning fine-tuning.")
print("Next step: Run train_reasoning_specialist.ipynb")

In [None]:
# Final evaluation
print("\nFinal Evaluation:")
evaluate(model, tokenizer, num_samples=5)

In [None]:
# Cleanup
if CONFIG["use_wandb"]:
    wandb.finish()
    print("WandB run finished.")

print("\n" + "="*60)
print("LANGUAGE MODEL PRETRAINING COMPLETE")
print("="*60)
print(f"Final checkpoint: {final_path}")
print(f"Total steps: {step:,}")
print("\nNext: Use this model as the base for reasoning fine-tuning.")
print("="*60)

## 8. Resume Training (if needed)

If Colab disconnects, update `CONFIG["resume_from"]` and re-run the notebook.

In [None]:
# List available checkpoints
import os

checkpoint_dir = CONFIG["checkpoint_dir"]
if os.path.exists(checkpoint_dir):
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")])
    print("Available checkpoints:")
    for cp in checkpoints:
        path = os.path.join(checkpoint_dir, cp)
        size_mb = os.path.getsize(path) / 1e6
        print(f"  {cp} ({size_mb:.0f} MB)")
else:
    print("No checkpoints found yet.")

---

## Notes

### What this notebook does:
1. Trains a 500M parameter language model from scratch
2. Uses diverse text data (OpenWebText, Wikipedia, BookCorpus, C4)
3. Produces a base model that understands language patterns

### What's next:
1. Run `train_reasoning_specialist.ipynb` to fine-tune for math/reasoning
2. The reasoning specialist will load this language model checkpoint
3. Fine-tuning adds reasoning capability on top of language understanding

### Tips:
- Monitor perplexity (PPL) - should decrease steadily
- Good language models reach PPL ~20-30 on diverse text
- Sample generations should become more coherent over time
- Save checkpoints frequently - Colab can disconnect!