# Example 1: Training TopKSAE Model

This notebook demonstrates how to:
1. Load a language model and dataset
2. Save activations from a specific layer
3. Train a TopK Sparse Autoencoder (TopKSAE) on those activations using the new `SaeTrainer` composite class
4. Save the trained TopKSAE model

The training uses overcomplete's `train_sae` functions via the `SaeTrainer` composite class, which is automatically available on all SAE instances via `sae.trainer`.

All files (trained TopKSAE model, training metadata, activations) will be saved under `store/{model_id}/` for organized, model-specific storage.


In [83]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
from datetime import datetime

from amber.adapters import TextDataset
from amber.core.language_model import LanguageModel
from amber.mechanistic.sae.modules.topk_sae import TopKSae, TopKSAETrainingConfig
from amber.store.local_store import LocalStore

print("‚úÖ Imports completed")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
‚úÖ Imports completed


In [84]:
# Configuration
MODEL_ID = "sshleifer/tiny-gpt2"  # Small model for quick experimentation
HF_DATASET = "roneneldan/TinyStories"
DATA_SPLIT = "train"
TEXT_FIELD = "text"
DATA_LIMIT = 1000  # Number of text samples to use
MAX_LENGTH = 64  # Maximum sequence length
BATCH_SIZE_SAVE = 16  # Batch size for saving activations
BATCH_SIZE_TRAIN = 32  # Batch size for SAE training

# TopKSAE configuration
TOP_K = 8  # Number of top activations to keep (sparsity parameter)

# Choose which layer to hook - you can inspect available layers with model.layers.print_layer_names()
LAYER_SIGNATURE = 'gpt2lmheadmodel_transformer_h_0_attn_c_attn'  # Attention layer (better activations)

# Storage locations - will be updated after model loading to use model_id
STORE_DIR = Path("store")
CACHE_DIR = Path("store/cache")
RUN_ID = f"topk_sae_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# Model-specific paths will be set after loading the model
SAE_MODEL_PATH = None  # Will be set to store/{model_id}/topk_sae_model.pt
METADATA_PATH = None  # Will be set to store/{model_id}/training_metadata.json

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else None  # Use half precision on GPU

print("üöÄ Starting TopKSAE Training Example")
print(f"üì± Using device: {DEVICE}")
print(f"üîß Model: {MODEL_ID}")
print(f"üìä Dataset: {HF_DATASET}")
print(f"üéØ Target layer: {LAYER_SIGNATURE}")
print(f"üî¢ TopK parameter: {TOP_K}")
print()

# Create output directories
CACHE_DIR.mkdir(parents=True, exist_ok=True)
STORE_DIR.mkdir(parents=True, exist_ok=True)
print("‚úÖ Output directories created")


üöÄ Starting TopKSAE Training Example
üì± Using device: cpu
üîß Model: sshleifer/tiny-gpt2
üìä Dataset: roneneldan/TinyStories
üéØ Target layer: gpt2lmheadmodel_transformer_h_0_attn_c_attn
üî¢ TopK parameter: 8

‚úÖ Output directories created


In [85]:
# Step 1: Load language model with context
print("üì• Loading language model...")

store = LocalStore(STORE_DIR)
# Load model first to get model_id
lm = LanguageModel.from_huggingface(MODEL_ID, store=store)
lm.model.to(DEVICE)

# Create model-specific directory for organizing all files
MODEL_DIR = STORE_DIR / lm.model_id

# Create store under model-specific directory (so activations are also organized by model)

# Set the store we want to use (overrides default)
lm.context.store = store

# Update paths to use model-specific directory
SAE_MODEL_PATH = MODEL_DIR / "topk_sae_model.pt"
METADATA_PATH = MODEL_DIR / "training_metadata.json"

# Print available layers for reference
print("üîç Available layers:")
lm.layers.print_layer_names()
print(f"‚úÖ Model loaded: {lm.model_id}")
print(f"üì± Device: {DEVICE}")
print(f"üìÅ Store base: {STORE_DIR}")
print(f"üìÅ Model directory: {MODEL_DIR}")
print(f"üìÅ Store location: {lm.context.store.base_path}")
print(f"üíæ SAE model will be saved to: {SAE_MODEL_PATH}")
print(f"üíæ Metadata will be saved to: {METADATA_PATH}")


üì• Loading language model...
üîç Available layers:
gpt2lmheadmodel_transformer: No weight
gpt2lmheadmodel_transformer_wte: torch.Size([50257, 2])
gpt2lmheadmodel_transformer_wpe: torch.Size([1024, 2])
gpt2lmheadmodel_transformer_drop: No weight
gpt2lmheadmodel_transformer_h: No weight
gpt2lmheadmodel_transformer_h_0: No weight
gpt2lmheadmodel_transformer_h_0_ln_1: torch.Size([2])
gpt2lmheadmodel_transformer_h_0_attn: No weight
gpt2lmheadmodel_transformer_h_0_attn_c_attn: torch.Size([2, 6])
gpt2lmheadmodel_transformer_h_0_attn_c_proj: torch.Size([2, 2])
gpt2lmheadmodel_transformer_h_0_attn_attn_dropout: No weight
gpt2lmheadmodel_transformer_h_0_attn_resid_dropout: No weight
gpt2lmheadmodel_transformer_h_0_ln_2: torch.Size([2])
gpt2lmheadmodel_transformer_h_0_mlp: No weight
gpt2lmheadmodel_transformer_h_0_mlp_c_fc: torch.Size([2, 8])
gpt2lmheadmodel_transformer_h_0_mlp_c_proj: torch.Size([8, 2])
gpt2lmheadmodel_transformer_h_0_mlp_act: No weight
gpt2lmheadmodel_transformer_h_0_mlp_dro

In [86]:
# Step 2: Load dataset
print("üì• Loading dataset...")
dataset = TextDataset.from_huggingface(
    HF_DATASET,
    split=DATA_SPLIT,
    store=store,
    text_field=TEXT_FIELD,
    limit=DATA_LIMIT,
)
print(f"‚úÖ Loaded {len(dataset)} text samples")


üì• Loading dataset...


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 498965.50 examples/s]

‚úÖ Loaded 1000 text samples





In [87]:
# Step 3: Save activations
print("üíæ Saving activations...")

# Use the store that was set on the language model
lm.activations.save_activations_dataset(
    dataset,
    layer_signature=LAYER_SIGNATURE,
    run_name=RUN_ID,
    batch_size=BATCH_SIZE_SAVE,
    autocast=False,  # Disable autocast for consistency
)

# Verify activations were saved
batches = lm.context.store.list_run_batches(RUN_ID)
print(f"‚úÖ Saved {len(batches)} batches of activations")
print(f"üìÅ Run ID: {RUN_ID}")
print(f"üìÅ Store location: {lm.context.store.base_path}")


üíæ Saving activations...
‚úÖ Saved 63 batches of activations
üìÅ Run ID: topk_sae_training_20251114_000835
üìÅ Store location: store


In [88]:
# Step 4: Create TopKSAE model
print("üèóÔ∏è Creating TopKSAE model...")

# Get the hidden dimension from the first batch
first_batch = lm.context.store.get_run_batch(RUN_ID, 0)
if isinstance(first_batch, dict):
    activations = first_batch["activations"]
else:
    activations = first_batch[0]  # Assume first tensor is activations

hidden_dim = activations.shape[-1]  # Last dimension is hidden size
print(f"üìè Hidden dimension: {hidden_dim}")

sae = TopKSae(
    n_latents=hidden_dim * 4,
    n_inputs=hidden_dim,
    k=TOP_K,
    device=DEVICE,
)

print(f"üß† TopKSAE architecture: {hidden_dim} ‚Üí {sae.context.n_latents} ‚Üí {hidden_dim}")
print(f"üî¢ TopK parameter: {sae.k}")
print(f"üîß Device: {DEVICE}")


üèóÔ∏è Creating TopKSAE model...
üìè Hidden dimension: 6
üß† TopKSAE architecture: 6 ‚Üí 24 ‚Üí 6
üî¢ TopK parameter: 8
üîß Device: cpu


In [None]:
# Step 5: Train TopKSAE using SaeTrainer
print("üèãÔ∏è Training TopKSAE...")
print("üìù Note: Training uses overcomplete's train_sae functions via the SaeTrainer composite class")
print(f"üîß Trainer available at: sae.trainer (type: {type(sae.trainer).__name__})")
print()

# Configure training parameters
# Note: TopKSAETrainingConfig is an alias for SaeTrainingConfig
# You can also use SaeTrainingConfig directly from sae_trainer module
config = TopKSAETrainingConfig(
    epochs=100,
    batch_size=BATCH_SIZE_TRAIN,
    lr=1e-3,
    l1_lambda=1e-4,  # L1 sparsity penalty
    device=DEVICE,
    dtype=DTYPE,
    max_batches_per_epoch=50,  # Limit batches per epoch for quick training
    verbose=True,  # Enable progress logging
    use_amp=True,
    amp_dtype=DTYPE,
    clip_grad=1.0,  # Gradient clipping (overcomplete parameter)
    monitoring=2,  # Detailed monitoring (0=silent, 1=basic, 2=detailed)
)

# Train using TopKSAE's train method (which delegates to sae.trainer.train())
# The trainer uses overcomplete's train_sae_amp or train_sae functions internally
history = sae.train(lm.context.store, RUN_ID, LAYER_SIGNATURE, config)

print()
print("‚úÖ Training completed!")
print(f"üìà Final loss: {history['loss'][-1]:.6f}")
print(f"üìà Final reconstruction MSE: {history['recon_mse'][-1]:.6f}")
print(f"üìà Final L1 penalty: {history['l1'][-1]:.6f}")


In [None]:
# Step 6: Save trained TopKSAE
print("üíæ Saving trained TopKSAE...")

# Save using TopKSAE's save method (saves overcomplete model + our metadata)
sae.save(
    name="topk_sae_model",
    path=SAE_MODEL_PATH.parent
)

print(f"‚úÖ TopKSAE saved to: {SAE_MODEL_PATH}")


In [72]:
# Step 7: Save run metadata for next example
import json

run_metadata = {
    "run_id": RUN_ID,
    "layer_signature": LAYER_SIGNATURE,
    "hidden_dim": hidden_dim,
    "n_latents": sae.context.n_latents,
    "k": sae.k,
    "model_id": MODEL_ID,
    "model_dir": str(MODEL_DIR),
    "dataset": HF_DATASET,
    "data_limit": DATA_LIMIT,
    "sae_model_path": str(SAE_MODEL_PATH),
    "store_dir": str(STORE_DIR),
    "cache_dir": str(CACHE_DIR),
    "training_history": history,
}

# Save metadata to model-specific directory
with open(METADATA_PATH, "w") as f:
    json.dump(run_metadata, f, indent=2)

print(f"üìã Training metadata saved to: {METADATA_PATH}")
print()
print("üéâ TopKSAE training completed successfully!")
print(f"üìÅ All files saved under model directory: {MODEL_DIR}")
print("üìù Next: Run 02_attach_sae_and_save_texts.ipynb to attach the TopKSAE and collect top texts")


üìã Training metadata saved to: store/sshleifer_tiny-gpt2/training_metadata.json

üéâ TopKSAE training completed successfully!
üìÅ All files saved under model directory: store/sshleifer_tiny-gpt2
üìù Next: Run 02_attach_sae_and_save_texts.ipynb to attach the TopKSAE and collect top texts
