# Example 1: Training SAE Model

This notebook demonstrates how to:
1. Load a language model and dataset
2. Save activations from a specific layer
3. Train a Sparse Autoencoder (SAE) on those activations
4. Save the trained SAE model

The trained SAE will be saved to `outputs/sae_model.pt` for use in the next example.


In [None]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
from datetime import datetime

from amber.store import LocalStore
from amber.adapters.text_snippet_dataset import TextSnippetDataset
from amber.core.language_model import LanguageModel
from amber.mechanistic.autoencoder.autoencoder import Autoencoder
from amber.mechanistic.autoencoder.train import SAETrainer, SAETrainingConfig

print("‚úÖ Imports completed")


In [7]:
# Configuration
MODEL_ID = "sshleifer/tiny-gpt2"  # Small model for quick experimentation
HF_DATASET = "roneneldan/TinyStories"
DATA_SPLIT = "train"
TEXT_FIELD = "text"
DATA_LIMIT = 1000  # Number of text samples to use
MAX_LENGTH = 64    # Maximum sequence length
BATCH_SIZE_SAVE = 16  # Batch size for saving activations
BATCH_SIZE_TRAIN = 32  # Batch size for SAE training

# Choose which layer to hook - you can inspect available layers with model.layers.print_layer_names()
LAYER_SIGNATURE = 'gpt2lmheadmodel_transformer_h_0_attn_c_attn'  # Attention layer (better activations)

# Storage locations
CACHE_DIR = Path("outputs/cache")
STORE_DIR = Path("outputs/store")
RUN_ID = f"sae_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
SAE_MODEL_PATH = Path("outputs/sae_model.pt")

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else None  # Use half precision on GPU

print("üöÄ Starting SAE Training Example")
print(f"üì± Using device: {DEVICE}")
print(f"üîß Model: {MODEL_ID}")
print(f"üìä Dataset: {HF_DATASET}")
print(f"üéØ Target layer: {LAYER_SIGNATURE}")
print()

# Create output directories
CACHE_DIR.mkdir(parents=True, exist_ok=True)
STORE_DIR.mkdir(parents=True, exist_ok=True)
print("‚úÖ Output directories created")


üöÄ Starting SAE Training Example
üì± Using device: cpu
üîß Model: sshleifer/tiny-gpt2
üìä Dataset: roneneldan/TinyStories
üéØ Target layer: gpt2lmheadmodel_transformer_h_0_attn_c_attn

‚úÖ Output directories created


In [8]:
# Step 1: Load language model with context
print("üì• Loading language model...")


# Load model using context
lm = LanguageModel.from_huggingface(MODEL_ID)
lm.model.to(DEVICE)

# Print available layers for reference
print("üîç Available layers:")
lm.layers.print_layer_names()
print(f"‚úÖ Model loaded: {lm.model_id}")
print(f"üì± Device: {DEVICE}")


üì• Loading language model...
üîç Available layers:
gpt2lmheadmodel_transformer: No weight
gpt2lmheadmodel_transformer_wte: torch.Size([50257, 2])
gpt2lmheadmodel_transformer_wpe: torch.Size([1024, 2])
gpt2lmheadmodel_transformer_drop: No weight
gpt2lmheadmodel_transformer_h: No weight
gpt2lmheadmodel_transformer_h_0: No weight
gpt2lmheadmodel_transformer_h_0_ln_1: torch.Size([2])
gpt2lmheadmodel_transformer_h_0_attn: No weight
gpt2lmheadmodel_transformer_h_0_attn_c_attn: torch.Size([2, 6])
gpt2lmheadmodel_transformer_h_0_attn_c_proj: torch.Size([2, 2])
gpt2lmheadmodel_transformer_h_0_attn_attn_dropout: No weight
gpt2lmheadmodel_transformer_h_0_attn_resid_dropout: No weight
gpt2lmheadmodel_transformer_h_0_ln_2: torch.Size([2])
gpt2lmheadmodel_transformer_h_0_mlp: No weight
gpt2lmheadmodel_transformer_h_0_mlp_c_fc: torch.Size([2, 8])
gpt2lmheadmodel_transformer_h_0_mlp_c_proj: torch.Size([8, 2])
gpt2lmheadmodel_transformer_h_0_mlp_act: No weight
gpt2lmheadmodel_transformer_h_0_mlp_dro

In [9]:
# Step 2: Load dataset
print("üì• Loading dataset...")
dataset = TextSnippetDataset.from_huggingface(
    HF_DATASET,
    split=DATA_SPLIT,
    cache_dir=str(CACHE_DIR),
    text_field=TEXT_FIELD,
    limit=DATA_LIMIT,
)
print(f"‚úÖ Loaded {len(dataset)} text samples")


üì• Loading dataset...


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 487313.12 examples/s]

‚úÖ Loaded 1000 text samples





In [10]:
# Step 3: Save activations
print("üíæ Saving activations...")
store = LocalStore(STORE_DIR)

lm.activations.infer_and_save(
    dataset,
    layer_signature=LAYER_SIGNATURE,
    run_name=RUN_ID,
    store=store,
    batch_size=BATCH_SIZE_SAVE,
    autocast=False,  # Disable autocast for consistency
)

# Verify activations were saved
batches = store.list_run_batches(RUN_ID)
print(f"‚úÖ Saved {len(batches)} batches of activations")
print(f"üìÅ Run ID: {RUN_ID}")


üíæ Saving activations...
‚úÖ Saved 63 batches of activations
üìÅ Run ID: sae_training_20251028_234300


In [None]:
# Step 4: Create SAE model
print("üèóÔ∏è Creating SAE model...")

# Get the hidden dimension from the first batch
first_batch = store.get_run_batch(RUN_ID, 0)
if isinstance(first_batch, dict):
    activations = first_batch["activations"]
else:
    activations = first_batch[0]  # Assume first tensor is activations

hidden_dim = activations.shape[-1]  # Last dimension is hidden size
print(f"üìè Hidden dimension: {hidden_dim}")

# Create SAE directly (context is created internally)
sae = Autoencoder(
    n_latents=hidden_dim * 4,  # 4x expansion factor
    n_inputs=hidden_dim,
    activation="TopK_8",  # TopK with k=8 for sparsity
    tied=False,  # Untied weights for better reconstruction
    init_method="kaiming",
    device=DEVICE,
)

print(f"üß† SAE architecture: {hidden_dim} ‚Üí {sae.context.n_latents} ‚Üí {hidden_dim}")
sae.context.experiment_name = "sae_training"
sae.context.run_id = RUN_ID
print(f"üîß Context: {sae.context.experiment_name}/{sae.context.run_id}")


üèóÔ∏è Creating SAE model...
üìè Hidden dimension: 6
üß† SAE architecture: 6 ‚Üí 24 ‚Üí 6
üîß Context: sae_training/sae_training_20251028_234300


In [12]:
# Step 5: Train SAE
print("üèãÔ∏è Training SAE...")

config = SAETrainingConfig(
    epochs=10,
    batch_size=BATCH_SIZE_TRAIN,
    lr=1e-3,
    l1_lambda=1e-4,  # L1 sparsity penalty
    device=DEVICE,
    max_batches_per_epoch=50,  # Limit batches per epoch for quick training
    project_decoder_grads=True,  # Project gradients for stability
    renorm_decoder_every=5,  # Renormalize decoder weights
    verbose=True,  # Enable progress logging
)

trainer = SAETrainer(sae, store, RUN_ID, config)
history = trainer.train()

print("‚úÖ Training completed!")
print(f"üìà Final loss: {history['loss'][-1]:.6f}")
print(f"üìà Final reconstruction MSE: {history['recon_mse'][-1]:.6f}")
print(f"üìà Final L1 penalty: {history['l1'][-1]:.6f}")


üèãÔ∏è Training SAE...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2025-10-28 23:43:10,498 [INFO] amber.mechanistic.autoencoder.train: [SAETrainer] device=cpu dtype=None use_amp=True grad_accum_steps=1 batch_size=32 lr=0.001
2025-10-28 23:43:10,498 [INFO] amber.mechanistic.autoencoder.train: [SAETrainer] Starting training run_id=sae_training_20251028_234300 epochs=10 batch_size=32
Epochs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:00<00:00, 16.52it/s, loss=0.00000, mse=0.00000, l1=0.00609]
20

‚úÖ Training completed!
üìà Final loss: 0.000001
üìà Final reconstruction MSE: 0.000001
üìà Final L1 penalty: 0.006089


In [None]:
# Step 6: Save trained SAE
print("üíæ Saving trained SAE...")

# Add metadata about the training
metadata = {
    "dataset": HF_DATASET,
    "data_limit": DATA_LIMIT,
    "hidden_dim": hidden_dim,
    "n_latents": sae.context.n_latents,
    "activation": "TopK_8",
    "training_config": {
        "epochs": config.epochs,
        "batch_size": config.batch_size,
        "lr": config.lr,
        "l1_lambda": config.l1_lambda,
    },
    "run_id": RUN_ID,
    "training_history": history,
}

sae.save(
    name="sae_model",
    path=SAE_MODEL_PATH.parent,
    run_metadata=metadata,
    layer_signature=LAYER_SIGNATURE,
    model_id=MODEL_ID,
)

print(f"‚úÖ SAE saved to: {SAE_MODEL_PATH}")


üíæ Saving trained SAE...
‚úÖ SAE saved to: outputs/sae_model.pt


In [None]:
# Step 7: Save run metadata for next example
import json

run_metadata = {
    "run_id": RUN_ID,
    "layer_signature": LAYER_SIGNATURE,
    "hidden_dim": hidden_dim,
    "n_latents": sae.context.n_latents,
    "model_id": MODEL_ID,
    "dataset": HF_DATASET,
    "data_limit": DATA_LIMIT,
    "sae_model_path": str(SAE_MODEL_PATH),
    "store_dir": str(STORE_DIR),
    "cache_dir": str(CACHE_DIR),
}

metadata_path = Path("outputs/training_metadata.json")
with open(metadata_path, "w") as f:
    json.dump(run_metadata, f, indent=2)

print(f"üìã Training metadata saved to: {metadata_path}")
print()
print("üéâ SAE training completed successfully!")
print("üìù Next: Run 02_attach_sae_and_save_texts.ipynb to attach the SAE and collect top texts")


üìã Training metadata saved to: outputs/training_metadata.json

üéâ SAE training completed successfully!
üìù Next: Run 02_attach_sae_and_save_texts.ipynb to attach the SAE and collect top texts
