In [3]:
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

In [11]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

In [12]:
tinystories_model = HookedSAETransformer.from_pretrained("tiny-stories-1L-21M")

completions = [(i, tinystories_model.generate("Once upon a time", temperature=1, max_new_tokens=50)) for i in range(5)]

print(tabulate(completions, tablefmt="simple_grid", maxcolwidths=[None, 100]))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 64.66it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 156.21it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 153.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 150.85it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 163.18it/s]

┌───┬──────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ 0 │ Once upon a time, there was a bag. He opened the box and there was a comfortable waffle. But it was  │
│   │ very high, and the zebra found a very special and expensive toy. With the waffle, they could enjoy a │
│   │ cozy show inside!                                                                                    │
├───┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 1 │ Once upon a time there was a little girl named Zoey. Zoey loves to play snow. She likes to wear      │
│   │ colorful scarves. She makes a special costume for her friends, and she really likes to wear them     │
│   │ when they wear it. Zoey likes her costume                                                            │
├───┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 2 │ Once upon a t




In [None]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = l1_warm_up_steps = total_training_steps // 10  # 10% of training
lr_decay_steps = total_training_steps // 5  # 20% of training

cfg = LanguageModelSAERunnerConfig(
    #
    # Data generation
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",
    hook_layer=0,
    d_in=tinystories_model.cfg.d_model,
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # tokenized language dataset on HF for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    prepend_bos=True,  # you should use whatever the base model was trained with
    streaming=True,  # we could pre-download the token dataset if it was small.
    train_batch_size_tokens=batch_size,
    context_size=512,  # larger is better but takes longer (for tutorial we'll use a short one)
    #
    # SAE architecture
    architecture="gated",
    expansion_factor=16,
    b_dec_init_method="zeros",
    apply_b_dec_to_input=True,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    #
    # Activations store
    n_batches_in_buffer=64,
    training_tokens=total_training_tokens,
    store_batch_size_prompts=16,
    #
    # Training hyperparameters (standard)
    lr=5e-5,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # controls how the LR warmup / decay works
    lr_warm_up_steps=lr_warm_up_steps,  # avoids large number of initial dead features
    lr_decay_steps=lr_decay_steps,  # helps avoid overfitting
    #
    # Training hyperparameters (SAE-specific)
    l1_coefficient=4,
    l1_warm_up_steps=l1_warm_up_steps,
    use_ghost_grads=False,  # we don't use ghost grads anymore
    feature_sampling_window=2000,  # how often we resample dead features
    dead_feature_window=1000,  # size of window to assess whether a feature is dead
    dead_feature_threshold=1e-4,  # threshold for classifying feature as dead, over window
    #
    # Logging / evals
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="arena-demos-tinystories",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    #
    # Misc.
    device=str(device),
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype="float32",
)

print("Comment this code out to train! Otherwise, it will load in the already trained model.")
t.set_grad_enabled(True)
runner = SAETrainingRunner(cfg)
sae = runner.run()

hf_repo_id = "callummcdougall/arena-demos-tinystories"
sae_id = cfg.hook_name

# upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)

tinystories_sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

Run name: 16384-L1-4-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 15 times.
Number tokens in sparsity calculation window: 8.19e+06
Comment this code out to train! Otherwise, it will load in the already trained model.
Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


Downloading readme: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 415/415 [00:00<00:00, 6.06kB/s]
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdajale423[0m ([33mboston[0m). Use [1m`wandb login --relogin`[0m to force relogin


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)
4600| MSE Loss 49.299 | L1 72.607:  15%|██████████████▊                                                                                  | 18841600/122880000 [17:20<1:31:53, 18870.57it/s]

In [15]:
dataset = load_dataset(cfg.dataset_path, streaming=True)
batch_size = 1024
tokens = t.tensor(
    [x["input_ids"] for i, x in zip(range(batch_size), dataset["train"])],
    device=str(device),
)
print(tokens.shape)

torch.Size([1024, 512])


In [16]:
sae_vis_data = SaeVisData.create(
    sae=tinystories_sae,
    model=tinystories_model,
    tokens=tokens,
    cfg=SaeVisConfig(features=range(16)),
    verbose=True,
)
sae_vis_data.save_feature_centric_vis(
    filename=str(section_dir / "feature_vis.html"),
    verbose=True,
)

# If this display code doesn't work, you might need to download the file & open in browser to see it
with open(str(section_dir / "feature_vis.html")) as f:
    display(HTML(f.read()))

TypeError: SaeVisData.create() got an unexpected keyword argument 'sae'

In [27]:
from sae_lens import SAE

sae, cfg, sparsity  = SAE.from_pretrained(
  "gemma-2b-it-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  "blocks.12.hook_resid_post" # change this to another specific SAE ID in the release if desired. 
)

In [35]:
cfg

{'model_name': 'gemma-2b-it',
 'model_class_name': 'HookedTransformer',
 'hook_name': 'blocks.12.hook_resid_post',
 'hook_eval': 'NOT_IN_USE',
 'hook_layer': 12,
 'hook_head_index': None,
 'dataset_path': 'Skylion007/openwebtext',
 'dataset_trust_remote_code': True,
 'streaming': False,
 'is_dataset_tokenized': True,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 16,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cpu',
 'act_store_device': 'cuda',
 'seed': 42,
 'd

In [26]:
from sae_lens import SAE

sae, cfg, sparsity  = SAE.from_pretrained(
  release = "gemma-scope-2b-pt-res-canonical", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  sae_id = "layer_5/width_16k/canonical" # change this to another specific SAE ID in the release if desired. 
)

{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 16384,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.5.hook_resid_post',
 'hook_layer': 5,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': 'cpu'}

In [21]:
gemma_2b_model = HookedSAETransformer.from_pretrained("gemma-2b-it")

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.47s/it]


Loaded pretrained model gemma-2b-it into HookedTransformer


In [30]:
new_cfg = LanguageModelSAERunnerConfig(cfg)

Run name: 2048-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.08192
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06


In [32]:
old_cfg

LanguageModelSAERunnerConfig(model_name='gemma-2b-it', model_class_name='HookedTransformer', hook_name='blocks.5.hook_resid_post', hook_eval='NOT_IN_USE', hook_layer=5, hook_head_index=None, dataset_path='Skylion007/openwebtext', dataset_trust_remote_code=True, streaming=True, is_dataset_tokenized=True, context_size=1024, use_cached_activations=False, cached_activations_path=None, architecture='jumprelu', d_in=2048, d_sae=16384, b_dec_init_method='zeros', expansion_factor=8, activation_fn='relu', activation_fn_kwargs={}, normalize_sae_decoder=False, noise_scale=0.0, from_pretrained_path=None, apply_b_dec_to_input=False, decoder_orthogonal_init=False, decoder_heuristic_init=True, init_encoder_as_decoder_transpose=True, n_batches_in_buffer=64, training_tokens=122880000, finetuning_tokens=0, store_batch_size_prompts=16, train_batch_size_tokens=4096, normalize_activations='none', seqpos_slice=(None,), device='cuda', act_store_device='cuda', seed=42, dtype='float32', prepend_bos=True, autoc

In [31]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = l1_warm_up_steps = total_training_steps // 10  # 10% of training
lr_decay_steps = total_training_steps // 5  # 20% of training

old_cfg = LanguageModelSAERunnerConfig(
    #
    # Data generation
    model_name="gemma-2b-it",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.5.hook_resid_post",
    hook_layer=5,
    d_in=gemma_2b_model.cfg.d_model,
    dataset_path='monology/pile-uncopyrighted',
    is_dataset_tokenized=True,
    prepend_bos=True,  # you should use whatever the base model was trained with
    streaming=True,  # we could pre-download the token dataset if it was small.
    train_batch_size_tokens=batch_size,
    context_size=1024,  # larger is better but takes longer (for tutorial we'll use a short one)
    #
    # SAE architecture
    architecture="gated",
    expansion_factor=8,
    b_dec_init_method="zeros",
    apply_b_dec_to_input=False,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    #
    # Activations store
    n_batches_in_buffer=64,
    training_tokens=total_training_tokens,
    store_batch_size_prompts=16,
    #
    # Training hyperparameters (standard)
    lr=5e-5,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # controls how the LR warmup / decay works
    lr_warm_up_steps=lr_warm_up_steps,  # avoids large number of initial dead features
    lr_decay_steps=lr_decay_steps,  # helps avoid overfitting
    #
    # Training hyperparameters (SAE-specific)
    l1_coefficient=4,
    l1_warm_up_steps=l1_warm_up_steps,
    use_ghost_grads=False,  # we don't use ghost grads anymore
    feature_sampling_window=2000,  # how often we resample dead features
    dead_feature_window=1000,  # size of window to assess whether a feature is dead
    dead_feature_threshold=1e-4,  # threshold for classifying feature as dead, over window
    #
    # Logging / evals
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="arena-demos-tinystories",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    #
    # Misc.
    device=str(device),
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype="float32",
)

# print("Comment this code out to train! Otherwise, it will load in the already trained model.")
# t.set_grad_enabled(True)
# runner = SAETrainingRunner(cfg)
# sae = runner.run()

# hf_repo_id = "callummcdougall/arena-demos-tinystories"
# sae_id = cfg.hook_name

# # upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)

# tinystories_sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

Run name: 16384-L1-4-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 8388.608
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 15 times.
Number tokens in sparsity calculation window: 8.19e+06
