# Phase 1 Quick Start: GhostTrack Infrastructure

This notebook demonstrates how to use all Phase 1 components:
1. Configuration loading
2. Data loading (TruthfulQA)
3. GPT-2 model wrapper with hooks
4. JumpReLU SAE model

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt

## 1. Configuration Loading

In [None]:
from config import load_config

# Load configuration from .claude file
config = load_config()

print("Model Configuration:")
print(f"  Base model: {config.model.base_model}")
print(f"  d_model: {config.model.d_model}")
print(f"  n_layers: {config.model.n_layers}")

print("\nSAE Configuration:")
print(f"  Architecture: {config.sae.architecture}")
print(f"  d_hidden: {config.sae.d_hidden}")
print(f"  Threshold: {config.sae.threshold}")
print(f"  Lambda sparse: {config.sae.lambda_sparse}")

print("\nTracking Configuration:")
print(f"  Top-k features: {config.tracking.top_k_features}")
print(f"  Semantic weight: {config.tracking.semantic_weight}")
print(f"  Association threshold: {config.tracking.association_threshold}")

## 2. Data Loading (TruthfulQA)

In [None]:
from data import load_truthfulqa

print("Loading TruthfulQA dataset...")
print("Note: First load may take a few minutes to download from HuggingFace\n")

train_dataset, val_dataset, test_dataset = load_truthfulqa(
    cache_dir='../data/cache',
    seed=42
)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")
print(f"Total: {len(train_dataset) + len(val_dataset) + len(test_dataset)}")

In [None]:
# Examine a sample example
example = train_dataset[0]

print("Example Question-Answer Pair:")
print("="*70)
print(f"ID: {example.id}")
print(f"Category: {example.category}")
print(f"\nQuestion: {example.prompt}")
print(f"\nFactual Answer: {example.factual_answer}")
print(f"\nHallucinated Answer: {example.hallucinated_answer}")
print("="*70)

In [None]:
# Analyze category distribution
categories = train_dataset.get_categories()
category_counts = train_dataset.get_category_counts()

print(f"Number of categories: {len(categories)}")
print(f"\nTop 10 categories by count:")
sorted_cats = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)[:10]
for cat, count in sorted_cats:
    print(f"  {cat}: {count}")

## 3. GPT-2 Model Wrapper with Hooks

In [None]:
from models import GPT2WithResidualHooks

print("Loading GPT-2 model with hooks...")
model = GPT2WithResidualHooks(
    model_name='gpt2',
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print(f"Device: {model.device}")
print(f"Number of layers: {model.n_layers}")
print(f"Hidden dimension: {model.d_model}")

In [None]:
# Process some text
test_text = "The capital of France is Paris."

print(f"Processing text: '{test_text}'\n")
outputs = model.process_text(test_text)

print("Output keys:", outputs.keys())
print(f"\nLogits shape: {outputs['logits'].shape}")
print(f"Number of residual stream activations: {len(outputs['residual_stream'])}")
print(f"Number of MLP activations: {len(outputs['mlp_outputs'])}")
print(f"Number of attention activations: {len(outputs['attn_outputs'])}")

print(f"\nResidual stream shape per layer: {outputs['residual_stream'][0].shape}")
print(f"Format: [batch_size, seq_length, hidden_dim]")

In [None]:
# Visualize activation magnitudes across layers
layer_norms = []
for i, residual in enumerate(outputs['residual_stream']):
    norm = torch.norm(residual, dim=-1).mean().item()
    layer_norms.append(norm)

plt.figure(figsize=(10, 5))
plt.plot(range(12), layer_norms, marker='o')
plt.xlabel('Layer')
plt.ylabel('Average Activation Norm')
plt.title('Residual Stream Activation Norms Across Layers')
plt.grid(alpha=0.3)
plt.show()

## 4. JumpReLU SAE Model

In [None]:
from models import JumpReLUSAE

# Create SAE
sae = JumpReLUSAE(
    d_model=768,
    d_hidden=4096,
    threshold=0.1,
    lambda_sparse=0.01
)

print(f"SAE created with:")
print(f"  Input dim: {sae.d_model}")
print(f"  Hidden dim: {sae.d_hidden}")
print(f"  Threshold: {sae.threshold.item():.3f}")
print(f"  Total parameters: {sae.get_num_parameters():,}")

In [None]:
# Test SAE on actual model activations
# Use layer 6 (middle layer)
layer_6_activations = outputs['residual_stream'][6]

print(f"Input shape: {layer_6_activations.shape}")

# Forward pass through SAE
sae_output = sae.forward(layer_6_activations)

print(f"\nSAE Output:")
print(f"  Reconstruction shape: {sae_output['reconstruction'].shape}")
print(f"  Features shape: {sae_output['features'].shape}")
print(f"  Error shape: {sae_output['error'].shape}")
print(f"  Sparsity: {sae_output['sparsity'].item():.4f}")
print(f"  Active features per token: {sae.count_active_features(layer_6_activations):.1f}")

In [None]:
# Compute reconstruction loss
loss_dict = sae.loss(layer_6_activations, return_components=True)

print("Loss Components:")
print(f"  Total loss: {loss_dict['total_loss'].item():.6f}")
print(f"  Reconstruction loss: {loss_dict['recon_loss'].item():.6f}")
print(f"  Sparsity loss: {loss_dict['sparsity_loss'].item():.6f}")
print(f"  Sparsity: {loss_dict['sparsity'].item():.4f}")

In [None]:
# Visualize sparse activation pattern
features = sae_output['features'][0].detach().cpu().numpy()  # [seq_len, 4096]

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(features.T, aspect='auto', cmap='hot', interpolation='nearest')
plt.colorbar(label='Activation')
plt.xlabel('Token Position')
plt.ylabel('Feature ID')
plt.title('SAE Feature Activations (Sparse)')

plt.subplot(1, 2, 2)
active_per_token = (features > 0).sum(axis=1)
plt.bar(range(len(active_per_token)), active_per_token)
plt.xlabel('Token Position')
plt.ylabel('Number of Active Features')
plt.title('Active Features per Token')
plt.axhline(y=active_per_token.mean(), color='r', linestyle='--', 
            label=f'Mean: {active_per_token.mean():.1f}')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Test JumpReLU activation function directly
test_values = torch.tensor([[-0.5, -0.2, 0.0, 0.05, 0.1, 0.2, 0.5, 1.0]])
activated = sae.jumprelu(test_values)

print("JumpReLU Activation Test:")
print(f"Threshold: {sae.threshold.item():.3f}")
print(f"\nInput:  {test_values[0].tolist()}")
print(f"Output: {activated[0].tolist()}")
print(f"\nNote: Values > {sae.threshold.item():.3f} are preserved, others set to 0")

## 5. Combining Components: Full Pipeline Demo

In [None]:
# Load a factual vs hallucinated pair
example = test_dataset[5]  # Pick an example

print("Question:", example.prompt)
print("\nFactual Answer:", example.factual_answer)
print("Hallucinated Answer:", example.hallucinated_answer)

In [None]:
# Process both through model
factual_outputs = model.process_text(example.prompt + " " + example.factual_answer)
halluc_outputs = model.process_text(example.prompt + " " + example.hallucinated_answer)

print("Processed both answers through GPT-2")
print(f"Factual seq length: {factual_outputs['logits'].shape[1]}")
print(f"Halluc seq length: {halluc_outputs['logits'].shape[1]}")

In [None]:
# Pass through SAE (layer 8 as example)
factual_layer8 = factual_outputs['residual_stream'][8]
halluc_layer8 = halluc_outputs['residual_stream'][8]

factual_sae = sae.forward(factual_layer8)
halluc_sae = sae.forward(halluc_layer8)

print("\nSAE Analysis:")
print(f"Factual sparsity: {factual_sae['sparsity'].item():.4f}")
print(f"Halluc sparsity: {halluc_sae['sparsity'].item():.4f}")

factual_active = sae.count_active_features(factual_layer8)
halluc_active = sae.count_active_features(halluc_layer8)

print(f"\nFactual active features: {factual_active:.1f}")
print(f"Halluc active features: {halluc_active:.1f}")
print(f"Difference: {abs(factual_active - halluc_active):.1f}")

## Summary

Phase 1 Infrastructure is complete! You've seen:

1. ✅ **Configuration system** - Load settings from YAML
2. ✅ **Data loading** - TruthfulQA with train/val/test splits
3. ✅ **Model wrapper** - Extract activations from GPT-2
4. ✅ **SAE model** - Sparse feature extraction with JumpReLU

### Next Steps (Phase 2):
- Train SAEs on Wikipedia corpus
- Achieve reconstruction loss < 0.01
- Interpret learned features
- Build hypothesis tracking system

### Questions?
See `PHASE1_SUMMARY.md` for detailed documentation.