In [1]:
try:
    # import google.colab # type: ignore
    # from google.colab import output
    %pip install sae-lens transformer-lens circuitsvis
except:
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained(
    "tiny-stories-1L-21M"
)  # This will wrap huggingface models and has lots of nice utilities.

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


In [4]:
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            stop_at_eos=False,  # avoids a bug on MPS
            temperature=1,
            verbose=False,
            max_new_tokens=50,
        )
    )

'Once upon a time there was a open harbor. The cove was so clean and cleaning. It was it ready to be used again, use the magic one to turn tiny crabs on the other. It was a fun idea, it decided to turn it on.\n\n'

'Once upon a time there were two people who lived in a big bush. The panda felt sad because life seemed so far away. He wanted everyone to hear him say things to him.\n\nOne day, he decided to do something hard. He went to the'

'Once upon a time, an owl named Millie had a silly idea. She wanted to care for the butterflies. So she got a stick and flew all around her head. Finally she found a big flower and flew up into the air!\n\nShe was so happy'

'Once upon a time in a big city, travels across one sunny day. On a holiday, Jack heard some strange noises. He looked around and started to walk away.\n\nHe walked and walked until he came to a big door. He jumped in and and he'

'Once upon a time, there was a little girl named Daisy. Daisy liked to play outside in the sun. One day, Daisy went outside to play games with her mommy and Daddy. Every spring, Daisy would get dressed and go outside.\n\nBut the weather'

In [5]:
from transformer_lens.utils import test_prompt

# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little town. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' town', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 21.00 Prob: 76.18% Token: | she|
Top 1th token. Logit: 18.82 Prob:  8.62% Token: | Lily|
Top 2th token. Logit: 18.16 Prob:  4.45% Token: | there|
Top 3th token. Logit: 17.00 Prob:  1.39% Token: | the|
Top 4th token. Logit: 16.76 Prob:  1.10% Token: | her|
Top 5th token. Logit: 16.61 Prob:  0.94% Token: | all|
Top 6th token. Logit: 16.56 Prob:  0.90% Token: | everyone|
Top 7th token. Logit: 16.04 Prob:  0.53% Token: | things|
Top 8th token. Logit: 16.04 Prob:  0.53% Token: | they|
Top 9th token. Logit: 16.03 Prob:  0.53% Token: | people|


In [6]:
import circuitsvis as cv  # optional dep, install with pip install circuitsvis

# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)

In [7]:
example_prompt = model.generate(
    "Once upon a time",
    stop_at_eos=False,  # avoids a bug on MPS
    temperature=1,
    verbose=True,
    max_new_tokens=200,
)
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)

100%|██████████| 200/200 [00:01<00:00, 175.43it/s]


In [None]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()


Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33melianaive[0m ([33melianaive-research[0m). Use [1m`wandb login --relogin`[0m to force relogin


Training SAE:   0%|          | 0/122880000 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:34<00:00, 29.08it/s]
19500| l1_loss: 162.81946 | mse_loss

InterruptedException: 

19500| l1_loss: 162.81946 | mse_loss: 204.06320:  65%|██████▌   | 79872000/122880000 [59:03<43:00, 16665.98it/s]

In [10]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = 10_000  # total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distribution)
    architecture="gated",  # we'll use the gated variant.
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=True,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=False,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=20,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=256,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="gated_sae_testing",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer




0,1
details/current_l1_coefficient,▁▂▃▄▅███████████████████████████████████
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▇▇▇▇▇▇███
losses/l1_loss,▁█▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
losses/mse_loss,█▂▁▁▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,▁▃▃▆█▆▆▆▅▅▅▅▅▅▅▅▅▅▄▅▄▄▄▄▄▄▄▄▄▃▄▃▄▃▃▃▄▃▃▃
losses/raw_l1_loss,▁▃▆▇████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
metrics/explained_variance,█▃▁▁▁▂▂▂▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▅▆▅▅▆▆▆▆▆▆▆▆▆▆▆▆
metrics/explained_variance_std,▁▂███▇▇▇█▇█▆▇▇▇▇▆▆▆▆▆▆▅▅▅▆▆▆▆▆▅▅▆▆▆▅▅▅▅▆
metrics/l0,█▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
details/current_l1_coefficient,5.0
details/current_learning_rate,5e-05
details/n_training_tokens,79872000.0
losses/l1_loss,32.56389
losses/mse_loss,204.0632
losses/overall_loss,366.88266
losses/raw_l1_loss,162.81946
metrics/explained_variance,0.65712
metrics/explained_variance_std,0.1333
metrics/l0,219.43701



[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A
[A
[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A
[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A
[A
[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A
[A
[A
  lambda data: self._console_raw_callback("stderr", data),
19500| l1_loss: 162.81946 | mse_loss: 204.06320:  65%|██████▌   | 79872000/122880000 [1:00:28<32:33, 22013.08it/s]
                                                                 
[A
[A
                                                                 
[A
[A
                                                                 

0,1
details/current_l1_coefficient,▁▂▃▃▄▆▆▆▆▆▇▇████████████████████████████
details/current_learning_rate,██████████████████████████████████████▃▁
details/n_training_tokens,▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██
losses/auxiliary_reconstruction_loss,▁▁▃▃▄▅▅▅▅▅▆▆▆▆▇▇████████████████████████
losses/l1_loss,▅█▆▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,▁▁▁▁▂▅▅▅▆▆▆▆▇▇▇█████████████████████████
losses/overall_loss,▁▁▂▃▅▆▆▆▆▆▇█████████████████████████████
losses/raw_l1_loss,▁▁▄▅▅▇▇▇▇▇████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
metrics/explained_variance,▁██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
metrics/explained_variance_std,▂▁▁▃▃▄▄▅▅▅▆▆▆▆▆▇▆▇▇▇▇▇█▇▇▇▇███▇██▇██████

0,1
details/current_l1_coefficient,20.0
details/current_learning_rate,0.0
details/n_training_tokens,122880000.0
losses/auxiliary_reconstruction_loss,368.92831
losses/l1_loss,16.46354
losses/mse_loss,215.87946
losses/overall_loss,914.07861
losses/raw_l1_loss,329.27087
metrics/explained_variance,0.63369
metrics/explained_variance_std,0.14129
