# 🐛 Nanochat Bug Hunt: Architecture & Performance

Welcome to the advanced debugging challenge! This notebook contains 5 challenging bugs:

1. **Rotary Embedding Bug**: Model can't handle long sequences
2. **KV Cache Corruption**: Garbage output after ~50 tokens
3. **MQA Implementation Bug**: Wrong condition for Group Query Attention
4. **Memory Leak**: KV cache grows without bounds
5. **Tool Use Bug**: Calculator results appear at wrong positions

These bugs require deep understanding of transformer architecture and modern optimizations.

In [None]:
# Setup
import os
import sys
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import importlib
import gc

# Add nanochat to path
repo_root = Path.cwd()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

# Device selection
device = 'cuda' if torch.cuda.is_available() else 'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")

# Set up cache directory
os.environ["NANOCHAT_BASE_DIR"] = os.path.join(repo_root, ".cache_hard")
os.makedirs(os.environ["NANOCHAT_BASE_DIR"], exist_ok=True)

## Preparation: Create Model and Tokenizer

We'll need a model to test these advanced bugs.

In [None]:
# Create or load tokenizer
from nanochat.tokenizer import RustBPETokenizer

tokenizer_dir = Path(os.environ["NANOCHAT_BASE_DIR"]) / "tokenizer"

if tokenizer_dir.exists():
    print("Loading tokenizer...")
    tokenizer = RustBPETokenizer.from_directory(str(tokenizer_dir))
else:
    print("Training tokenizer...")
    texts = ["The quick brown fox"] * 1000
    tokenizer = RustBPETokenizer.train_from_iterator(iter(texts), vocab_size=1024)
    tokenizer.save(str(tokenizer_dir))

print(f"Vocab size: {tokenizer.get_vocab_size()}")

In [None]:
# Create a model with specific config to test bugs
from nanochat.gpt import GPT, GPTConfig

config = GPTConfig(
    sequence_len=256,  # Long enough to test position issues
    vocab_size=tokenizer.get_vocab_size(),
    n_layer=4,
    n_head=8,       # 8 query heads
    n_kv_head=4,    # 4 key/value heads for MQA
    n_embd=256,
)

model = GPT(config)
model.init_weights()
model = model.to(device)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Using MQA: {config.n_head} query heads, {config.n_kv_head} kv heads")

## Bug #1: Rotary Embedding Frequency Calculation

Let's test if the model can handle sequences at different positions.

In [None]:
# Test rotary embeddings at different positions
def test_rotary_positions(model, positions=[0, 50, 100, 150, 200]):
    """Test if model produces consistent outputs at different sequence positions"""
    
    # Create a simple input
    test_seq = torch.randint(0, 100, (1, 10), dtype=torch.int32).to(device)
    
    outputs = []
    for pos_offset in positions:
        # Simulate being at different positions in sequence
        with torch.no_grad():
            # Get rotary embeddings for this position
            T = test_seq.size(1)
            cos_sin = model.cos[:, pos_offset:pos_offset+T], model.sin[:, pos_offset:pos_offset+T]
            
            # Manual forward through first attention layer
            x = model.transformer.wte(test_seq)
            x = torch.nn.functional.rms_norm(x, (x.size(-1),))
            
            # Get attention output
            attn_out = model.transformer.h[0].attn(x, cos_sin, kv_cache=None)
            outputs.append(attn_out[0, 0, :5].cpu().numpy())  # First 5 dims
    
    return outputs

# Test at different positions
positions = [0, 50, 100, 150, 200]
outputs = test_rotary_positions(model, positions)

# Plot the outputs
plt.figure(figsize=(10, 6))
for i, (pos, out) in enumerate(zip(positions, outputs)):
    plt.plot(out, label=f'Position {pos}', marker='o')

plt.xlabel('Dimension')
plt.ylabel('Output Value')
plt.title('Attention Output at Different Sequence Positions')
plt.legend()
plt.grid(True)
plt.show()

# Check variance
variances = [np.var(out) for out in outputs]
print(f"Output variances: {variances}")
if variances[-1] < variances[0] * 0.1:
    print("\\n❌ BUG #1 DETECTED! Outputs degrade at higher positions!")
    print("💡 The rotary frequency calculation seems wrong.")
    print("💡 Check _precompute_rotary_embeddings() in gpt.py")

## Fix Bug #1: Rotary Embeddings

The frequency calculation is wrong:
- Current: `inv_freq = 1.0 / (base ** (channel_range / (head_dim * 2)))`
- Should be: `inv_freq = 1.0 / (base ** (channel_range / head_dim))`

In [None]:
# After fixing, test again
print("After fixing the rotary embedding calculation...")
print("✅ Model should handle all positions equally well!")

## Bug #2 & #4: KV Cache Position Tracking and Memory Leak

Let's test generation with the KV cache.

In [None]:
# Test KV cache during generation
from nanochat.engine import Engine, KVCache

# Create engine
engine = Engine(model, tokenizer)

# Generate some text and monitor KV cache
def test_generation_with_cache_monitoring():
    prompt = "Once upon a time"
    prompt_tokens = tokenizer.encode(prompt, prepend="<|bos|>")
    
    print(f"Generating from: '{prompt}'")
    print("\\nMonitoring KV cache positions...\\n")
    
    # Track positions
    positions = []
    tokens_generated = []
    
    # Generate token by token to monitor
    for i, (token_col, masks) in enumerate(engine.generate(prompt_tokens, max_tokens=20)):
        token = token_col[0]
        tokens_generated.append(token)
        
        # Try to access cache position (this will fail with the bug)
        if i < 10:  # Only print first 10
            print(f"Step {i}: Generated token {token} ('{tokenizer.decode([token])}')") 
    
    # Decode all
    generated_text = tokenizer.decode(prompt_tokens + tokens_generated)
    print(f"\\nGenerated text: '{generated_text}'")
    
    return generated_text

# Test generation
generated = test_generation_with_cache_monitoring()

# Check for corruption
if len(set(generated.split()[-5:])) < 3:  # Last 5 words are too repetitive
    print("\\n❌ BUG #2 DETECTED! Generation quality degrades!")
    print("💡 KV cache position tracking is incrementing at the wrong layer.")
    print("💡 Check insert_kv() in engine.py - it should increment after LAST layer.")

In [None]:
# Test memory usage with multiple generations
def test_memory_leak():
    print("Testing memory usage across multiple generations...\\n")
    
    # Create a fresh engine
    engine = Engine(model, tokenizer)
    
    mem_usage = []
    
    for i in range(5):
        # Generate some text
        tokens = tokenizer.encode(f"Test {i}:", prepend="<|bos|>")
        result, _ = engine.generate_batch(tokens, max_tokens=50)
        
        # Force garbage collection and measure memory
        gc.collect()
        if device == 'cuda':
            torch.cuda.empty_cache()
            mem = torch.cuda.memory_allocated() / 1024**2  # MB
        else:
            mem = 0  # Can't easily measure on CPU/MPS
        
        mem_usage.append(mem)
        print(f"Generation {i+1}: Memory used: {mem:.2f} MB")
    
    if device == 'cuda' and len(mem_usage) > 1:
        mem_increase = mem_usage[-1] - mem_usage[0]
        print(f"\\nMemory increase: {mem_increase:.2f} MB")
        
        if mem_increase > 50:  # More than 50MB increase is suspicious
            print("\\n❌ BUG #4 DETECTED! Memory usage keeps growing!")
            print("💡 KV cache is not being properly reset between generations.")
            print("💡 Check reset() method in KVCache class.")
            print("💡 Also check the cache growth logic - it's growing 2x more than needed!")

test_memory_leak()

## Bug #3: MQA Implementation

Test Multi-Query Attention with different head configurations.

In [None]:
# Test MQA implementation
def test_mqa_shapes():
    """Test if MQA works correctly with n_head != n_kv_head"""
    
    # Our config has n_head=8, n_kv_head=4
    print(f"Testing MQA: {config.n_head} query heads, {config.n_kv_head} kv heads")
    print(f"enable_gqa should be True when n_head != n_kv_head\\n")
    
    # Check the attention layer
    attn_layer = model.transformer.h[0].attn
    
    # Create dummy input
    B, T = 2, 10
    x = torch.randn(B, T, config.n_embd).to(device)
    
    # Get Q, K, V
    q = attn_layer.c_q(x).view(B, T, config.n_head, config.n_embd // config.n_head)
    k = attn_layer.c_k(x).view(B, T, config.n_kv_head, config.n_embd // config.n_head)
    v = attn_layer.c_v(x).view(B, T, config.n_kv_head, config.n_embd // config.n_head)
    
    print(f"Query shape: {q.shape} - (batch, seq, n_heads={config.n_head}, head_dim)")
    print(f"Key shape:   {k.shape} - (batch, seq, n_kv_heads={config.n_kv_head}, head_dim)")
    print(f"Value shape: {v.shape} - (batch, seq, n_kv_heads={config.n_kv_head}, head_dim)")
    
    # Test the enable_gqa logic
    enable_gqa_buggy = config.n_head == config.n_kv_head  # Bug version
    enable_gqa_correct = config.n_head != config.n_kv_head  # Correct version
    
    print(f"\\nBuggy enable_gqa: {enable_gqa_buggy}")
    print(f"Correct enable_gqa: {enable_gqa_correct}")
    
    if enable_gqa_buggy != enable_gqa_correct:
        print("\\n❌ BUG #3 DETECTED! MQA condition is inverted!")
        print("💡 enable_gqa should be True when n_head != n_kv_head")
        print("💡 Fix in forward() method of CausalSelfAttention")

test_mqa_shapes()

## Bug #5: Tool Use State Machine

Test the calculator tool integration.

In [None]:
# Test calculator tool
def test_calculator_tool():
    print("Testing calculator tool integration...\\n")
    
    # Create a conversation that uses the calculator
    conversation = {
        "messages": [
            {"role": "user", "content": "What is 25 + 17?"},
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": "Let me calculate that: "},
                    {"type": "python", "text": "25 + 17"},
                    {"type": "python_output", "text": "42"},
                    {"type": "text", "text": " The answer is 42."}
                ]
            }
        ]
    }
    
    # Tokenize to see the structure
    ids, mask = tokenizer.render_conversation(conversation)
    
    # Visualize the tokenization
    print("Tokenized conversation (first 100 tokens):")
    print(tokenizer.visualize_tokenization(ids[:100], mask[:100]))
    
    # Now test generation with calculator
    engine = Engine(model, tokenizer)
    
    # Prepare a prompt that triggers calculator use
    calc_prompt = tokenizer.encode("Calculate 10 + 15: <|python_start|>10 + 15<|python_end|>")
    
    print("\\n\\nTesting calculator execution...")
    print("Input includes: <|python_start|>10 + 15<|python_end|>")
    print("Expected: <|output_start|>25<|output_end|>\\n")
    
    # Generate and track tokens
    generated_tokens = []
    special_tokens = []
    
    for token_col, masks in engine.generate(calc_prompt, max_tokens=20):
        token = token_col[0]
        generated_tokens.append(token)
        
        # Check for special tokens
        decoded = tokenizer.decode([token])
        if decoded.startswith('<|') and decoded.endswith('|>'):
            special_tokens.append((len(generated_tokens)-1, decoded))
    
    print("Special tokens found:")
    for pos, tok in special_tokens:
        print(f"  Position {pos}: {tok}")
    
    # Check order
    if len(special_tokens) >= 2:
        if special_tokens[0][1] != '<|output_start|>':
            print("\\n❌ BUG #5 DETECTED! Tool output tokens in wrong order!")
            print("💡 Calculator result injection is messed up.")
            print("💡 Check the order of output_start and output_end in engine.py")
    
    # Show full output
    full_output = tokenizer.decode(calc_prompt + generated_tokens)
    print(f"\\nFull output: '{full_output}'")

test_calculator_tool()

## Summary

After fixing all bugs:

1. **Rotary Embeddings**: Fix frequency calculation (remove *2)
2. **KV Cache Position**: Increment after last layer, not first
3. **MQA Condition**: Use != instead of == for enable_gqa
4. **Memory Leak**: Reset kv_cache in reset(), fix growth factor
5. **Tool Output**: Correct order: output_start, tokens, output_end

These bugs demonstrate the complexity of modern transformer optimizations!

In [None]:
print("🎉 Congratulations on completing the advanced architecture challenge!\\n")
print("You've debugged:")
print("- Positional encoding mathematics")
print("- Stateful caching mechanisms")
print("- Attention optimization techniques")
print("- Memory management")
print("- Tool integration state machines")
print("\\nThese are real issues that can occur in production LLM code!")