```bash
conda create -n sae-lens python=3.10
conda activate sae-lens
pip install -e .
pip install ipywidgets

ipython kernel install --name "sae-lens" --user
```


In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]='7'

import random
import numpy as np
import torch
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


In [2]:
def reproducibility():
    """Apply various mechanisms to try to prevent nondeterminism in test runs."""
    # I have not in general attempted to verify that the below are necessary
    # for reproducibility, only that they are likely to help and unlikely to
    # hurt.
    # https://pytorch.org/docs/stable/notes/randomness.html#reproducibility
    seed = 0x1234_5678_9ABC_DEF0
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    # Python native RNG; docs don't give any limitations on seed range
    random.seed(seed)
    # this is a "legacy" method that operates on a global RandomState
    # sounds like the argument must be in [0, 2**32)
    np.random.seed(seed & 0xFFFF_FFFF)

# Run the code below to generate the ground truth (full training without interruption.)


In [3]:
reproducibility()

In [4]:
total_training_steps = 1_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=64,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=False,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=10, # (FOR CHECKPOINTING) set this
    resume=True, # (FOR CHECKPOINTING) set this
    wandb_id = "fullrun", # (FOR CHECKPOINTING) set this some value to report to same wandb experiment
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Loaded pretrained model tiny-stories-1M into HookedTransformer


Estimating norm scaling factor:   0%|          | 0/1000 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Training SAE:   0%|          | 0/4096000 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

# Run the code below to resume from checkpoint


In [5]:
import shutil

if os.path.exists("checkpoints/test/"):
    shutil.rmtree("checkpoints/test/")
os.makedirs("checkpoints/test/")

shutil.copytree("checkpoints/fullrun/1232896", "checkpoints/test/1232896")

'checkpoints/test/1232896'

In [6]:
reproducibility()

In [7]:
total_training_steps = 1_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=64,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=False,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=10, # (FOR CHECKPOINTING) set this
    resume=True, # (FOR CHECKPOINTING) set this
    wandb_id = "test", # (FOR CHECKPOINTING) set this some value to report to same wandb experiment
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Loaded pretrained model tiny-stories-1M into HookedTransformer


Estimating norm scaling factor:   0%|          | 0/1000 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Loading checkpoint from checkpoints/test/1232896
Replaying batches (next_batch() only)


Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Done loading checkpoint


Training SAE:  30%|###       | 1232896/4096000 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

Refilling buffer:   0%|          | 0/32 [00:00<?, ?it/s]

# Check if the final checkpoints are the same


In [8]:
from safetensors.torch import load_file

sparsity_target = load_file("checkpoints/fullrun/final_4096000/sparsity.safetensors")
sparsity_test = load_file("checkpoints/test/final_4096000/sparsity.safetensors")

sae_weights_target = load_file("checkpoints/fullrun/final_4096000/sae_weights.safetensors")
sae_weights_test = load_file("checkpoints/test/final_4096000/sae_weights.safetensors")

activation_store_state_target = load_file("checkpoints/fullrun/final_4096000/activations_store_state.safetensors")
activation_store_state_test = load_file("checkpoints/test/final_4096000/activations_store_state.safetensors")

trainer_state_target = torch.load("checkpoints/fullrun/final_4096000/trainer_state.pt", weights_only=False, map_location="cpu")
trainer_state_test = torch.load("checkpoints/test/final_4096000/trainer_state.pt", weights_only=False, map_location="cpu")

In [9]:
for key in trainer_state_target.keys():
    print(key, str(trainer_state_target[key])[:1000]==str(trainer_state_test[key])[:1000])
    print(str(trainer_state_target[key])[:1000])
    print(str(trainer_state_test[key])[:1000])

n_training_steps True
1000
1000
n_training_tokens True
4096000
4096000
act_freq_scores True
tensor([  1.,   0., 615.,  ...,   0.,   1., 773.])
tensor([  1.,   0., 615.,  ...,   0.,   1., 773.])
n_forward_passes_since_fired True
tensor([0., 4., 0.,  ..., 1., 0., 0.])
tensor([0., 4., 0.,  ..., 1., 0., 0.])
n_frac_active_tokens True
4096
4096
optimizer True
{'state': {0: {'step': tensor(1000.), 'exp_avg': tensor([6.2866e-06, 4.0250e-06, 3.4841e-03,  ..., 8.4443e-06, 2.0037e-05,
        8.2403e-03]), 'exp_avg_sq': tensor([3.0301e-06, 1.7676e-06, 1.1245e-05,  ..., 8.2309e-06, 8.9528e-06,
        7.2805e-05])}, 1: {'step': tensor(1000.), 'exp_avg': tensor([[ 6.1260e-07, -2.7279e-07,  5.0299e-07,  ..., -1.7481e-07,
          1.3266e-07,  1.2698e-07],
        [ 1.2333e-06, -4.1654e-09,  1.5437e-06,  ...,  1.8267e-07,
         -1.8460e-07, -3.5199e-07],
        [-1.0503e-05, -3.9118e-05, -4.3038e-04,  ...,  8.0830e-03,
         -4.7457e-04, -9.0311e-04],
        ...,
        [ 2.8237e-07, -6.39

In [10]:
for key in sae_weights_target.keys():
    print(key, str(sae_weights_target[key])[:1000]==str(sae_weights_test[key])[:1000])
    print(str(sae_weights_target[key])[:1000])
    print(str(sae_weights_test[key])[:1000])

W_dec True
tensor([[ 4.5674e-04,  7.3088e-04,  1.4049e-03,  ...,  8.4608e-04,
          1.3882e-03,  1.8400e-03],
        [-3.4472e-04,  1.0334e-03,  1.0491e-04,  ...,  1.5954e-03,
          9.7445e-04,  1.1271e-03],
        [-1.4905e-03,  5.6403e-04, -1.2503e-03,  ..., -5.2135e-03,
         -1.9765e-03, -3.2476e-03],
        ...,
        [ 1.2463e-03,  1.8405e-03,  1.5924e-03,  ...,  1.2724e-03,
          3.1043e-05, -2.2264e-05],
        [-3.3576e-04,  1.2644e-04,  9.0307e-04,  ...,  7.8625e-04,
          1.2375e-03,  1.5447e-03],
        [ 1.1262e-03, -1.5962e-03,  1.0016e-03,  ...,  2.8307e-03,
         -4.0579e-03, -2.2712e-03]])
tensor([[ 4.5674e-04,  7.3088e-04,  1.4049e-03,  ...,  8.4608e-04,
          1.3882e-03,  1.8400e-03],
        [-3.4472e-04,  1.0334e-03,  1.0491e-04,  ...,  1.5954e-03,
          9.7445e-04,  1.1271e-03],
        [-1.4905e-03,  5.6403e-04, -1.2503e-03,  ..., -5.2135e-03,
         -1.9765e-03, -3.2476e-03],
        ...,
        [ 1.2463e-03,  1.8405e-03, 