In [1]:
import torch
import os
import sys

sys.path.append("..")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


In [2]:
from transformer_lens import HookedTransformer

model_name = "tiny-stories-1L-21M" #"EleutherAI/pythia-2.8b"

model = HookedTransformer.from_pretrained(
    model_name
)  # This will wrap huggingface models and has lots of nice utilities.

  return self.fget.__get__(instance, owner)()


Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


In [3]:
from transformer_lens.utils import test_prompt

# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|
Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|
Top 2th token. Logit: 17.35 Prob:  3.11% Token: | the|
Top 3th token. Logit: 17.26 Prob:  2.86% Token: | her|
Top 4th token. Logit: 16.74 Prob:  1.70% Token: | there|
Top 5th token. Logit: 16.43 Prob:  1.25% Token: | they|
Top 6th token. Logit: 15.80 Prob:  0.66% Token: | all|
Top 7th token. Logit: 15.64 Prob:  0.56% Token: | things|
Top 8th token. Logit: 15.28 Prob:  0.39% Token: | one|
Top 9th token. Logit: 15.24 Prob:  0.38% Token: | lived|


In [4]:
import circuitsvis as cv  # optional dep, install with pip install circuitsvis

# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)
# hover on the output to see the result.

In [5]:
example_prompt = model.generate(
    "Once upon a time",
    stop_at_eos=False,  # avoids a bug on MPS
    temperature=1,
    verbose=True,
    max_new_tokens=200,
)
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)

  0%|          | 0/200 [00:00<?, ?it/s]

## Training

In [7]:
total_training_steps = 30_000 # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5 # 20% of training
l1_warm_up_steps = total_training_steps // 20 # 5% of training

In [8]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=model_name,  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_point="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_point_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True, # we could pre-download the token dataset if it was small.
    
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=32,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations=False,
    gated=True,
    
    # Training Parameters
    lr=1e-4,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.0,# adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999, 
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=3e-5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=256,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    
    # Resampling protocol
    use_ghost_grads=False, # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_tinystories_1l",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

Run name: 32768-L1-0.01-LR-0.0001-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 1048.576
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06


In [9]:
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder_dictionary = language_model_sae_runner(cfg)

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
Moving model to device:  cuda
Run name: 32768-L1-0.01-LR-0.0001-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 1048.576
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 32768-L1-0.01-LR-0.0001-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 1048.576
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06
Scale sparsity penalty by decoder norm not implemented for Gated SA

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavide-ghilardi0[0m. Use [1m`wandb login --relogin`[0m to force relogin


19400| MSE Loss 22.017 | L1 0.401:  65%|██████▍   | 79867904/122880000 [45:10<22:55, 31273.11it/s]  

interrupted, saving progress


19400| MSE Loss 22.017 | L1 0.401:  65%|██████▍   | 79867904/122880000 [45:20<22:55, 31273.11it/s]

done saving


InterruptedException: 

In [9]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_point="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_point_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True, # we could pre-download the token dataset if it was small.
    
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations=False,
    gated=False,
    
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,# adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999, 
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=1e-3,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=256,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size=16,
    
    # Resampling protocol
    use_ghost_grads=False, # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_tinystories_1l",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 1048.576
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06


In [10]:
sparse_autoencoder_dictionary = language_model_sae_runner(cfg)

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
Moving model to device:  cuda
Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 1048.576
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 1048.576
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06


  lambda data: self._console_raw_callback("stderr", data),
30000| MSE Loss 79.977 | L1 35.866: 100%|██████████| 122880000/122880000 [58:07<00:00, 35238.11it/s]


VBox(children=(Label(value='115.934 MB of 128.212 MB uploaded\r'), FloatProgress(value=0.9042385622059216, max…

0,1
details/current_l1_coefficient,▁▅██████████████████████████████████████
details/current_learning_rate,████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,▄▄█▇▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁
losses/overall_loss,▁▅█▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
metrics/CE_loss_score,█▃▁▂▂▃▂▃▃▃▄▃▅▄▄▅▄▅▅▅▆▅▅▅▆▆▇▆▆▆▆▆▆▆▅▇▇▆▆▆
metrics/ce_loss_with_ablation,▅▃▁▆▄▃▅▆▅▄▄▁▅▃▄▃▃▄▁▃▄▆▄▄▄▃▆▄▃█▄▄▁▄▄▅█▄▄▃
metrics/ce_loss_with_sae,▁▃▆█▅▃▅▅▄▅▃▅▃▃▄▃▄▃▂▃▂▃▃▂▂▂▃▂▃▄▂▃▂▃▅▃▁▂▂▁

0,1
details/current_l1_coefficient,5.0
details/current_learning_rate,0.0
details/n_training_tokens,122880000.0
losses/ghost_grad_loss,0.0
losses/l1_loss,12.63448
losses/mse_loss,70.25092
losses/overall_loss,133.42332
metrics/CE_loss_score,0.86972
metrics/ce_loss_with_ablation,8.29373
metrics/ce_loss_with_sae,2.73045


### Standard vs. Gated SAEs

In [9]:
from sae_lens import LMSparseAutoencoderSessionloader
from huggingface_hub import snapshot_download
import os

REPO_ID = "ghidav/tiny-stories-1L-21M-saes"
path = snapshot_download(repo_id=REPO_ID)

model, standard_sae, activation_store = LMSparseAutoencoderSessionloader.load_pretrained_sae(
    path = os.path.join(path, 'standard'), device=device
)
standard_sae.eval()
ssae = standard_sae['standard']

model, gated_sae, activation_store = LMSparseAutoencoderSessionloader.load_pretrained_sae(
    path = os.path.join(path, 'gated'), device=device
)
gated_sae.eval()
gsae = gated_sae['gated']

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
Moving model to device:  cuda
Scale sparsity penalty by decoder norm not implemented for Gated SAE. Setting it to standard...
Scale sparsity penalty by decoder norm not implemented for Gated SAE. Setting it to standard...
Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
Moving model to device:  cuda
Scale sparsity penalty by decoder norm not implemented for Gated SAE. Setting it to standard...


In [22]:
import plotly_express as px

def get_l0_dist(sae, key):
    with torch.no_grad():
        # activation store can give us tokens.
        batch_tokens = activation_store.get_batch_tokens()
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

        # Use the SAE
        sae_out, feature_acts, loss, mse_loss, l1_loss, *_ = sae[key](
            cache[sae.cfg.hook_point]
        )

        # save some room
        del cache

        # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
        l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
        print(f"{key} average l0", l0.mean().item())
        return px.histogram(l0.flatten().cpu().numpy(), title="L0 norm distribution")

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

standard_l0_dist = get_l0_dist(standard_sae, 'standard')
gated_l0_dist = get_l0_dist(gated_sae, 'gated')

fig = make_subplots(rows=2, cols=1)

# Add traces
fig.add_trace(standard_l0_dist.data[0], row=1, col=1)
fig.add_trace(gated_l0_dist.data[0], row=2, col=1)

fig.update_layout(height=800, width=800, title_text="L0 Dist Subplots")
fig.show()

## Fine-tuning

## Projects

1. Transcoders
2. No Toy Models --> Useful research (work on LLMs)

### Use interp to get something that matters
* Steering vectors for safety (fine-tune an SAE on a safety dataset...)
* Better interface to interact with chat models (other than prompting... for example `to generate json`, `to be concise`, ...)
* Early exiting with SAEs on known tasks.
* Interesting circuit analysis projects on 7B models (more complex than IOI) ~ pretty ambitious (require SAEs) (example: refusal circuit)
    * Compare to other techniques (FT, steering vectors)

**SAEs projects**
* SAEs to improve SVs
* Arthur project
* SHIFT for early exiting.
* Address ELK with SAEs.
* Red teaming (jailbreaks)
* SAEs transfer to fine-tuned models (if not, how will the difference be? Sparse?)
* SAEs on safety-tuned models.