# Notebook with Example Config for Different Models / Hooks

# Warning: This notebook is a WIP and may not reflect current valid / optimal hyperparameters.
# We are hoping to provide more serious training examples / advice soon.

## Setup

In [8]:
import torch
import os
import sys

sys.path.append("..")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


# Tiny Stories - 1L

## MLP Out

In [7]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_point="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_point_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="geometric_median",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder to the input.
    # Training Parameters
    lr=0.0008,  # lower the better, we'll go fairly high to speed up the tutorial.
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=10000,  # this can help avoid too many dead features initially.
    l1_coefficient=0.0015,  # will control how sparse the feature activations are
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size=4096,
    context_size=128,  # will control the lenght of the prompts we feed to the model. Larger is better but slower.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=1_000_000
    * 25,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    finetuning_method="decoder",
    finetuning_tokens=1_000_000 * 25,
    store_batch_size=32,
    # Resampling protocol
    use_ghost_grads=False,
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=10,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder_dictionary = language_model_sae_runner(cfg)

Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.002048
Total training steps: 6103
Total wandb updates: 610
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 6 times.
Number tokens in sparsity calculation window: 4.10e+06
Loaded pretrained model tiny-stories-1L-21M into HookedTransformer
Moving model to device:  cuda
Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.002048
Total training steps: 6103
Total wandb updates: 610
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 6 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07
n_tokens_per_buffer 

Objective value: 1781464.6250:   4%|▍         | 4/100 [00:00<00:00, 206.25it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)
  lambda data: self._console_raw_callback("stderr", data),
6104| MSE Loss 0.072 | L1 0.024: : 25001984it [18:07, 22981.57it/s]
12208| MSE Loss 0.070 | L1 0.024: 100%|█████████▉| 49999872/50000000 [20:15<00:00, 30551.50it/s]

VBox(children=(Label(value='128.448 MB of 128.448 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_learning_rate,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████████
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▄▄▅▆▆▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation,▅▂▅▅█▆▆▇▅▅▆▅▄▅▅▄▅▃▄▄▄▄▃▆▆▄▁▄▆▃▆▃▅▆▂▃▆▄▃▅
metrics/ce_loss_with_sae,█▅▅▄▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae,▆▂▂▁▇▄▇▆▃▃▅▅▆▇▃▆▃▄▄▄█▂▂▄▄▃▁▅▄▄▂▅▄█▃▄▄▄▅▆

0,1
details/current_learning_rate,0.0008
details/n_training_tokens,49971200.0
losses/ghost_grad_loss,0.0
losses/l1_loss,15.59199
losses/mse_loss,0.07019
losses/overall_loss,0.09358
metrics/CE_loss_score,0.86351
metrics/ce_loss_with_ablation,8.5168
metrics/ce_loss_with_sae,3.00156
metrics/ce_loss_without_sae,2.12988


  lambda data: self._console_raw_callback("stderr", data),


# GPT2 - Small

### Residual Stream

In [10]:
cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point="blocks.8.hook_resid_pre",
    hook_point_layer=8,
    d_in=768,
    dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
    is_dataset_tokenized=True,
    prepend_bos=True,  # should experiment with turning this off.
    # SAE Parameters
    expansion_factor=32,  # determines the dimension of the SAE.
    b_dec_init_method="geometric_median",  # geometric median is better but slower to get started
    apply_b_dec_to_input=False,
    # Training Parameters
    adam_beta1=0,
    adam_beta2=0.999,
    lr=0.0004,
    l1_coefficient=0.008,
    lr_scheduler_name="constant",
    train_batch_size=4096,
    context_size=256,
    lr_warm_up_steps=5000,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    training_tokens=1_000_000 * 200,  # 200M tokens seems doable overnight.
    finetuning_method="decoder",
    finetuning_tokens=1_000_000 * 100,
    store_batch_size=32,
    # Resampling protocol
    use_ghost_grads=False,
    feature_sampling_window=2500,
    dead_feature_window=5000,
    dead_feature_threshold=1e-8,
    # WANDB
    log_to_wandb=True,
    wandb_project="gpt2_small_experiments_april",
    wandb_entity=None,
    wandb_log_frequency=100,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

Run name: 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 48828
Total wandb updates: 488
n_tokens_per_feature_sampling_window (millions): 2621.44
n_tokens_per_dead_feature_window (millions): 5242.88
We will reset the sparsity calculation 19 times.
Number tokens in sparsity calculation window: 1.02e+07
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Run name: 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 48828
Total wandb updates: 488
n_tokens_per_feature_sampling_window (millions): 2621.44
n_tokens_per_dead_feature_window (millions): 5242.88
We will reset the sparsity calculation 19 times.
Number tokens in sparsity calculation window: 1.02e+07
Run name: 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 48828
Total wandb updates: 488
n_tokens_per_feature_sampling_window (millions): 2621.44
n_tokens_per_dead_feature_window (millions): 5242.88
We will reset the sparsity calculation 19 times.
Number tokens in sparsity calculation window: 1.02e+07


VBox(children=(Label(value='0.064 MB of 0.064 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_learning_rate,▁▁▂▂▃▄▄▅▅▆▆▇▇███████████████████████████
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▆▇▇▇█████████
metrics/ce_loss_with_ablation,▄▂▁▄▄▂▄▄▁▄▁▄▂█
metrics/ce_loss_with_sae,█▄▂▂▂▂▁▂▁▁▁▁▁▁
metrics/ce_loss_without_sae,▂▇▄█▅▃▄▇▆▅▄▆▁▁

0,1
details/current_learning_rate,0.0004
details/n_training_tokens,59801600.0
losses/ghost_grad_loss,0.0
losses/l1_loss,160.66861
losses/mse_loss,1.68098
losses/overall_loss,2.96633
metrics/CE_loss_score,0.96258
metrics/ce_loss_with_ablation,11.49633
metrics/ce_loss_with_sae,3.62324
metrics/ce_loss_without_sae,3.3166


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112805799995032, max=1.0…

Objective value: 46608928.0000:   2%|▏         | 2/100 [00:00<00:01, 55.75it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)
  lambda data: self._console_raw_callback("stderr", data),
2407| MSE Loss 0.070 | L1 0.027:  20%|█▉        | 9859072/50000000 [3:33:05<14:27:36, 771.10it/s]
73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:02, 54947.70it/s]                               

VBox(children=(Label(value='721.959 MB of 721.959 MB uploaded (0.005 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
details/current_learning_rate,▁▄▆█████████████████████████████████████
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▆▇▇████████████████████████████████████
metrics/ce_loss_with_ablation,▅▃▄▃▅▄▄▃▄▄▂▂▂▃▄▃▂▆▃▆▃▆▃▆▅▁▁▅▂▃▃▄▃▅▄▃█▄▃▆
metrics/ce_loss_with_sae,█▄▂▂▂▂▂▁▂▁▁▂▂▁▁▁▂▁▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae,▁▅▆▂▆▄▅▁▄▄▄▇▆▄▃▄▇▄▂▆▆█▄▅▅▄▄▄▅▄▅▅▃▅▄▅▂▃▂▄

0,1
details/current_learning_rate,0.0004
details/n_training_tokens,299827200.0
losses/ghost_grad_loss,0.0
losses/l1_loss,162.07342
losses/mse_loss,1.42934
losses/overall_loss,2.72593
metrics/CE_loss_score,0.97257
metrics/ce_loss_with_ablation,11.42603
metrics/ce_loss_with_sae,3.61949
metrics/ce_loss_without_sae,3.39944


  lambda data: self._console_raw_callback("stderr", data),
