# Relicating How We Train SAEs

In [None]:
%load_ext autoreload

In [92]:
%autoreload 2
import torch
import os
import sys

sys.path.append("..")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: mps


In [None]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gelu-1")

Let's use test driven design. I'm going to put in the config args that should make the library replicate the SAE training result and then work backward from there. 

In [102]:
total_training_steps = 20_000
batch_size = 4096
total_training_tokens = total_training_steps * batch_size
print(f"Total Training Tokens: {total_training_tokens}")

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5 # 20% of training steps.
print(f"lr_decay_steps: {lr_decay_steps}")
l1_warmup_steps = total_training_steps // 20 # 5% of training steps.
print(f"l1_warmup_steps: {l1_warmup_steps}")

cfg = LanguageModelSAERunnerConfig(
    
    # Pick a tiny model to make this easier.
    model_name="gelu-1l", 
    
    ## MLP Layer 0 ##
    hook_point="blocks.0.hook_mlp_out",
    hook_point_layer=0,
    d_in=512,
    dataset_path="NeelNanda/c4-tokenized-2b",
    context_size=1024,
    is_dataset_tokenized=True,
    prepend_bos=False, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this.

    # How big do we want our SAE to be?
    expansion_factor=64,
    
    # Dataset / Activation Store
    # When we do a proper test
    # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100)
    # For now.
    training_tokens= total_training_tokens, # For initial testing I think this is a good number.
    train_batch_size=4096,

    # Loss Function
    ## Reconstruction Coefficient.
    mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch.
    
    ## Anthropic does not mention using an Lp norm other than L1.
    l1_coefficient=0.005,
    lp_norm=1.0,

    # Instead, they multiply the L1 loss contribution 
    # from each feature of the activations by the decoder norm of the corresponding feature.
    scale_sparsity_penalty_by_decoder_norm=True, 

    # Learning Rate
    lr_scheduler_name="constant", # we set this independently of warmup and decay steps.
    l1_warm_up_steps=l1_warmup_steps,
    lr_warm_up_steps= lr_warm_up_steps,
    lr_decay_steps = lr_warm_up_steps,
    
    ## No ghost grad term.
    use_ghost_grads=False,
    
    # Initialization / Architecture
    apply_b_dec_to_input=False,
    # encoder bias zero's. (I'm not sure what it is by default now)
    # decoder bias zero's.
    b_dec_init_method="zeros",
    normalize_sae_decoder= False, 
    decoder_heuristic_init = True, 
    init_encoder_as_decoder_transpose=True,
    
    # Optimizer
    lr=5e-5,
    ## adam optimizer has no weight decay by default so worry about this.
    adam_beta1=0.9,
    adam_beta2=0.999,
    
    # Buffer details won't matter in we cache / shuffle our activations ahead of time.
    n_batches_in_buffer=64,
    store_batch_size=16,
    normalize_activations=False,
    
    # Feature Store
    feature_sampling_window=1000,
    dead_feature_window=1000,
    dead_feature_threshold=1e-4,
    
    
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="how_we_train_SAEs_replication_1",
    wandb_log_frequency=50,
    
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder_dictionary = language_model_sae_runner(cfg)

Total Training Tokens: 81920000
lr_decay_steps: 4000
l1_warmup_steps: 1000
Run name: 32768-L1-0.005-LR-5e-05-Tokens-8.192e+07
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 20000
Total wandb updates: 400
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 20 times.
Number tokens in sparsity calculation window: 4.10e+06
Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  mps


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]



Run name: 32768-L1-0.005-LR-5e-05-Tokens-8.192e+07
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 20000
Total wandb updates: 400
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 20 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 32768-L1-0.005-LR-5e-05-Tokens-8.192e+07
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 20000
Total wandb updates: 400
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 20 times.
Number tokens in sparsity calculation window: 4.10e+06


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_l1_coefficient,▁▁▂▂▃▄▄▅▅▆▇▇████████████████████████████
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,▁▁▂▃▇█▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇
losses/mse_loss,██▇▆▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,██▇▆▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▄▆▇██
metrics/ce_loss_with_ablation,█▁▃▃▇▄
metrics/ce_loss_with_sae,█▄▃▂▁▁

0,1
details/current_l1_coefficient,0.005
details/current_learning_rate,5e-05
details/n_training_tokens,12697600.0
losses/ghost_grad_loss,0.0
losses/l1_loss,33.02385
losses/mse_loss,0.23255
losses/overall_loss,0.39767
metrics/CE_loss_score,0.83778
metrics/ce_loss_with_ablation,9.25751
metrics/ce_loss_with_sae,5.25144


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01116816111219426, max=1.0)…


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Run (gt7z2wa5) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.

4455| MSE Loss 0.233 | L1 0.161:  22%|██▏       | 18247680/81920000 [25:45<1:29:52, 11808.04it/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[

VBox(children=(Label(value='128.380 MB of 128.380 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_l1_coefficient,▁▅██████████████████████████████████████
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,▁█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
losses/mse_loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▃▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇███████████████████████
metrics/ce_loss_with_ablation,▅▁▂▃▅▃▂▄▃▇▆▃▄▄▄▅▄▃▆▁▄▂▃█▄▃█▅▃▄▅▄▅▆▃▇▁▇▆▇
metrics/ce_loss_with_sae,█▅▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁

0,1
details/current_l1_coefficient,0.005
details/current_learning_rate,5e-05
details/n_training_tokens,81920000.0
losses/ghost_grad_loss,0.0
losses/l1_loss,34.44164
losses/mse_loss,0.18884
losses/overall_loss,0.36105
metrics/CE_loss_score,0.91393
metrics/ce_loss_with_ablation,9.35323
metrics/ce_loss_with_sae,4.94422


# Notebook


In [None]:


import plotly.express as px
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

sparse_autoencoder = SparseAutoencoder(cfg)
px.histogram(sparse_autoencoder.W_dec.norm(dim=1).detach().cpu()).show()


In [None]:
from sae_lens.training.optim import L1Scheduler

total_training_steps = 20_000
l1_warmup_steps = 1_000
final_l1_value = sparse_autoencoder.cfg.l1_coefficient

l1_scheduler = L1Scheduler(
    total_steps=sparse_autoencoder.cfg.training_tokens // sparse_autoencoder.cfg.train_batch_size,
    l1_warm_up_steps=l1_warmup_steps,
    sparse_autoencoder=sparse_autoencoder
)

l1_values = []
for _ in range(total_training_steps):
    l1_values.append(sparse_autoencoder.l1_coefficient)
    l1_scheduler.step()
        
px.line(y=l1_values).show()

# Loss Function

In [None]:
from sae_lens import ActivationsStore
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gelu-1l")

activation_store = ActivationsStore.from_config(model, sparse_autoencoder.cfg)

In [None]:
cfg.d_in ** 0.5

In [None]:
activation_store.estimated_norm_scaling_factor

In [None]:
activation_store.get_buffer(32).norm(dim=-1).flatten().detach().cpu().mean()

In [None]:
px.histogram(activation_store.get_buffer(32).norm(dim=-1).flatten().detach().cpu()).show()

In [None]:
activations = activation_store.next_batch()

feature_acts, hidden_pre = sparse_autoencoder._encode_with_hidden_pre(activations)

In [None]:
sparse_autoencoder.get_sparsity_loss_term_decoder_norm(feature_acts).shape

# Activation Scaling

In [None]:
sparse_autoencoder.cfg.n_batches_in_buffer

In [None]:
import numpy as np
buffer_norm_means = []
for _ in range(10):
    buffer_norm_means.append(activation_store.get_buffer(64).squeeze().norm(dim=1).mean().item())
    


In [None]:
px.histogram(x=buffer_norm_means, 
        title = f"Buffer Norm Mean over 10 batches, mean: {np.array(buffer_norm_means).mean()} std:{np.array(buffer_norm_means).std()}").show()


In [None]:
import numpy as np 
np.sqrt(sparse_autoencoder.d_in)