# Notebook with Example Config for Different Models / Hooks

## Setup

In [1]:
import torch
import os
import sys

sys.path.append("..")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: mps


## Gelu-2L

An example of a toy language model we're able to train on.

### MLP Out

In [None]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gelu-2l",
    hook_point="blocks.0.hook_mlp_out",
    hook_point_layer=0,
    d_in=512,
    dataset_path="NeelNanda/c4-tokenized-2b",
    is_dataset_tokenized=True,
    # SAE Parameters
    expansion_factor=[16, 32, 64],
    b_dec_init_method="geometric_median",  # geometric median is better but slower to get started
    # Training Parameters
    lr=0.0012,
    lr_scheduler_name="constantwithwarmup",
    l1_coefficient=0.00016,
    train_batch_size=4096,
    context_size=128,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    total_training_tokens=1_000_000 * 100,
    store_batch_size=32,
    # Resampling protocol
    use_ghost_grads=True,
    feature_sampling_window=5000,
    dead_feature_window=5000,
    dead_feature_threshold=1e-4,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_language_models_gelu_2l_test",
    wandb_log_frequency=10,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)


sparse_autoencoder = language_model_sae_runner(cfg)

## GPT2 - Small

### Residual Stream

In [None]:
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

layer = 3
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point=f"blocks.{layer}.hook_resid_pre",
    hook_point_layer=layer,
    d_in=768,
    dataset_path="Skylion007/openwebtext",
    is_dataset_tokenized=False,
    # SAE Parameters
    expansion_factor=32,  # determines the dimension of the SAE.
    b_dec_init_method="mean",  # geometric median is better but slower to get started
    # Training Parameters
    lr=0.0004,
    l1_coefficient=0.00008,
    lr_scheduler_name="constantwithwarmup",
    train_batch_size=4096,
    context_size=128,
    lr_warm_up_steps=5000,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    total_training_tokens=1_000_000 * 300,  # 200M tokens seems doable overnight.
    store_batch_size=32,
    # Resampling protocol
    use_ghost_grads=True,
    feature_sampling_window=2500,
    dead_feature_window=5000,
    dead_feature_threshold=1e-8,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_language_models_resid_pre_test",
    wandb_entity=None,
    wandb_log_frequency=100,
    # Misc
    device="cuda",
    seed=42,
    n_checkpoints=10,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

# Pythia 70-M

In [None]:
import torch
import os
import sys

sys.path.append("..")
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="pythia-70m-deduped",
    hook_point="blocks.0.hook_mlp_out",
    hook_point_layer=0,
    d_in=512,
    dataset_path="EleutherAI/the_pile_deduplicated",
    is_dataset_tokenized=False,
    # SAE Parameters
    expansion_factor=64,
    # Training Parameters
    lr=3e-4,
    l1_coefficient=4e-5,
    train_batch_size=8192,
    context_size=128,
    lr_scheduler_name="constantwithwarmup",
    lr_warm_up_steps=10_000,
    # Activation Store Parameters
    n_batches_in_buffer=64,
    total_training_tokens=1_000_000 * 800,
    store_batch_size=32,
    # Resampling protocol
    feature_sampling_window=2000,  # Doesn't currently matter.
    dead_feature_window=40000,
    dead_feature_threshold=1e-8,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_language_benchmark_tests",
    wandb_entity=None,
    wandb_log_frequency=20,
    # Misc
    device="cuda",
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)


sparse_autoencoder = language_model_sae_runner(cfg)

# Pythia 70M Hook Q

In [None]:
import torch
import os
import sys

sys.path.append("../")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="pythia-70m-deduped",
    hook_point="blocks.2.attn.hook_q",
    hook_point_layer=2,
    hook_point_head_index=7,
    d_in=64,
    dataset_path="EleutherAI/the_pile_deduplicated",
    is_dataset_tokenized=False,
    # SAE Parameters
    expansion_factor=16,
    # Training Parameters
    lr=0.0012,
    l1_coefficient=0.003,
    lr_scheduler_name="constantwithwarmup",
    lr_warm_up_steps=1000,  # about 4 million tokens.
    train_batch_size=4096,
    context_size=128,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    total_training_tokens=1_000_000 * 1500,
    store_batch_size=32,
    # Resampling protocol
    feature_sampling_method="anthropic",
    feature_sampling_window=1000,  # doesn't do anything currently.
    feature_reinit_scale=0.2,
    resample_batches=8,
    dead_feature_window=60000,
    dead_feature_threshold=1e-5,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_pythia_70M_hook_q_L2H7",
    wandb_entity=None,
    wandb_log_frequency=100,
    # Misc
    device="mps",
    seed=42,
    n_checkpoints=15,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

# Tiny Stories

## MLP Out

In [None]:
import torch
import os

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu" and torch.backends.mps.is_available():
    device = "mps"

os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1M",
    hook_point="blocks.1.mlp.hook_post",
    hook_point_layer=1,
    d_in=256,
    # dataset_path="roneneldan/TinyStories",
    # is_dataset_tokenized=False,
    # Dan at Apollo pretokenized this dataset for us which will speed up training.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
    is_dataset_tokenized=True,
    # SAE Parameters
    expansion_factor=16,
    # Training Parameters
    lr=1e-4,
    lp_norm=1.0,
    l1_coefficient=2e-4,
    train_batch_size=4096,
    context_size=128,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    total_training_tokens=1_000_000 * 20,
    store_batch_size=32,
    feature_sampling_window=500,  # So we see the histograms.
    dead_feature_window=250,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_language_benchmark_tests",
    wandb_log_frequency=10,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

## Hook Z



In [1]:
import torch
import os

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu" and torch.backends.mps.is_available():
    device = "mps"

os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1M",
    hook_point="blocks.1.attn.hook_z",
    hook_point_layer=1,
    d_in=64,
    # dataset_path="roneneldan/TinyStories",
    # is_dataset_tokenized=False,
    # Dan at Apollo pretokenized this dataset for us which will speed up training.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
    is_dataset_tokenized=True,
    # SAE Parameters
    expansion_factor=16,
    # Training Parameters
    lr=1e-4,
    lp_norm=1.0,
    l1_coefficient=2e-4,
    train_batch_size=4096,
    context_size=128,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    total_training_tokens=1_000_000 * 20,
    store_batch_size=32,
    feature_sampling_window=500,  # So we see the histograms.
    dead_feature_window=250,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_language_benchmark_tests",
    wandb_log_frequency=10,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 4882
Total wandb updates: 488
n_tokens_per_feature_sampling_window (millions): 262.144
n_tokens_per_dead_feature_window (millions): 131.072
We will reset the sparsity calculation 9 times.
Number tokens in sparsity calculation window: 2.05e+06
Loaded pretrained model tiny-stories-1M into HookedTransformer
Moving model to device:  mps


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 4882
Total wandb updates: 488
n_tokens_per_feature_sampling_window (millions): 262.144
n_tokens_per_dead_feature_window (millions): 131.072
We will reset the sparsity calculation 9 times.
Number tokens in sparsity calculation window: 2.05e+06
Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 4882
Total wandb updates: 488
n_tokens_per_feature_sampling_window (millions): 262.144
n_tokens_per_dead_feature_window (millions): 131.072
We will reset the sparsity calculation 9 times.
Number tokens in sparsity calculation window: 2.05e+06


[34m[1mwandb[0m: Currently logged in as: [33mjbloom[0m. Use [1m`wandb login --relogin`[0m to force relogin


Objective value: 116883.7422:  10%|█         | 10/100 [00:00<00:00, 128.72it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)
100%|██████████| 10/10 [00:02<00:00,  4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]
100%|██████████| 10/10 [00:01<00:00,  5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] 
100%|██████████| 10/10 [00:01<00:00,  5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] 
100%|██████████| 10/10 [00:02<00:00,  4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]
100%|██████████| 10/10 [00:02<00:00,  4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]
100%|██████████| 10/10 [00:01<00:00,  5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]
100%|██████████| 10/10 [00:01<00:00,  5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]
100%|██████████| 10/10 [00:01<00:00,  5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] 
100%|██████████| 10/10 [00:02<00:00,  4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] 
100%|██

Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt


VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\r'), FloatProgress(value=0.0935266880101429, max=1.0…



0,1
details/current_learning_rate,▁▃▅▆████████████████████████████████████
details/n_training_tokens,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation,▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae,█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae,▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅

0,1
details/current_learning_rate,0.0001
details/n_training_tokens,19988480.0
losses/ghost_grad_loss,0.0
losses/l1_loss,1.41017
losses/mse_loss,8e-05
losses/overall_loss,0.00036
metrics/CE_loss_score,0.98362
metrics/ce_loss_with_ablation,5.49512
metrics/ce_loss_with_sae,2.71813
metrics/ce_loss_without_sae,2.67199


  lambda data: self._console_raw_callback("stderr", data),


# Toy Model

In [None]:
from sae_lens.training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner


cfg = SAEToyModelRunnerConfig(
    # Model Details
    n_features=200,
    n_hidden=5,
    n_correlated_pairs=0,
    n_anticorrelated_pairs=0,
    feature_probability=0.025,
    model_training_steps=10_000,
    # SAE Parameters
    d_sae=240,
    l1_coefficient=0.001,
    # SAE Train Config
    train_batch_size=1028,
    feature_sampling_window=3_000,
    dead_feature_window=1_000,
    feature_reinit_scale=0.5,
    total_training_tokens=4096 * 300,
    # Other parameters
    log_to_wandb=True,
    wandb_project="sae-training-test",
    wandb_log_frequency=5,
    device="mps",
)

trained_sae = toy_model_sae_runner(cfg)

assert trained_sae is not None

# Run caching of activations to disk

In [None]:
import torch
import os
import sys

sys.path.append("..")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_lens.training.config import CacheActivationsRunnerConfig
from sae_lens.training.cache_activations_runner import cache_activations_runner

cfg = CacheActivationsRunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point="blocks.10.attn.hook_q",
    hook_point_layer=10,
    hook_point_head_index=7,
    d_in=64,
    dataset_path="Skylion007/openwebtext",
    is_dataset_tokenized=False,
    cached_activations_path="../activations/",
    # Activation Store Parameters
    n_batches_in_buffer=16,
    total_training_tokens=500_000_000,
    store_batch_size=32,
    # Activation caching shuffle parameters
    n_shuffles_final=16,
    # Misc
    device="mps",
    seed=42,
    dtype=torch.float32,
)

cache_activations_runner(cfg)

## Train an SAE using the cached activations stored on disk
Pass `use_cached_activations=True` into the config

In [None]:
import torch
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point="blocks.10.hook_resid_pre",
    hook_point_layer=11,
    d_in=768,
    dataset_path="Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=True,
    # SAE Parameters
    expansion_factor=64,  # determines the dimension of the SAE.
    # Training Parameters
    lr=1e-5,
    l1_coefficient=5e-4,
    lr_scheduler_name=None,
    train_batch_size=4096,
    context_size=128,
    # Activation Store Parameters
    n_batches_in_buffer=64,
    total_training_tokens=200_000,
    store_batch_size=32,
    # Resampling protocol
    feature_sampling_method="l2",
    feature_sampling_window=1000,
    feature_reinit_scale=0.2,
    dead_feature_window=5000,
    dead_feature_threshold=1e-7,
    # WANDB
    log_to_wandb=True,
    wandb_project="mats_sae_training_gpt2_small",
    wandb_entity=None,
    wandb_log_frequency=50,
    # Misc
    device="mps",
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)