# Sparse Autoencoder (SAE) Training

**Experiment:** Train SAEs to discover interpretable features in model activations

**Date:** 2025-11-05

**Research Question:** What interpretable features emerge when we train sparse autoencoders on LLM activations?

**Goals:**
- Train SAE on model activations
- Discover interpretable features
- Analyze feature sparsity and quality
- Export features for dashboard exploration

## Setup

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from harness import ExperimentConfig, ExperimentResult, get_tracker

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 1. Define SAE Architecture

Sparse Autoencoder with L1 regularization for feature sparsity

In [None]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_weight=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sparsity_weight = sparsity_weight
        
        # Encoder
        self.encoder = nn.Linear(input_dim, hidden_dim)
        
        # Decoder
        self.decoder = nn.Linear(hidden_dim, input_dim)
        
        # Initialize weights
        nn.init.xavier_uniform_(self.encoder.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.zeros_(self.encoder.bias)
        nn.init.zeros_(self.decoder.bias)
    
    def forward(self, x):
        # Encode
        hidden = torch.relu(self.encoder(x))
        
        # Decode
        reconstruction = self.decoder(hidden)
        
        return reconstruction, hidden
    
    def loss(self, x):
        reconstruction, hidden = self.forward(x)
        
        # Reconstruction loss (MSE)
        recon_loss = torch.mean((x - reconstruction) ** 2)
        
        # Sparsity loss (L1 on hidden activations)
        sparsity_loss = torch.mean(torch.abs(hidden))
        
        # Total loss
        total_loss = recon_loss + self.sparsity_weight * sparsity_loss
        
        return total_loss, recon_loss, sparsity_loss

print("SAE architecture defined")

## 2. Generate Training Data

Collect activations from a model to train the SAE

Note: In production, this would load real model activations. For demonstration, we'll use synthetic data.

In [None]:
# TODO: Replace with real activation collection
# For now, generate synthetic data

def generate_synthetic_activations(n_samples=10000, activation_dim=768, n_features=100, sparsity=0.1):
    """
    Generate synthetic activations with sparse underlying features.
    
    Real implementation would:
    1. Load a model (e.g., via MLX or transformers)
    2. Pass text through and capture layer activations
    3. Store activations for SAE training
    """
    # Create sparse feature matrix
    features = torch.zeros((n_samples, n_features))
    for i in range(n_samples):
        # Activate a few random features
        active_features = torch.randperm(n_features)[:int(n_features * sparsity)]
        features[i, active_features] = torch.randn(len(active_features)).abs()
    
    # Project to activation space
    projection = torch.randn((n_features, activation_dim))
    activations = features @ projection
    
    # Add noise
    activations += torch.randn_like(activations) * 0.1
    
    return activations

# Generate training data
activation_dim = 768  # e.g., hidden size of model
n_train_samples = 10000
n_val_samples = 1000

train_activations = generate_synthetic_activations(n_train_samples, activation_dim)
val_activations = generate_synthetic_activations(n_val_samples, activation_dim)

print(f"Training data shape: {train_activations.shape}")
print(f"Validation data shape: {val_activations.shape}")

## 3. Train SAE

In [None]:
# Hyperparameters
hidden_dim = 2048  # Typically larger than input for overcomplete representation
sparsity_weight = 0.05
batch_size = 256
learning_rate = 1e-3
n_epochs = 20

# Create model
sae = SparseAutoencoder(
    input_dim=activation_dim,
    hidden_dim=hidden_dim,
    sparsity_weight=sparsity_weight,
).to(device)

optimizer = optim.Adam(sae.parameters(), lr=learning_rate)

# Create data loaders
train_loader = DataLoader(
    TensorDataset(train_activations),
    batch_size=batch_size,
    shuffle=True,
)
val_loader = DataLoader(
    TensorDataset(val_activations),
    batch_size=batch_size,
)

# Training loop
train_losses = []
val_losses = []

print(f"Training SAE ({hidden_dim} features)...\n")

for epoch in range(n_epochs):
    # Training
    sae.train()
    epoch_loss = 0
    for (batch,) in train_loader:
        batch = batch.to(device)
        
        optimizer.zero_grad()
        loss, recon_loss, sparsity_loss = sae.loss(batch)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    train_losses.append(epoch_loss / len(train_loader))
    
    # Validation
    sae.eval()
    val_loss = 0
    with torch.no_grad():
        for (batch,) in val_loader:
            batch = batch.to(device)
            loss, _, _ = sae.loss(batch)
            val_loss += loss.item()
    
    val_losses.append(val_loss / len(val_loader))
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}")
        print(f"  Train Loss: {train_losses[-1]:.4f}")
        print(f"  Val Loss: {val_losses[-1]:.4f}")

print("\nTraining complete!")

## 4. Analyze Training

In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SAE Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 5. Analyze Feature Sparsity

In [None]:
# Measure feature activation sparsity
sae.eval()
with torch.no_grad():
    sample_batch = val_activations[:1000].to(device)
    _, features = sae(sample_batch)
    features = features.cpu().numpy()

# Compute sparsity metrics
feature_means = np.mean(features, axis=0)
feature_activation_freq = np.mean(features > 0, axis=0)

print(f"Feature statistics:")
print(f"  Mean activation: {np.mean(feature_means):.4f}")
print(f"  Std activation: {np.std(feature_means):.4f}")
print(f"  Average sparsity: {np.mean(feature_activation_freq):.2%}")
print(f"  Features active >1%: {np.sum(feature_activation_freq > 0.01)}")

# Plot feature activation histogram
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist(feature_activation_freq, bins=50)
plt.xlabel('Activation Frequency')
plt.ylabel('Number of Features')
plt.title('Feature Activation Frequency Distribution')

plt.subplot(1, 2, 2)
plt.hist(feature_means, bins=50)
plt.xlabel('Mean Activation')
plt.ylabel('Number of Features')
plt.title('Feature Mean Activation Distribution')

plt.tight_layout()
plt.show()

## 6. Export for Dashboard

Save trained SAE for exploration in the Lens dashboard

In [None]:
# Save model
save_path = "../models/sae_example.pt"
torch.save({
    'model_state_dict': sae.state_dict(),
    'config': {
        'input_dim': activation_dim,
        'hidden_dim': hidden_dim,
        'sparsity_weight': sparsity_weight,
    },
    'training_config': {
        'n_epochs': n_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
    },
    'metrics': {
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'avg_sparsity': float(np.mean(feature_activation_freq)),
    },
}, save_path)

print(f"Saved SAE to: {save_path}")

# Export feature metadata
import json

metadata = {
    'n_features': hidden_dim,
    'activation_dim': activation_dim,
    'avg_sparsity': float(np.mean(feature_activation_freq)),
    'active_features': int(np.sum(feature_activation_freq > 0.01)),
}

with open('../models/sae_example_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("Exported metadata")

## 7. Track Experiment

In [None]:
config = ExperimentConfig(
    experiment_name="sae_training",
    task_type="feature_discovery",
    strategy="sparse_autoencoder",
    provider="local",
    model="sae",
)

tracker = get_tracker()
run_dir = tracker.start_experiment(config)

result = ExperimentResult(
    config=config,
    task_input=f"Train SAE on {n_train_samples} activations",
    output=f"Trained {hidden_dim}-feature SAE",
    eval_scores={
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'avg_sparsity': float(np.mean(feature_activation_freq)),
    },
    eval_metadata={
        'n_features': hidden_dim,
        'activation_dim': activation_dim,
        'sparsity_weight': sparsity_weight,
        'n_epochs': n_epochs,
    },
    success=True,
)

tracker.log_result(result)
summary = tracker.finish_experiment()
print(f"Experiment logged in: {run_dir}")

## Next Steps

1. **Feature Analysis** (02_feature_analysis.ipynb): Analyze what concepts each feature represents
2. **Feature Steering** (03_feature_steering.ipynb): Use features to steer model behavior
3. **Dashboard Exploration**: Load SAE into Lens dashboard for interactive exploration

## Key Research Questions

- What semantic concepts do individual features capture?
- How do features compose to represent complex concepts?
- Do features align with human-interpretable concepts?
- Can we steer models via feature activation?
- How do features differ across layers?