# 🐛 Nanochat Bug Hunt: Training Pipeline

Welcome to the intermediate debugging challenge! In this notebook, you'll fix 4 training pipeline bugs:

1. **Gradient Accumulation Bug**: Loss explodes with large batch sizes
2. **Learning Rate Bug**: Model doesn't learn due to broken warmup
3. **Optimizer Misconfig**: Wrong parameters assigned to optimizers
4. **SFT Masking Bug**: Model learns to repeat special tokens

This notebook assumes you have a small pretrained base model to work with.

In [None]:
# Setup
import os
import sys
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import importlib

# Add nanochat to path
repo_root = Path.cwd()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

# Device selection
device = 'cuda' if torch.cuda.is_available() else 'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")

# Set up cache directory
os.environ["NANOCHAT_BASE_DIR"] = os.path.join(repo_root, ".cache_medium")
os.makedirs(os.environ["NANOCHAT_BASE_DIR"], exist_ok=True)

## Step 1: Create a Small Model and Dataset

We'll create a minimal setup to test training.

In [None]:
# First, create a tokenizer (or load if exists)
from nanochat.tokenizer import RustBPETokenizer

tokenizer_dir = Path(os.environ["NANOCHAT_BASE_DIR"]) / "tokenizer"

if tokenizer_dir.exists():
    print("Loading existing tokenizer...")
    tokenizer = RustBPETokenizer.from_directory(str(tokenizer_dir))
else:
    print("Creating new tokenizer...")
    # Create training data
    texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Machine learning is transforming the world.",
        "Python is a versatile programming language.",
        "Artificial intelligence powers modern technology.",
    ] * 200
    
    def text_iterator():
        for text in texts:
            yield text
    
    tokenizer = RustBPETokenizer.train_from_iterator(text_iterator(), vocab_size=1024)
    tokenizer.save(str(tokenizer_dir))

print(f"Tokenizer vocab size: {tokenizer.get_vocab_size()}")

In [None]:
# Create a tiny model
from nanochat.gpt import GPT, GPTConfig

config = GPTConfig(
    sequence_len=128,
    vocab_size=tokenizer.get_vocab_size(),
    n_layer=4,      # Small model
    n_head=4,
    n_kv_head=4,
    n_embd=128,
)

model = GPT(config)
model.init_weights()
model = model.to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Step 2: Test Gradient Accumulation - Find Bug #1

Let's simulate training with gradient accumulation and see what happens.

In [None]:
# Create simple training data
def get_batch(batch_size=2, seq_len=64):
    """Generate random batch for testing"""
    # Random tokens
    inputs = torch.randint(0, tokenizer.get_vocab_size(), (batch_size, seq_len), dtype=torch.int32).to(device)
    # Targets are shifted inputs (simplified)
    targets = torch.randint(0, tokenizer.get_vocab_size(), (batch_size, seq_len), dtype=torch.int64).to(device)
    return inputs, targets

# Test training with gradient accumulation
# Simulating the bug from base_train.py
def train_with_grad_accum(model, steps=20, grad_accum_steps=4):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    losses = []
    
    for step in range(steps):
        total_loss = 0
        
        # Gradient accumulation loop (mimicking base_train.py)
        for micro_step in range(grad_accum_steps):
            x, y = get_batch()
            loss = model(x, y)
            
            # BUG 1: Not normalizing loss!
            # loss = loss / grad_accum_steps  # This line is missing!
            loss.backward()  # Accumulating unnormalized losses
            
            total_loss += loss.item()
        
        # Step optimizer
        optimizer.step()
        optimizer.zero_grad()
        
        avg_loss = total_loss / grad_accum_steps
        losses.append(avg_loss)
        
        if step % 5 == 0:
            print(f"Step {step}: avg loss = {avg_loss:.4f}")
    
    return losses

print("Training with gradient accumulation (buggy)...")
losses_buggy = train_with_grad_accum(model.clone(), grad_accum_steps=8)

plt.figure(figsize=(8, 4))
plt.plot(losses_buggy)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss with Bug #1 (Unnormalized Gradient Accumulation)')
plt.show()

if losses_buggy[-1] > losses_buggy[0] * 2:
    print("\\n❌ BUG #1 DETECTED! Loss is exploding!")
    print("💡 The gradients are being accumulated without normalization.")
    print("💡 Check base_train.py around line 271 where loss.backward() is called.")

## Fix Bug #1: Gradient Accumulation

Go to `scripts/base_train.py` and uncomment the line that normalizes loss:
```python
loss = loss / grad_accum_steps  # each .backward() is a grad sum => normalize loss here
```

In [None]:
# Test with fix
def train_with_grad_accum_fixed(model, steps=20, grad_accum_steps=4):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    losses = []
    
    for step in range(steps):
        total_loss = 0
        
        for micro_step in range(grad_accum_steps):
            x, y = get_batch()
            loss = model(x, y)
            
            # FIXED: Normalize loss
            loss = loss / grad_accum_steps
            loss.backward()
            
            total_loss += loss.item() * grad_accum_steps  # Scale back for logging
        
        optimizer.step()
        optimizer.zero_grad()
        
        avg_loss = total_loss / grad_accum_steps
        losses.append(avg_loss)
    
    return losses

# Reinitialize model
model_fixed = GPT(config)
model_fixed.init_weights()
model_fixed = model_fixed.to(device)

print("Training with gradient accumulation (fixed)...")
losses_fixed = train_with_grad_accum_fixed(model_fixed, grad_accum_steps=8)

plt.figure(figsize=(8, 4))
plt.plot(losses_fixed)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss with Bug #1 Fixed')
plt.show()

print("✅ Bug #1 FIXED! Loss is now stable with gradient accumulation!")

## Step 3: Test Learning Rate Schedule - Find Bug #2

Now let's test the learning rate warmup.

In [None]:
# Test the learning rate schedule from base_train.py
def get_lr_multiplier_buggy(it, num_iterations=1000, warmup_ratio=0.1):
    """Buggy LR scheduler from base_train.py"""
    warmup_iters = round(warmup_ratio * num_iterations)
    if it < warmup_iters:
        # BUG: Warmup is divided by 100x too much!
        return (it + 1) / (warmup_iters * 100)
    else:
        return 1.0

# Plot the learning rate schedule
num_iters = 200
warmup_ratio = 0.1
steps = range(num_iters)
lr_multipliers = [get_lr_multiplier_buggy(step, num_iters, warmup_ratio) for step in steps]

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(steps, lr_multipliers)
plt.axhline(y=1.0, color='r', linestyle='--', label='Target LR')
plt.xlabel('Step')
plt.ylabel('LR Multiplier')
plt.title('Buggy Learning Rate Schedule')
plt.legend()
plt.ylim(0, 1.2)

# Zoom in on warmup
plt.subplot(1, 2, 2)
warmup_steps = int(warmup_ratio * num_iters)
plt.plot(steps[:warmup_steps*2], lr_multipliers[:warmup_steps*2])
plt.axvline(x=warmup_steps, color='g', linestyle='--', label='End of warmup')
plt.xlabel('Step')
plt.ylabel('LR Multiplier')
plt.title('Warmup Phase (Zoomed)')
plt.legend()

plt.tight_layout()
plt.show()

print(f"\\n❌ BUG #2 DETECTED! Learning rate at step 10: {lr_multipliers[10]:.6f}")
print(f"Expected ~0.5, but got {lr_multipliers[10]:.6f} (100x too small!)")
print("💡 Check get_lr_multiplier() in base_train.py around line 160")

## Fix Bug #2: Learning Rate Schedule

Fix the warmup calculation in `scripts/base_train.py`:
```python
return (it + 1) / warmup_iters  # Remove the * 100
```

In [None]:
# Fixed LR schedule
def get_lr_multiplier_fixed(it, num_iterations=1000, warmup_ratio=0.1):
    """Fixed LR scheduler"""
    warmup_iters = round(warmup_ratio * num_iterations)
    if it < warmup_iters:
        return (it + 1) / warmup_iters  # Fixed!
    else:
        return 1.0

# Plot fixed schedule
lr_multipliers_fixed = [get_lr_multiplier_fixed(step, num_iters, warmup_ratio) for step in steps]

plt.figure(figsize=(8, 4))
plt.plot(steps, lr_multipliers_fixed, label='Fixed')
plt.plot(steps, lr_multipliers, label='Buggy', alpha=0.5)
plt.axhline(y=1.0, color='r', linestyle='--', label='Target LR')
plt.xlabel('Step')
plt.ylabel('LR Multiplier')
plt.title('Learning Rate Schedule Comparison')
plt.legend()
plt.show()

print("✅ Bug #2 FIXED! Learning rate warmup now works correctly!")

## Step 4: Test Optimizer Assignment - Find Bug #3

Let's check if the optimizers are assigned to the correct parameters.

In [None]:
# Test the optimizer setup
print("Testing optimizer setup...\\n")

# Get the optimizers from the model
optimizers = model.setup_optimizers()
adamw_opt, muon_opt = optimizers

# Check what parameters are in each optimizer
print("AdamW optimizer parameter groups:")
for i, group in enumerate(adamw_opt.param_groups):
    params = group['params']
    print(f"  Group {i}: {len(params)} parameters")
    # Check what these parameters are
    for p in params[:2]:  # Just check first 2
        for name, param in model.named_parameters():
            if param is p:
                print(f"    - {name}")
                break

print("\\nMuon optimizer parameters:")
muon_params = muon_opt.param_groups[0]['params']
print(f"  {len(muon_params)} parameters")
for p in muon_params[:2]:  # Just check first 2
    for name, param in model.named_parameters():
        if param is p:
            print(f"    - {name}")
            break

# Check if bug exists
print("\\n❌ BUG #3 DETECTED! Optimizer assignment is wrong!")
print("💡 Matrix parameters (transformer.h.*) should use Muon optimizer")
print("💡 But they're assigned to AdamW instead!")
print("💡 Check setup_optimizers() in nanochat/gpt.py around line 227")

## Fix Bug #3: Optimizer Assignment

In `nanochat/gpt.py`, fix the parameter assignments:
- AdamW should get: `lm_head_params` and `embedding_params`
- Muon should get: `matrix_params`

In [None]:
# After fixing, reload the module
import nanochat.gpt
importlib.reload(nanochat.gpt)
from nanochat.gpt import GPT, GPTConfig

# Create new model and check
model_fixed = GPT(config)
model_fixed.init_weights()
model_fixed = model_fixed.to(device)

# This would work after the fix
print("After fixing gpt.py, the optimizers should be correctly assigned.")
print("✅ Bug #3 will be FIXED when you update the code!")

## Step 5: Test SFT Masking - Find Bug #4

Finally, let's test the SFT masking logic for conversation training.

In [None]:
# Create a sample conversation
conversation = {
    "messages": [
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": "I'm doing well, thank you!"}
    ]
}

# Tokenize the conversation
ids, mask = tokenizer.render_conversation(conversation)

print("Tokenized conversation:")
print(f"IDs: {ids[:20]}... (showing first 20)")
print(f"Mask: {mask[:20]}... (showing first 20)")
print(f"\\nMask values: 0 = don't train, 1 = train on this token")

# Visualize what tokens we're training on
print("\\nVisualization (GREEN = train, RED = don't train):")
print(tokenizer.visualize_tokenization(ids[:50], mask[:50]))

In [None]:
# Simulate the buggy SFT data processing
def process_sft_batch_buggy(ids, mask):
    """Simulate the buggy masking from chat_sft.py"""
    ids_tensor = torch.tensor(ids, dtype=torch.long)
    mask_tensor = torch.tensor(mask[1:], dtype=torch.long)  # Skip BOS mask
    
    # Create targets
    targets = ids_tensor[1:].clone()
    
    # BUG: Inverted mask logic!
    targets[mask_tensor == 1] = -1  # Bug: masking where we SHOULD train!
    
    return targets

# Process with bug
targets_buggy = process_sft_batch_buggy(ids, mask)

# Count how many tokens we're training on
num_train_tokens = (targets_buggy != -1).sum().item()
total_tokens = len(targets_buggy)

print(f"\\nWith buggy masking:")
print(f"Training on {num_train_tokens}/{total_tokens} tokens")
print(f"That's only {100*num_train_tokens/total_tokens:.1f}% of tokens!")

# Check what we're training on
train_indices = (targets_buggy != -1).nonzero().squeeze().tolist()
if isinstance(train_indices, int):
    train_indices = [train_indices]
print(f"\\nTraining on tokens at positions: {train_indices[:10]}...")

print("\\n❌ BUG #4 DETECTED! Mask logic is inverted!")
print("💡 We're training on user messages instead of assistant messages!")
print("💡 Check chat_sft.py around line 114 where mask is applied")

## Fix Bug #4: SFT Masking

In `scripts/chat_sft.py`, fix the mask logic:
```python
row_targets[mask_tensor == 0] = -1  # Mask where mask is 0, not 1!
```

In [None]:
# Fixed masking
def process_sft_batch_fixed(ids, mask):
    """Fixed SFT masking"""
    ids_tensor = torch.tensor(ids, dtype=torch.long)
    mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
    
    targets = ids_tensor[1:].clone()
    
    # FIXED: Correct mask logic
    targets[mask_tensor == 0] = -1  # Mask where we should NOT train
    
    return targets

# Process with fix
targets_fixed = process_sft_batch_fixed(ids, mask)
num_train_tokens_fixed = (targets_fixed != -1).sum().item()

print(f"With fixed masking:")
print(f"Training on {num_train_tokens_fixed}/{total_tokens} tokens")
print(f"That's {100*num_train_tokens_fixed/total_tokens:.1f}% of tokens!")

# Visualize what we're training on now
print("\\n✅ Bug #4 FIXED! Now training on assistant responses only!")

In [None]:
# Summary
print("🎉 Congratulations! You've debugged the training pipeline!\\n")
print("Summary of fixes:")
print("1. ✅ Gradient accumulation: Added loss normalization")
print("2. ✅ Learning rate warmup: Removed factor of 100")
print("3. ✅ Optimizer assignment: Matrix params → Muon, embeddings → AdamW")
print("4. ✅ SFT masking: Fixed inverted mask logic")
print("\\nThese bugs would have caused:")
print("- Unstable training with large batch sizes")
print("- Extremely slow initial learning")
print("- Poor optimization of different parameter types")
print("- Model learning wrong conversation patterns")