# Example 2: Attach SAE and Save Top Texts for Neurons

This notebook demonstrates how to:
1. Load a trained TopKSAE model from the previous example (saved under `store/{model_id}/`)
2. Attach the TopKSAE to a language model
3. Enable text tracking to collect top activating texts for each neuron
4. Run inference on new text data to collect neuron-text associations
5. Save the collected top texts for use in the next example

All files will be saved under `store/{model_id}/` for organized, model-specific storage.


In [1]:
# Setup and imports
%load_ext autoreload
%autoreload 2

import torch
import json
from pathlib import Path
from datetime import datetime

from amber.adapters import TextDataset
from amber.core.language_model import LanguageModel
from amber.mechanistic.sae.modules.topk_sae import TopKSae
from amber.store.local_store import LocalStore

print("‚úÖ Imports completed")


  from .autonotebook import tqdm as notebook_tqdm


‚úÖ Imports completed


In [2]:
# Load metadata from previous example
# First, we need to determine the model_id to find the correct directory
# For this example, we'll use the same model as example 1
MODEL_ID_HF = "sshleifer/tiny-gpt2"  # Should match example 1

# Load model to get model_id
STORE_DIR = Path("store")

store = LocalStore(STORE_DIR)
temp_lm = LanguageModel.from_huggingface(MODEL_ID_HF, store)
model_id = temp_lm.model_id
del temp_lm  # Clean up

# Metadata is saved under store/{model_id}/
MODEL_DIR = STORE_DIR / model_id
metadata_path = MODEL_DIR / "training_metadata.json"

if not metadata_path.exists():
    print(f"‚ùå Error: training_metadata.json not found at {metadata_path}!")
    print("   Please run 01_train_sae_model.ipynb first")
    raise FileNotFoundError(f"training_metadata.json not found at {metadata_path}")

with open(metadata_path, "r") as f:
    metadata = json.load(f)

print("üìã Loaded training metadata:")
for key, value in metadata.items():
    if key != "training_history":  # Skip large training history
        print(f"   {key}: {value}")
print()


üìã Loaded training metadata:
   run_id: topk_sae_training_20251117_213223
   layer_signature: gpt2lmheadmodel_transformer_h_0_attn_c_attn
   hidden_dim: 6
   n_latents: 24
   k: 8
   model_id: sshleifer/tiny-gpt2
   model_dir: store/sshleifer_tiny-gpt2
   dataset: roneneldan/TinyStories
   data_limit: 1000
   sae_model_path: store/sshleifer_tiny-gpt2/topk_sae_model.pt
   store_dir: store
   cache_dir: store/cache



In [3]:
# Configuration
MODEL_ID = metadata["model_id"]
LAYER_SIGNATURE = metadata["layer_signature"]
SAE_MODEL_PATH = Path(metadata["sae_model_path"])
CACHE_DIR = Path(metadata["cache_dir"])
STORE_DIR = Path(metadata["store_dir"])
MODEL_DIR = Path(metadata.get("model_dir", STORE_DIR / MODEL_ID.split("/")[-1].replace("-", "_")))

# New dataset for text collection (different from training data)
HF_DATASET = "roneneldan/TinyStories"
DATA_SPLIT = "train"
TEXT_FIELD = "text"
DATA_LIMIT = 500  # Smaller dataset for text collection
MAX_LENGTH = 64

# Text tracking configuration
TOP_K = 10  # Number of top texts to track per neuron
NEGATIVE_TRACKING = False  # Track positive activations

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Output paths (save under model directory)
TOP_TEXTS_PATH = MODEL_DIR / "top_texts.json"
ATTACHMENT_METADATA_PATH = MODEL_DIR / "attachment_metadata.json"

print("üîó Starting SAE Attachment and Text Collection Example")
print(f"üì± Using device: {DEVICE}")
print(f"üîß Model: {MODEL_ID}")
print(f"üìä Dataset: {HF_DATASET}")
print(f"üéØ Target layer: {LAYER_SIGNATURE}")
print(f"üß† SAE model: {SAE_MODEL_PATH}")
print(f"üìÅ Model directory: {MODEL_DIR}")
print()


üîó Starting SAE Attachment and Text Collection Example
üì± Using device: cpu
üîß Model: sshleifer/tiny-gpt2
üìä Dataset: roneneldan/TinyStories
üéØ Target layer: gpt2lmheadmodel_transformer_h_0_attn_c_attn
üß† SAE model: store/sshleifer_tiny-gpt2/topk_sae_model.pt
üìÅ Model directory: store/sshleifer_tiny-gpt2



In [4]:
# Step 1: Load language model
print("üì• Loading language model...")

# Load model and move to device
model = LanguageModel.from_huggingface(MODEL_ID, store)
model.model.to(DEVICE)


# Optional: set experiment metadata
model.context.experiment_name = "sae_attachment"
model.context.run_id = f"attachment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
model.context.max_length = MAX_LENGTH

print(f"‚úÖ Model loaded: {model.model_id}")
print(f"üì± Device: {DEVICE}")
print(f"üìÅ Store: {model.context.store.base_path}")
print(f"üîß Context: {model.context.experiment_name}/{model.context.run_id}")


üì• Loading language model...
‚úÖ Model loaded: sshleifer_tiny-gpt2
üì± Device: cpu
üìÅ Store: store
üîß Context: sae_attachment/attachment_20251117_213311


In [5]:
# Step 2: Load trained TopKSAE
print("üì• Loading trained TopKSAE...")
if not SAE_MODEL_PATH.exists():
    print(f"‚ùå Error: SAE model not found at {SAE_MODEL_PATH}")
    print("   Please run 01_train_sae_model.ipynb first")
    raise FileNotFoundError(f"SAE model not found at {SAE_MODEL_PATH}")

# Load TopKSAE using the new load method
sae_hook = TopKSae.load(SAE_MODEL_PATH)
sae_hook.sae_engine.to(DEVICE)

# Get k from metadata
k = metadata.get("k", 8)

# Update context with current experiment info
sae_hook.context.experiment_name = "sae_attachment"
sae_hook.context.run_id = f"attachment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

print(f"‚úÖ TopKSAE loaded: {metadata['hidden_dim']} ‚Üí {metadata['n_latents']} ‚Üí {metadata['hidden_dim']} (k={k})")
print(f"üîß Context: {sae_hook.context.experiment_name}/{sae_hook.context.run_id}")


2025-11-17 21:33:11,612 [INFO] amber.mechanistic.sae.modules.topk_sae: 
Loaded TopKSAE from store/sshleifer_tiny-gpt2/topk_sae_model.pt
n_latents=24, n_inputs=6, k=8


üì• Loading trained TopKSAE...
‚úÖ TopKSAE loaded: 6 ‚Üí 24 ‚Üí 6 (k=8)
üîß Context: sae_attachment/attachment_20251117_213311


In [6]:
# Step 3: Register SAE hook on language model and enable text tracking
print("üîó Registering SAE hook on language model...")

# Register the SAE hook on the target layer
model.layers.register_hook(LAYER_SIGNATURE, sae_hook)

# Set the language model and layer signature on the SAE hook's context
sae_hook.context.lm = model
sae_hook.context.lm_layer_signature = LAYER_SIGNATURE

print(f"‚úÖ SAE hook registered on layer: {LAYER_SIGNATURE}")

# Enable text tracking
print("üîó Enabling text tracking...")
sae_hook.context.text_tracking_enabled = True
sae_hook.context.text_tracking_k = TOP_K
sae_hook.context.text_tracking_negative = NEGATIVE_TRACKING
sae_hook.concepts.enable_text_tracking()

print(f"‚úÖ Text tracking enabled: top-{TOP_K} {'negative' if NEGATIVE_TRACKING else 'positive'} activations")
print(f"üîß Context: {sae_hook.context.experiment_name}/{sae_hook.context.run_id}")


üîó Registering SAE hook on language model...
‚úÖ SAE hook registered on layer: gpt2lmheadmodel_transformer_h_0_attn_c_attn
üîó Enabling text tracking...
‚úÖ Text tracking enabled: top-10 positive activations
üîß Context: sae_attachment/attachment_20251117_213311


In [7]:
# Step 4: Load dataset for text collection
print("üì• Loading dataset for text collection...")
dataset = TextDataset.from_huggingface(
    HF_DATASET,
    split=DATA_SPLIT,
    store=store,
    text_field=TEXT_FIELD,
    limit=DATA_LIMIT,
)
print(f"‚úÖ Loaded {len(dataset)} text samples")


üì• Loading dataset for text collection...


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [00:00<00:00, 270286.38 examples/s]

‚úÖ Loaded 500 text samples





In [8]:
# Step 5: Run inference to collect top texts
print("üîç Running inference to collect top texts...")

batch_size = 16
total_batches = (len(dataset) + batch_size - 1) // batch_size

for i in range(0, len(dataset), batch_size):
    batch_end = min(i + batch_size, len(dataset))
    batch_texts = [dataset[j] for j in range(i, batch_end)]

    # Run forward pass to trigger text tracking
    model.forwards(batch_texts)

    if (i // batch_size + 1) % 10 == 0:
        print(f"   Processed {batch_end}/{len(dataset)} samples...")

print("‚úÖ Text collection completed!")


üîç Running inference to collect top texts...
   Processed 160/500 samples...
   Processed 320/500 samples...
   Processed 480/500 samples...
‚úÖ Text collection completed!


In [9]:
# Step 6: Analyze collected top texts
print("üìä Analyzing collected top texts...")

# Get top texts for a few neurons as examples
example_neurons = [0, 1, 2, 5, 10]  # Show first few neurons
total_neurons_with_texts = 0

for neuron_idx in example_neurons:
    if neuron_idx < metadata["n_latents"]:
        top_texts = sae_hook.concepts.get_top_texts_for_neuron(neuron_idx)
        if top_texts:
            total_neurons_with_texts += 1
            print(f"üß† Neuron {neuron_idx}: {len(top_texts)} texts")
            for j, nt in enumerate(top_texts[:3]):  # Show top 3
                print(f"   {j + 1}. '{nt.text}' (score: {nt.score:.4f})")
            if len(top_texts) > 3:
                print(f"   ... and {len(top_texts) - 3} more")
            print()

# Count total neurons with texts
for neuron_idx in range(metadata["n_latents"]):
    if sae_hook.concepts.get_top_texts_for_neuron(neuron_idx):
        total_neurons_with_texts += 1

print(f"üìà Summary: {total_neurons_with_texts}/{metadata['n_latents']} neurons have collected texts")


üìä Analyzing collected top texts...
üß† Neuron 0: 10 texts
   1. 'Mummy and Daddy were picking flowers in the garden. Mummy picked a red daisy, Daddy picked a purple thistle and the little girl picked a beautiful lily. The lily was her favorite because it was so fluffy and white and the aroma was heavenly. 

Daddy said, "Let's bring this lily inside and put it on the windowsill." 

Mummy said, "How about we make it a surprise?" 

So the family all went inside and the little girl put the lily on the windowsill. 

When she stepped back to admire her work, she noticed a bright yellow butterfly that had landed on the lily. The little girl smiled. 

Mummy said, "Oh my, that lily looks so warm and cozy with the butterfly on top."

The little girl nodded, delighted with her surprise. And, from that day on, the warm lily became a happy reminder of the special family day.' (score: 0.0169)
   2. 'Once there was a generous bear. He liked to help others and was always very kind. But he had one 

In [10]:
# Step 7: Save top texts
print("üíæ Saving top texts...")
sae_hook.concepts.export_top_texts_to_json(str(TOP_TEXTS_PATH))
print(f"üìä Saved texts for {total_neurons_with_texts} neurons")
print(f"üìÅ Saved to: {TOP_TEXTS_PATH}")


üíæ Saving top texts...
üìä Saved texts for 29 neurons
üìÅ Saved to: store/sshleifer_tiny-gpt2/top_texts.json


In [11]:
# Step 8: Save metadata for next example
attachment_metadata = {
    "model_id": MODEL_ID,
    "model_dir": str(MODEL_DIR),
    "layer_signature": LAYER_SIGNATURE,
    "n_latents": metadata["n_latents"],
    "top_k": TOP_K,
    "negative_tracking": NEGATIVE_TRACKING,
    "dataset": HF_DATASET,
    "data_limit": DATA_LIMIT,
    "total_neurons_with_texts": total_neurons_with_texts,
    "top_texts_path": str(TOP_TEXTS_PATH),
    "sae_model_path": str(SAE_MODEL_PATH),
}

with open(ATTACHMENT_METADATA_PATH, "w") as f:
    json.dump(attachment_metadata, f, indent=2)

print(f"üìã Attachment metadata saved to: {ATTACHMENT_METADATA_PATH}")
print()
print("üéâ SAE attachment and text collection completed successfully!")
print(f"üìÅ All files saved under model directory: {MODEL_DIR}")
print("üìù Next: Run 03_load_concepts.ipynb to load and manipulate the concepts")


üìã Attachment metadata saved to: store/sshleifer_tiny-gpt2/attachment_metadata.json

üéâ SAE attachment and text collection completed successfully!
üìÅ All files saved under model directory: store/sshleifer_tiny-gpt2
üìù Next: Run 03_load_and_manipulate_concepts.ipynb to load and manipulate the concepts
