# Titan-Min Sanity Checks

This notebook contains quick smoke tests to verify that all core components of the Titan-Min repository work correctly.

## Tests Covered:
1. **Dataset Construction**: Validate NEEDLE placement and sample properties
2. **Model Construction**: Test forward pass shapes and ablation variants
3. **Training Step**: Ensure loss computation and gradient flow work
4. **End-to-End**: Complete training/validation workflow

In [None]:
# Setup imports and path
import sys
import os
sys.path.append('../src')

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader

# Import our modules
from data.niah import NIAHDataset, collate
from models.titan_min import TitanClassifier
from models.heads import position_logits

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✅ Imports successful!")

## 1. Dataset Construction Tests

In [None]:
# Test dataset construction
print("🔍 Testing dataset construction...")

dataset = NIAHDataset(n_samples=100, seed=42)
print(f"Dataset created with {len(dataset)} samples")

# Test sample properties
for i in range(5):
    tokens, needle_pos, length = dataset[i]
    
    # Check needle position is valid
    assert 0 <= needle_pos < length, f"Invalid needle position: {needle_pos} not in [0, {length-1}]"
    
    # Check exactly one needle
    needle_count = (tokens[:length] == 127).sum().item()
    assert needle_count == 1, f"Expected 1 needle, found {needle_count}"
    
    # Check needle is at correct position
    actual_pos = torch.where(tokens[:length] == 127)[0].item()
    assert actual_pos == needle_pos, f"Needle at {actual_pos}, expected {needle_pos}"
    
    print(f"Sample {i}: length={length}, needle_pos={needle_pos} ✓")

print("✅ Dataset construction tests passed!")

In [None]:
# Test collate function
print("🔍 Testing collate function...")

batch_samples = [dataset[i] for i in range(5)]
batch_tokens, batch_needle_pos, batch_lengths = collate(batch_samples)

print(f"Batch shapes: tokens={batch_tokens.shape}, positions={len(batch_needle_pos)}, lengths={len(batch_lengths)}")

# Verify left-padding
for i in range(5):
    length = batch_lengths[i]
    max_len = batch_tokens.shape[1]
    
    if length < max_len:
        pad_tokens = batch_tokens[i, :max_len-length]
        assert (pad_tokens == 0).all(), "Left padding should be zeros"
    
    print(f"Sample {i}: length={length}, padded_length={max_len} ✓")

print("✅ Collate function tests passed!")

## 2. Model Construction Tests

In [None]:
# Test model construction and forward pass
print("🔍 Testing model construction...")

model = TitanClassifier(
    vocab_size=128,
    dim=64,  # Small for testing
    n_heads=4,
    n_layers=2,
    n_mem=2
)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Test forward pass
batch_size, seq_len = 2, 32
x = torch.randint(0, 128, (batch_size, seq_len))

model.eval()
with torch.no_grad():
    h_tokens, rep = model(x)

print(f"Forward pass shapes: tokens={h_tokens.shape}, rep={rep.shape}")

# Assert expected shapes
assert h_tokens.shape == (batch_size, seq_len, 64), f"Wrong token shape: {h_tokens.shape}"
assert rep.shape == (batch_size, 64), f"Wrong rep shape: {rep.shape}"

print("✅ Model construction tests passed!")

In [None]:
# Test position logits and masking
print("🔍 Testing position logits...")

lengths = torch.tensor([20, 25])
logits = position_logits(rep, h_tokens, lengths)

print(f"Logits shape: {logits.shape}")
assert logits.shape == (batch_size, seq_len), f"Wrong logits shape: {logits.shape}"

# Check masking
for i in range(batch_size):
    length = lengths[i]
    
    # Positions >= length should be -inf
    masked = logits[i, length:]
    assert torch.all(torch.isinf(masked) & (masked < 0)), "Masked positions should be -inf"
    
    # Valid positions should be finite
    valid = logits[i, :length]
    assert torch.all(torch.isfinite(valid)), "Valid positions should be finite"
    
    print(f"Sample {i}: length={length}, valid_logits={valid.shape[0]}, masked={masked.shape[0]} ✓")

print("✅ Position logits tests passed!")

In [None]:
# Test ablation variants
print("🔍 Testing ablation variants...")

ablation_configs = [
    {"no_memory": True},
    {"no_dsconv": True},
    {"no_l2": True},
    {"activation": "relu"},
    {"no_memory": True, "no_dsconv": True, "no_l2": True, "activation": "relu"}
]

for i, config in enumerate(ablation_configs):
    model_abl = TitanClassifier(
        vocab_size=128, dim=64, n_heads=4, n_layers=2, n_mem=2, **config
    )
    
    model_abl.eval()
    with torch.no_grad():
        h_tokens_abl, rep_abl = model_abl(x)
    
    assert h_tokens_abl.shape == (batch_size, seq_len, 64), f"Ablation {i}: wrong shape"
    assert rep_abl.shape == (batch_size, 64), f"Ablation {i}: wrong shape"
    
    print(f"Ablation {i} ({config}): ✓")

print("✅ Ablation tests passed!")

## 3. Training Step Tests

In [None]:
# Test training step
print("🔍 Testing training step...")

model = TitanClassifier(vocab_size=128, dim=64, n_heads=4, n_layers=2, n_mem=2)
dataset = NIAHDataset(n_samples=50, seed=42)
dataloader = DataLoader(dataset, batch_size=8, collate_fn=collate)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

model.train()

# Get a batch
batch_tokens, batch_needle_pos, batch_lengths = next(iter(dataloader))
print(f"Batch loaded: {batch_tokens.shape}")

# Run training steps
losses = []
for step in range(5):
    optimizer.zero_grad()
    
    # Forward pass
    h_tokens, rep = model(batch_tokens)
    logits = position_logits(rep, h_tokens, batch_lengths)
    loss = criterion(logits, batch_needle_pos)
    
    # Assertions
    assert torch.isfinite(loss), f"Loss not finite: {loss}"
    assert loss.item() > 0, f"Loss should be positive: {loss.item()}"
    
    losses.append(loss.item())
    
    # Backward pass
    loss.backward()
    
    # Check gradients
    grad_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            assert torch.isfinite(torch.tensor(grad_norm)), f"Infinite gradient in {name}"
            grad_norms.append(grad_norm)
    
    optimizer.step()
    
    print(f"Step {step}: loss={loss.item():.4f}, avg_grad_norm={np.mean(grad_norms):.4f}")

# Check loss trend (not strict)
if len(losses) >= 3:
    early_avg = np.mean(losses[:2])
    late_avg = np.mean(losses[-2:])
    print(f"Loss trend: {early_avg:.4f} → {late_avg:.4f}")

print("✅ Training step tests passed!")

## 4. End-to-End Workflow Test

In [None]:
# Test complete training/validation workflow
print("🔍 Testing end-to-end workflow...")

# Create small dataset
dataset = NIAHDataset(n_samples=32, seed=42)
train_dataset = torch.utils.data.Subset(dataset, range(24))
val_dataset = torch.utils.data.Subset(dataset, range(24, 32))

train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=collate)
val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=collate)

# Create model
model = TitanClassifier(vocab_size=128, dim=32, n_heads=2, n_layers=1, n_mem=1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

print(f"Setup: {len(train_dataset)} train, {len(val_dataset)} val samples")

# Training
model.train()
train_losses = []

for batch_tokens, batch_needle_pos, batch_lengths in train_loader:
    optimizer.zero_grad()
    h_tokens, rep = model(batch_tokens)
    logits = position_logits(rep, h_tokens, batch_lengths)
    loss = criterion(logits, batch_needle_pos)
    
    assert torch.isfinite(loss), "Training loss not finite"
    train_losses.append(loss.item())
    
    loss.backward()
    optimizer.step()

avg_train_loss = np.mean(train_losses)
print(f"Training: avg_loss={avg_train_loss:.4f}")

# Validation
model.eval()
val_losses = []
correct = 0
total = 0

with torch.no_grad():
    for batch_tokens, batch_needle_pos, batch_lengths in val_loader:
        h_tokens, rep = model(batch_tokens)
        logits = position_logits(rep, h_tokens, batch_lengths)
        loss = criterion(logits, batch_needle_pos)
        
        assert torch.isfinite(loss), "Validation loss not finite"
        val_losses.append(loss.item())
        
        predictions = logits.argmax(dim=1)
        correct += (predictions == batch_needle_pos).sum().item()
        total += batch_needle_pos.size(0)

avg_val_loss = np.mean(val_losses)
accuracy = correct / total if total > 0 else 0.0

print(f"Validation: avg_loss={avg_val_loss:.4f}, accuracy={accuracy:.4f}")

# Sanity checks
assert 0.0 <= accuracy <= 1.0, f"Invalid accuracy: {accuracy}"
assert avg_val_loss > 0, f"Invalid validation loss: {avg_val_loss}"

print("✅ End-to-end workflow tests passed!")

## Summary

If all cells above run without assertion errors, the Titan-Min repository is working correctly!

**Tests Completed:**
- ✅ Dataset construction and NEEDLE validation
- ✅ Model forward pass and shape validation
- ✅ Position logits and masking
- ✅ Ablation variant compatibility
- ✅ Training step with finite loss and gradients
- ✅ End-to-end training/validation workflow

The repository is ready for full-scale experiments!