In [10]:
# Jupyter-friendly script (use in VS Code/JetBrains as a notebook)
%load_ext autoreload
%autoreload 2

# --- Imports ---
from pathlib import Path
from datetime import datetime

import torch

from amber.store import LocalStore
from amber.adapters.text_snippet_dataset import TextSnippetDataset
from amber.core.language_model import LanguageModel
from amber.mechanistic.autoencoder.autoencoder import Autoencoder
from amber.mechanistic.autoencoder.train import SAETrainer, SAETrainingConfig


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
# --- Configuration ---
MODEL_ID = "sshleifer/tiny-gpt2"  # tiny model for quick experimentation
HF_DATASET = "roneneldan/TinyStories"
DATA_SPLIT = "train"
TEXT_FIELD = "text"
DATA_LIMIT = 200  # keep small for a quick demo
MAX_LENGTH = 128
BATCH_SIZE_SAVE = 8

# Choose which layer to hook. You can use an integer index or a layer name.
# Use model.layers.get_layer_names() below to inspect available names.
LAYER_SIGNATURE: int | str = 'gpt2lmheadmodel_lm_head'

# Storage locations
CACHE_DIR = Path("./store/tinystories")
STORE_DIR = Path("./store/tiny-gpt2")
RUN_ID = f"tinystories_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = None  # e.g., torch.float16 for lower storage; keep None to preserve dtype

# SAE config
SAE_EPOCHS = 2
SAE_MINIBATCH = 1024  # mini-batch within stored activation batches
SAE_LR = 1e-3
SAE_L1 = 0.0  # sparsity penalty on latents (set > 0.0 to encourage sparsity)


In [12]:
# --- Initialize model and dataset ---
model = LanguageModel.from_huggingface(MODEL_ID)
model.model.to(DEVICE)

# Optionally, inspect available layer names
layer_names = model.layers.print_layer_names()
print(f"Discovered {len(layer_names)} layers. Example names: {layer_names[:5]}")

# Load a small text dataset
dataset = TextSnippetDataset.from_huggingface(
    HF_DATASET,
    split=DATA_SPLIT,
    cache_dir=str(CACHE_DIR),
    text_field=TEXT_FIELD,
    limit=DATA_LIMIT,
)

# Prepare a LocalStore for saving activations
store = LocalStore(STORE_DIR)
print(f"Store base path: {store.base_path}")
print(f"Run id: {RUN_ID}")

Discovered 33 layers. Example names: ['gpt2lmheadmodel_transformer', 'gpt2lmheadmodel_transformer_wte', 'gpt2lmheadmodel_transformer_wpe', 'gpt2lmheadmodel_transformer_drop', 'gpt2lmheadmodel_transformer_h']


Saving the dataset (1/1 shards): 100%|██████████| 200/200 [00:00<00:00, 110492.73 examples/s]

Store base path: store/tiny-gpt2
Run id: tinystories_20250929_221605





In [13]:
# --- Save activations for the chosen layer ---
# This will iterate over dataset in small batches, run the model, capture layer outputs,
# and write per-batch safetensors files under STORE_DIR/runs/{RUN_ID}/
model.activations.infer_and_save(
    dataset,
    layer_signature=LAYER_SIGNATURE,
    run_name=RUN_ID,
    store=store,
    batch_size=BATCH_SIZE_SAVE,
    max_length=MAX_LENGTH,
    dtype=DTYPE,
    autocast=True,
    save_inputs=True,
    free_cuda_cache_every=0,
    verbose=True,
)


2025-09-29 22:16:09,771 [INFO] amber.core.language_model_activations: Starting save_model_activations: run=tinystories_20250929_221605, layer=gpt2lmheadmodel_lm_head, batch_size=8, device=cpu
2025-09-29 22:16:09,773 [INFO] amber.core.language_model_activations: Prepared batch 0: items=8, seq_len=128
2025-09-29 22:16:09,826 [INFO] amber.core.language_model_activations: Saved batch 0 for run=tinystories_20250929_221605 with keys=['activations', 'input_ids', 'attention_mask']
2025-09-29 22:16:09,828 [INFO] amber.core.language_model_activations: Prepared batch 1: items=8, seq_len=128
2025-09-29 22:16:09,880 [INFO] amber.core.language_model_activations: Saved batch 1 for run=tinystories_20250929_221605 with keys=['activations', 'input_ids', 'attention_mask']
2025-09-29 22:16:09,882 [INFO] amber.core.language_model_activations: Prepared batch 2: items=8, seq_len=128
2025-09-29 22:16:09,934 [INFO] amber.core.language_model_activations: Saved batch 2 for run=tinystories_20250929_221605 with ke

In [14]:
# --- Inspect one saved batch; infer hidden size ---
first_batch = next(store.iter_run_batches(RUN_ID))
acts = first_batch["activations"] if isinstance(first_batch, dict) else first_batch[0]
print("Saved activations shape:", tuple(acts.shape))

# Flatten any leading dims to [N, D] to determine input dim
hidden_dim = acts.shape[-1]
print("Inferred hidden_dim:", hidden_dim)


Saved activations shape: (8, 128, 50257)
Inferred hidden_dim: 50257


In [17]:
# --- Build the Sparse Autoencoder ---
# A common recipe is to use overcomplete latents (e.g., 2x inputs)
n_latents = hidden_dim * 2
sae = Autoencoder(n_latents=n_latents, n_inputs=hidden_dim, activation="TopK_4", tied=False, device=DEVICE)
print(sae)


Autoencoder(
  (activation): TopK(
    k=4, act_fn=Identity(), use_abs=False
    (act_fn): Identity()
  )
)


In [None]:
# --- Train the SAE from stored activations ---
ckpt_dir = STORE_DIR / "checkpoints" / RUN_ID
cfg = SAETrainingConfig(
    epochs=SAE_EPOCHS,
    batch_size=SAE_MINIBATCH,
    lr=SAE_LR,
    l1_lambda=SAE_L1,
    device=DEVICE,
    dtype=DTYPE,
    max_batches_per_epoch=None,
    validate_every=None,
    checkpoint_dir=ckpt_dir,
    project_decoder_grads=True,
    renorm_decoder_every=100,  # maintain stable decoder scale periodically
)

trainer = SAETrainer(sae, store, RUN_ID, cfg)
history = trainer.train()
print("Training history:", history)

# Optionally, save final model
final_dir = STORE_DIR / "sae_models" / RUN_ID
final_dir.mkdir(parents=True, exist_ok=True)
sae.save("final", path=str(final_dir))
print("Saved final SAE to:", final_dir)
