# Level 3: Architecture & Performance

**Objective:** Fix bugs in rotary embeddings, KV cache, MQA, memory management, and tool use.

**Acceptance Criteria:**
- All tests in `tests/test_level3.py` pass
- Model handles long sequences correctly
- KV cache produces consistent outputs
- MQA works with n_head != n_kv_head
- Memory usage stays bounded
- Tool outputs appear in correct order

**Time estimate:** 2-3 hours

In [None]:
import os
import sys
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import gc

repo_root = Path.cwd()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

os.environ["NANOCHAT_BASE_DIR"] = os.path.join(repo_root, ".cache_level3")
os.makedirs(os.environ["NANOCHAT_BASE_DIR"], exist_ok=True)

## Setup

In [None]:
from nanochat.tokenizer import RustBPETokenizer
from nanochat.gpt import GPT, GPTConfig
from nanochat.engine import Engine, KVCache

# Create tokenizer
tokenizer_dir = Path(os.environ["NANOCHAT_BASE_DIR"]) / "tokenizer"
if not tokenizer_dir.exists():
    texts = ["test"] * 1000
    tokenizer = RustBPETokenizer.train_from_iterator(iter(texts), vocab_size=512)
    tokenizer.save(str(tokenizer_dir))
else:
    tokenizer = RustBPETokenizer.from_directory(str(tokenizer_dir))

# Create model with MQA
config = GPTConfig(
    sequence_len=256,
    vocab_size=512,
    n_layer=4,
    n_head=8,
    n_kv_head=4,  # MQA: fewer KV heads than query heads
    n_embd=256,
)

model = GPT(config)
model.init_weights()
model = model.to(device)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} params")
print(f"MQA config: {config.n_head} query heads, {config.n_kv_head} KV heads")

## Test 1: Rotary Embeddings at Different Positions

In [None]:
# Test model at different sequence positions
test_seq = torch.randint(0, 100, (1, 20), dtype=torch.int32).to(device)

outputs_at_positions = []
for pos in [0, 50, 100, 150, 200]:
    with torch.no_grad():
        T = test_seq.size(1)
        cos_sin = model.cos[:, pos:pos+T], model.sin[:, pos:pos+T]
        
        x = model.transformer.wte(test_seq)
        x = torch.nn.functional.rms_norm(x, (x.size(-1),))
        attn_out = model.transformer.h[0].attn(x, cos_sin, kv_cache=None)
        
        outputs_at_positions.append(attn_out[0, 0, :5].cpu().numpy())

# Plot outputs
plt.figure(figsize=(10, 5))
for i, out in enumerate(outputs_at_positions):
    plt.plot(out, label=f'Pos {[0,50,100,150,200][i]}', marker='o')
plt.xlabel('Dimension')
plt.ylabel('Output')
plt.title('Attention Output at Different Positions')
plt.legend()
plt.grid(True)
plt.show()

# Acceptance: outputs should have similar variance across positions
variances = [np.var(out) for out in outputs_at_positions]
print(f"Variances: {variances}")
assert variances[-1] > variances[0] * 0.1, f"FAIL: Output degrades at high positions"
print("✓ Test 1 passed")

## Test 2: KV Cache Consistency

In [None]:
# Test KV cache produces consistent outputs
engine = Engine(model, tokenizer)
prompt = "test prompt"
prompt_tokens = tokenizer.encode(prompt, prepend="<|bos|>")

# Generate with cache
generated_tokens = []
for token_col, _ in engine.generate(prompt_tokens, max_tokens=30):
    generated_tokens.append(token_col[0])

generated_text = tokenizer.decode(prompt_tokens + generated_tokens)
print(f"Generated: '{generated_text}'")

# Acceptance: should generate without errors
assert len(generated_tokens) == 30, f"FAIL: Generated {len(generated_tokens)} tokens, expected 30"
# Check for excessive repetition (sign of cache corruption)
unique_tokens = len(set(generated_tokens[-10:]))
assert unique_tokens >= 3, f"FAIL: Too much repetition in last 10 tokens ({unique_tokens} unique)"
print("✓ Test 2 passed")

## Test 3: MQA Shape Compatibility

In [None]:
# Test MQA with mismatched head counts
attn = model.transformer.h[0].attn
B, T = 2, 10
x = torch.randn(B, T, config.n_embd).to(device)

# Get Q, K, V
q = attn.c_q(x).view(B, T, config.n_head, config.n_embd // config.n_head)
k = attn.c_k(x).view(B, T, config.n_kv_head, config.n_embd // config.n_head)
v = attn.c_v(x).view(B, T, config.n_kv_head, config.n_embd // config.n_head)

print(f"Q shape: {q.shape} ({config.n_head} heads)")
print(f"K shape: {k.shape} ({config.n_kv_head} heads)")
print(f"V shape: {v.shape} ({config.n_kv_head} heads)")

# Test forward pass with MQA
try:
    cos_sin = model.cos[:, :T], model.sin[:, :T]
    output = attn(x, cos_sin, kv_cache=None)
    assert output.shape == (B, T, config.n_embd), f"FAIL: Wrong output shape {output.shape}"
    print("✓ Test 3 passed")
except Exception as e:
    print(f"✗ Test 3 failed: {e}")

## Test 4: Memory Management

In [None]:
# Test memory usage across multiple generations
if device == 'cuda':
    mem_usage = []
    
    for i in range(5):
        engine = Engine(model, tokenizer)
        tokens = tokenizer.encode(f"test {i}", prepend="<|bos|>")
        result, _ = engine.generate_batch(tokens, max_tokens=50)
        
        gc.collect()
        torch.cuda.empty_cache()
        mem = torch.cuda.memory_allocated() / 1024**2
        mem_usage.append(mem)
        print(f"Gen {i+1}: {mem:.2f} MB")
    
    mem_increase = mem_usage[-1] - mem_usage[0]
    print(f"\nMemory increase: {mem_increase:.2f} MB")
    
    assert mem_increase < 50, f"FAIL: Memory leak detected ({mem_increase:.2f} MB increase)"
    print("✓ Test 4 passed")
else:
    print("⊘ Test 4 skipped (requires CUDA)")

## Test 5: Tool Use Token Order

In [None]:
# Test calculator tool output ordering
conversation = {
    "messages": [
        {"role": "user", "content": "What is 5+3?"},
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": "Let me calculate: "},
                {"type": "python", "text": "5+3"},
                {"type": "python_output", "text": "8"},
                {"type": "text", "text": " The answer is 8."}
            ]
        }
    ]
}

ids, mask = tokenizer.render_conversation(conversation)
print(f"Conversation: {len(ids)} tokens")

# Find special tokens
special_tokens = []
for i, token_id in enumerate(ids):
    decoded = tokenizer.decode([token_id])
    if decoded.startswith('<|') and decoded.endswith('|>'):
        special_tokens.append((i, decoded))

print("\nSpecial tokens:")
for pos, tok in special_tokens:
    print(f"  {pos}: {tok}")

# Acceptance: output_start should come before output_end
output_tokens = [(pos, tok) for pos, tok in special_tokens if 'output' in tok]
if len(output_tokens) >= 2:
    assert '<|output_start|>' in output_tokens[0][1], "FAIL: output_start should come first"
    assert '<|output_end|>' in output_tokens[1][1], "FAIL: output_end should come after output_start"
    print("✓ Test 5 passed")
else:
    print("⊘ Test 5: Insufficient output tokens to verify order")

## Summary

All tests passed! Architecture and performance issues resolved.