# Level 2: Training Pipeline

**Objective:** Fix bugs in gradient accumulation, learning rate scheduling, optimizer configuration, and SFT masking.

**Acceptance Criteria:**
- All tests in `tests/test_level2.py` pass
- Gradient accumulation produces stable training
- Learning rate warmup functions correctly
- Optimizers are assigned to correct parameter groups
- SFT masking trains on assistant responses only

**Time estimate:** 1-2 hours

In [None]:
import os
import sys
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

repo_root = Path.cwd()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

os.environ["NANOCHAT_BASE_DIR"] = os.path.join(repo_root, ".cache_level2")
os.makedirs(os.environ["NANOCHAT_BASE_DIR"], exist_ok=True)

## Test 1: Gradient Accumulation

In [None]:
# Create a simple model for testing
model = torch.nn.Linear(10, 10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Test gradient accumulation with different accumulation steps
def train_with_accum(grad_accum_steps, num_steps=20):
    model_copy = torch.nn.Linear(10, 10).to(device)
    model_copy.load_state_dict(model.state_dict())
    opt = torch.optim.SGD(model_copy.parameters(), lr=0.01)
    
    losses = []
    for step in range(num_steps):
        for micro_step in range(grad_accum_steps):
            x = torch.randn(2, 10).to(device)
            y = torch.randn(2, 10).to(device)
            
            pred = model_copy(x)
            loss = torch.nn.functional.mse_loss(pred, y)
            
            # Normalize loss by grad_accum_steps
            loss = loss / grad_accum_steps
            loss.backward()
            
            losses.append(loss.item() * grad_accum_steps)
        
        opt.step()
        opt.zero_grad()
    
    return losses

# Compare different accumulation steps
losses_1 = train_with_accum(1)
losses_4 = train_with_accum(4)
losses_8 = train_with_accum(8)

print(f"Loss variance (accum=1): {np.var(losses_1):.6f}")
print(f"Loss variance (accum=4): {np.var(losses_4):.6f}")
print(f"Loss variance (accum=8): {np.var(losses_8):.6f}")

# Acceptance: Higher accumulation should not cause instability
assert np.var(losses_8) < np.var(losses_1) * 100, "FAIL: Gradient accumulation causes instability"
assert all(np.isfinite(l) for l in losses_8), "FAIL: Loss contains NaN/Inf"
print("✓ Test 1 passed")

## Test 2: Learning Rate Schedule

In [None]:
# Test LR warmup schedule
def get_lr_multiplier(it, num_iterations=1000, warmup_ratio=0.1):
    warmup_iters = round(warmup_ratio * num_iterations)
    if it < warmup_iters:
        return (it + 1) / warmup_iters
    return 1.0

num_iters = 200
warmup_ratio = 0.1
warmup_steps = int(warmup_ratio * num_iters)

lrs = [get_lr_multiplier(i, num_iters, warmup_ratio) for i in range(num_iters)]

plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.axvline(x=warmup_steps, color='r', linestyle='--', label='End of warmup')
plt.axhline(y=1.0, color='g', linestyle='--', label='Target LR')
plt.xlabel('Step')
plt.ylabel('LR Multiplier')
plt.title('Learning Rate Schedule')
plt.legend()
plt.grid(True)
plt.show()

# Acceptance tests
assert lrs[0] > 0.001, f"FAIL: Initial LR too small: {lrs[0]}"
assert lrs[warmup_steps-1] < 1.0, "FAIL: Warmup should not complete before warmup_steps"
assert abs(lrs[warmup_steps] - 1.0) < 0.01, f"FAIL: LR should be ~1.0 after warmup, got {lrs[warmup_steps]}"
assert lrs[warmup_steps//2] > 0.4 and lrs[warmup_steps//2] < 0.6, f"FAIL: Mid-warmup LR should be ~0.5, got {lrs[warmup_steps//2]}"
print("✓ Test 2 passed")

## Test 3: Optimizer Assignment

In [None]:
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer

# Create tokenizer
tokenizer_dir = Path(os.environ["NANOCHAT_BASE_DIR"]) / "tokenizer"
if not tokenizer_dir.exists():
    texts = ["test"] * 100
    tokenizer = RustBPETokenizer.train_from_iterator(iter(texts), vocab_size=256)
    tokenizer.save(str(tokenizer_dir))
else:
    tokenizer = RustBPETokenizer.from_directory(str(tokenizer_dir))

# Create model
config = GPTConfig(
    sequence_len=64,
    vocab_size=256,
    n_layer=2,
    n_head=2,
    n_kv_head=2,
    n_embd=64,
)

model = GPT(config)
model.init_weights()
model = model.to(device)

# Get optimizers
optimizers = model.setup_optimizers()
adamw_opt, muon_opt = optimizers

# Check parameter assignments
adamw_param_names = set()
for group in adamw_opt.param_groups:
    for p in group['params']:
        for name, param in model.named_parameters():
            if param is p:
                adamw_param_names.add(name)

muon_param_names = set()
for p in muon_opt.param_groups[0]['params']:
    for name, param in model.named_parameters():
        if param is p:
            muon_param_names.add(name)

print(f"AdamW params: {sorted(adamw_param_names)}")
print(f"Muon params: {sorted(muon_param_names)}")

# Acceptance tests
assert 'lm_head.weight' in adamw_param_names, "FAIL: lm_head should use AdamW"
assert 'transformer.wte.weight' in adamw_param_names, "FAIL: embeddings should use AdamW"
assert any('transformer.h' in name for name in muon_param_names), "FAIL: transformer layers should use Muon"
assert not any('transformer.h' in name for name in adamw_param_names), "FAIL: transformer layers should not use AdamW"
print("✓ Test 3 passed")

## Test 4: SFT Masking

In [None]:
# Test conversation masking
conversation = {
    "messages": [
        {"role": "user", "content": "Hello"},
        {"role": "assistant", "content": "Hi there!"}
    ]
}

ids, mask = tokenizer.render_conversation(conversation)

print(f"Total tokens: {len(ids)}")
print(f"Tokens to train on (mask=1): {sum(mask)}")
print(f"Tokens to skip (mask=0): {len(mask) - sum(mask)}")

# Simulate SFT data processing
ids_tensor = torch.tensor(ids, dtype=torch.long)
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
targets = ids_tensor[1:].clone()

# Apply masking (correct logic)
targets[mask_tensor == 0] = -1

# Count training tokens
num_train = (targets != -1).sum().item()
num_masked = (targets == -1).sum().item()

print(f"\nAfter masking:")
print(f"Training on {num_train} tokens")
print(f"Masked out {num_masked} tokens")

# Acceptance: Should train on assistant tokens (mask=1), not user tokens (mask=0)
assert num_train > 0, "FAIL: No tokens to train on"
assert num_masked > 0, "FAIL: No tokens masked out"
# Assistant responses should be trained on (more than half of non-BOS tokens)
assert num_train > len(mask) * 0.2, f"FAIL: Too few training tokens ({num_train}/{len(mask)})"
print("✓ Test 4 passed")

## Summary

All tests passed! Training pipeline is functioning correctly.