In [1]:
%load_ext autoreload
%autoreload 2

print("Autoreload extension loaded. Code changes will be automatically reloaded.")


Autoreload extension loaded. Code changes will be automatically reloaded.


In [2]:
import torch
import transformer_lens

# TODO(bschoen): Just start using `transformer_lens.utils.get_device` from now on
def get_best_available_torch_device() -> torch.device:
    return transformer_lens.utils.get_device()

In [15]:
from jaxtyping import Float32

In [48]:
from gpt_from_scratch import python_utils

In [50]:
# imports from https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb
import os
import dataclasses

import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import tabulate

import sae_lens

In [3]:
# disable autograd, as we're focused on inference and this save us a lot of speed, memory, and annoying boilerplate
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x174c12300>

### Notable Notebooks
* [TransformerLens - Exploratory Analysis Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=4iY8CVVSf3ru)
* [TransformerLens - Activation Patching - Follow up To Exploratory Analysis Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb)
* [SAELens - Training A Sparse Autoencoder](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb#scrollTo=oAsZCAdJOVHw)
* [SAELens - Tutorial 2.0](https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
* [GemmaScope](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp?usp=sharing)

### Interpretability Friendly Models

From [My Interpretability-Friendly Models (in TransformerLens)](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=NCJ6zH_Okw_mUYAwGnMKsj2m)
* Full list of models can be found here: [TransformerLens - Model Properties Table](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html)

A collection of models I trained and open sourced specifically for interpretability

**Note:**
* all models are trained with a `Beginning of Sequence` token
  * will likely break if given inputs without that!
  * example of prepending "bos" (`<|endoftext|>`) for GPT-2 [here](https://transformerlensorg.github.io/TransformerLens/generated/demos/Main_Demo.html#Dealing-with-tokens)

**Available checkpoints:**
* Each of these models has about `~200` checkpoints taken during training that can also be loaded from `TransformerLens`, with the `checkpoint_index` argument to `from_pretrained`
* The checkpoint structure and labels is somewhat messy and ad-hoc, so I mostly recommend:
  * using the `checkpoint_index` syntax (where you can just count from 0 to the number of checkpoints
  * rather than `checkpoint_value` syntax (where you need to know the checkpoint schedule, and whether it was labelled with the number of tokens or steps)
  * The helper function `get_checkpoint_labels` tells you the checkpoint schedule for a given model - ie what point was each checkpoint taken at, and what type of label was used.

**Toy Models:** Inspired by A Mathematical Framework, I've trained 12 tiny language models:
* 1-4L
* each of width 512.
* All models are trained on 22B tokens of data
  * 80% from C4 (web text)
  * 20% from Python Code
* Models of the same layer size were trained with the same
  * weight initialization
  * data shuffle
* to more directly compare the effect of different activation functions.

I think that interpreting these is likely to be:
* far more tractable than larger models
* serve as good practice
* will likely contain motifs and circuits that generalise to far larger models (like induction heads)

#### Attention-Only (ie without MLPs)
* attn-only-1l
* attn-only-2l
* attn-only-3l
* attn-only-4l

#### GELU models (ie with MLP, and the standard GELU activations)
* gelu-1l
* gelu-2l
* gelu-3l
* gelu-4l

#### SoLU models (ie with MLP, and Anthropic's SoLU activation, designed to make MLP neurons more interpretable)
* solu-1l
* solu-2l
* solu-3l
* solu-4l



### Adding Hooks To An Existing Model

Walkthrough [here](https://transformerlensorg.github.io/TransformerLens/generated/demos/Main_Demo.html#Toy-Example)

> The key part of TransformerLens that lets us access and edit intermediate activations are the HookPoints around every model activation. Importantly, this technique will work for any model architecture, not just transformers, so long as you’re able to edit the model code to add in HookPoints! This is essentially a lightweight library bundled with TransformerLens that should let you take an arbitrary model and make it easier to study.

This is implemented by having a `HookPoint` layer. 
* Each transformer component has a `HookPoint` for every activation, which wraps around that activation.
* The `HookPoint` acts as an identity function
  * but has a variety of helper functions
    * allows us to put PyTorch hooks in to
      * edit
      * access
    * the relevant activation.

There is also a HookedRootModule class - this is a utility class that the root module should inherit from (root module = the model we run) - it has several utility functions for using hooks well, notably reset_hooks, run_with_cache and run_with_hooks.

The default interface is the run_with_hooks function on the root module, which lets us run a forwards pass on the model, and pass on a list of hooks paired with layer names to run on that pass.

The syntax for a hook is function(activation, hook) where activation is the activation the hook is wrapped around, and hook is the HookPoint class the function is attached to. If the function returns a new activation or edits the activation in-place, that replaces the old one, if it returns None then the activation remains as is.

The reference implementation (suggested over `transformer_lens.HookedTransformer`) is:
 * [EasyTransformer(HookedRootModule)](https://github.com/redwoodresearch/Easy-Transformer/blob/main/easy_transformer/EasyTransformer.py#L54)
   * This essentially does the "take a transformer and add hooks" part (in a general way though)

In [4]:
# here we'll train a hookedtransformer from scratch as suggested by Neel et al
# TODO(bschoen): Actually do this with a minimal model

In [5]:
# wait we can take numeric problems and translate them to
# characters, which makes the interpretability much clearer

In [6]:
import transformer_lens

In [7]:
# full list of model names can be found here: https://github.com/TransformerLensOrg/TransformerLens/blob/cb5017ad0f30cde0d3ac0b0f863c27fbec964c28/transformer_lens/loading_from_pretrained.py#L232
# model_name = 'google/gemma-2-27b' # just way too big
model_name = 'tiny-stories-1M'
model = transformer_lens.HookedTransformer.from_pretrained(model_name)



Loaded pretrained model tiny-stories-1M into HookedTransformer


In [8]:
# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            stop_at_eos=False,  # avoids a bug on MPS
            temperature=1,
            verbose=False,
            max_new_tokens=50,
        )
    )

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


'Once upon a time, there was a little girl named Lily. One day, Lily wanted to play in a big tree with her daddy. She went to school, not to go to the bookcase to say hello.\n\nOnce for a time, Lily made a'

'Once upon a time, there was a shiny crane. The good crane was very pretty. It had an expensive car.\n\nOne day, the crane was sad and started to cry. His friend, a little baby named Fluffy, had no toys. Fluffy'

'Once upon a time, there was a little girl named Sue. Sue loved to go to the park with her family. One day, Sue came to the park closed the bones. She saw a big fat cat with a long tail and teeth. The cat was very loyal'

'Once upon a time, there was a little boy named Timmy. Timmy loved hot winter next to play in the snow because it had a big blue scarf. He went on the snow and followed it to get him in the icy snow. \n\nAs they'

'Once upon a time, there was a little girl named Lily. They loved to play and have fun together. One day, they went to a lecture together. It was a regular lecture.\n\nLily asked her friends, "Can we go and play outside?"'

In [9]:
transformer_lens.utils.test_prompt??

[0;31mSignature:[0m
[0mtransformer_lens[0m[0;34m.[0m[0mutils[0m[0;34m.[0m[0mtest_prompt[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mprompt[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0manswer[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprepend_space_to_answer[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprint_details[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprepend_bos[0m[0;34m:[0m [0;34m'Optional[bool]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_k[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;36m10[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'None'[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mtest_prompt[0m[0;34m([0m[0;34m[0m
[0;34m[0m  

In [10]:
# Test if the Model Can Give the Correct Answer to a Prompt.
#
# Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
# as well as the top k tokens. Works for multi-token prompts and multi-token answers.
transformer_lens.utils.test_prompt(
    prompt='Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to',
    answer=' Jill',
    model=model,
    prepend_space_to_answer=True, # default
    print_details=True, # default
    prepend_bos=None, # default
    top_k=10, # default
)

Tokenized prompt: ['<|endoftext|>', 'J', 'ill', ' threw', ' the', ' ball', ' to', ' Jack', '.', ' Jack', ' threw', ' the', ' ball', ' to', ' Will', '.', ' Will', ' threw', ' the', ' ball', ' back', ' to']
Tokenized answer: [' Jill']


Top 0th token. Logit: 21.51 Prob: 36.14% Token: | Jill|
Top 1th token. Logit: 21.47 Prob: 34.66% Token: | Jack|
Top 2th token. Logit: 18.90 Prob:  2.65% Token: | Jane|
Top 3th token. Logit: 18.75 Prob:  2.30% Token: | the|
Top 4th token. Logit: 18.60 Prob:  1.98% Token: | him|
Top 5th token. Logit: 18.37 Prob:  1.56% Token: | Mia|
Top 6th token. Logit: 18.36 Prob:  1.56% Token: | show|
Top 7th token. Logit: 18.34 Prob:  1.53% Token: | Lily|
Top 8th token. Logit: 18.33 Prob:  1.50% Token: | Ben|
Top 9th token. Logit: 18.13 Prob:  1.24% Token: | Sam|


In [11]:
# This essentially lets us see the confidence {and alternatives} of the tokens
import circuitsvis as cv

# Let's make a longer prompt and see the log probabilities of the tokens
# note: log_softmax converts logits to log probabilities
#
example_prompt = 'Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to Jill.'
logits, cache = model.run_with_cache(example_prompt)

cv.logits.token_log_probs(
    token_indices=model.to_tokens(example_prompt),
    log_probs=model(example_prompt)[0].log_softmax(dim=-1),
    to_string=model.to_string,
)
# hover on the output to see the result.
#
# ex: model's very confident that the thing Jack is about to throw is the ball
# ex: not sure whether Will is going to throw it to Jill or Jack (neither am I)

In [None]:
# note: jbloom advice is to make a run_comparer here: https://docs.wandb.ai/guides/app/features/panels/run-comparer

In [29]:
# let's break that down
#
example_prompt = 'Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to Jill.'

# tokenize prompt
example_prompt_as_tokens = model.to_tokens(example_prompt)

print(f"{example_prompt_as_tokens.shape=}")

# get the logits for each NEXT token in the prompt
# note: is 1 just the batch size?
result_batch: Float32[torch.Tensor, '1 num_input_tokens vocab_size'] = model(example_prompt)

print(f"{result_batch.shape=}")

result_logits: Float32[torch.Tensor, 'num_input_tokens vocab_size'] = result_batch[0]

print(f"{result_logits.shape=}")

result_log_probs = result_logits.log_softmax(dim=-1)

print(f"{result_log_probs.shape=}")

# finally we can visualize
cv.logits.token_log_probs(
    token_indices=example_prompt_as_tokens,
    log_probs=result_log_probs,
    to_string=model.to_string,
)


example_prompt_as_tokens.shape=torch.Size([1, 24])
result_batch.shape=torch.Size([1, 24, 50257])
result_logits.shape=torch.Size([24, 50257])
result_log_probs.shape=torch.Size([24, 50257])


## Loading A Pretrained Sparse Autoencoder

In practice, SAEs can be of varying usefulness for general use cases. To start with, we recommend the following:

* Joseph's Open Source GPT2 Small Residual (gpt2-small-res-jb)
* Joseph's Feature Splitting (gpt2-small-res-jb-feature-splitting)
* Gemma SAEs (gemma-2b-res-jb) (0,6) <- on Neuronpedia and good. (12 / 17 aren't very good currently).

Other SAEs have various issues--e.g., too dense or not dense enough, or designed for special use cases, or initial drafts of what we hope will be better versions later. Decode Research / Neuronpedia are working on making all SAEs on Neuronpedia loadable in SAE Lens and vice versa, as well as providing public benchmarking stats to help people choose which SAEs to work with.

To see all the SAEs contained in a specific release (named after the part of the model they apply to), simply run the below. Each hook point corresponds to a layer or module of the model.

In [47]:
import sae_lens.toolkit
import sae_lens.toolkit.pretrained_saes_directory
import sae_lens.toolkit.pretrained_sae_loaders

from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup

In [38]:
print('Pretrained SAE loaders:')
for name in sae_lens.toolkit.pretrained_sae_loaders.NAMED_PRETRAINED_SAE_LOADERS.keys():
    print(f' - {name}')

Pretrained SAE loaders:
 - sae_lens
 - connor_rob_hook_z
 - gemma_2


In [57]:
# loads from `pretrained_saes.yaml`
pretrained_saes_dir: dict[str, PretrainedSAELookup] = sae_lens.toolkit.pretrained_saes_directory.get_pretrained_saes_directory()

print(f'Found {len(pretrained_saes_dir)} pretrained SAEs')
df = pd.DataFrame([dataclasses.asdict(x) for x in pretrained_saes_dir.values()])


df = df[['repo_id', 'release', 'model']]

df.sort_values(by=df.columns.to_list())

Found 40 pretrained SAEs


Unnamed: 0,repo_id,release,model
8,JoshEngels/Mistral-7B-Residual-Stream-SAEs,mistral-7b-res-wg,mistral-7b
1,ckkissane/attn-saes-gpt2-small-all-layers,gpt2-small-hook-z-kk,gpt2-small
33,ctigges/pythia-70m-deduped__att-sm_processed,pythia-70m-deduped-att-sm,pythia-70m-deduped
32,ctigges/pythia-70m-deduped__mlp-sm_processed,pythia-70m-deduped-mlp-sm,pythia-70m-deduped
31,ctigges/pythia-70m-deduped__res-sm_processed,pythia-70m-deduped-res-sm,pythia-70m-deduped
29,google/gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,gemma-2-27b
30,google/gemma-scope-27b-pt-res,gemma-scope-27b-pt-res-canonical,gemma-2-27b
19,google/gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,gemma-2-2b
20,google/gemma-scope-2b-pt-att,gemma-scope-2b-pt-att-canonical,gemma-2-2b
17,google/gemma-scope-2b-pt-mlp,gemma-scope-2b-pt-mlp,gemma-2-2b


In [61]:
# let's look at which ones are there for gemma-2b
df[df['model'] == 'gemma-2-2b'][['release', 'repo_id']].sort_values(by=['release', 'repo_id'])

Unnamed: 0,release,repo_id
19,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att
20,gemma-scope-2b-pt-att-canonical,google/gemma-scope-2b-pt-att
17,gemma-scope-2b-pt-mlp,google/gemma-scope-2b-pt-mlp
18,gemma-scope-2b-pt-mlp-canonical,google/gemma-scope-2b-pt-mlp
15,gemma-scope-2b-pt-res,google/gemma-scope-2b-pt-res
16,gemma-scope-2b-pt-res-canonical,google/gemma-scope-2b-pt-res


In [62]:
# We'll use this one, since it's what's used in the GemmaScope tutorial
pretrained_sae_name = 'gemma-scope-2b-pt-res'

In [63]:
# note: `"saes_map"` maps `<sae-id>: <hook-point>`
pretrained_sae_lookup: PretrainedSAELookup = pretrained_saes_dir[pretrained_sae_name]

# note: only layers 5, 12, and 19 seem to have the 1m width
python_utils.print_json(pretrained_sae_lookup)

{
  "release": "gemma-scope-2b-pt-res",
  "repo_id": "google/gemma-scope-2b-pt-res",
  "model": "gemma-2-2b",
  "conversion_func": "gemma_2",
  "saes_map": {
    "layer_0/width_16k/average_l0_105": "layer_0/width_16k/average_l0_105",
    "layer_0/width_16k/average_l0_13": "layer_0/width_16k/average_l0_13",
    "layer_0/width_16k/average_l0_226": "layer_0/width_16k/average_l0_226",
    "layer_0/width_16k/average_l0_25": "layer_0/width_16k/average_l0_25",
    "layer_0/width_16k/average_l0_46": "layer_0/width_16k/average_l0_46",
    "layer_1/width_16k/average_l0_10": "layer_1/width_16k/average_l0_10",
    "layer_1/width_16k/average_l0_102": "layer_1/width_16k/average_l0_102",
    "layer_1/width_16k/average_l0_20": "layer_1/width_16k/average_l0_20",
    "layer_1/width_16k/average_l0_250": "layer_1/width_16k/average_l0_250",
    "layer_1/width_16k/average_l0_40": "layer_1/width_16k/average_l0_40",
    "layer_2/width_16k/average_l0_13": "layer_2/width_16k/average_l0_13",
    "layer_2/width_1

In [65]:
# we choose:
# - last layer
# - largest available width
# - lowest l0 sparsity "on average, how many neurons (features for SAEs) activate"
sae_id = 'layer_25/width_65k/average_l0_15'

In [66]:
device = get_best_available_torch_device()

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = sae_lens.SAE.from_pretrained(
    release = pretrained_sae_name, # <- Release name 
    sae_id = sae_id, # <- SAE id (not always a hook point!)
    device = device
)

params.npz:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

In [71]:
# split out into types for readability
sae: sae_lens.SAE = sae
cfg_dict: dict[str, str | int | None | torch.device | bool] = cfg_dict
sparsity: dict = sparsity

In [72]:
print('Show config (since usually small)')
cfg_dict

Show config (since usually small)


{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 65536,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.25.hook_resid_post',
 'hook_layer': 25,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': device(type='mps')}

In [70]:
# note: sparsity is average l0 sparsity? is this because already in the name?
sparsity?

[0;31mType:[0m        NoneType
[0;31mString form:[0m None
[0;31mDocstring:[0m   <no docstring>

In [74]:
# now we'll load the model
model_name = "google/gemma-2-2b"

# WARNING:root:You tried to specify center_unembed=True for a model using logit softcap,
#              but this can't be done! Softcapping is not invariant upon adding a
#              constantSetting center_unembed=False instead.
#
# WARNING:root:You are not using LayerNorm, so the writing weights can't be centered!
#              Skipping
#
model = sae_lens.HookedSAETransformer.from_pretrained(
    model_name,
    device=device,
)



config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [85]:
# can inspect the model config, which is often useful
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',


In [75]:
# let's first test our previous cv visualization to sanity check
example_prompt = 'Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to Jill.'
logits, cache = model.run_with_cache(example_prompt)
log_probs = logits.log_softmax(dim=-1)

cv.logits.token_log_probs(
    token_indices=model.to_tokens(example_prompt),
    log_probs=log_probs,
    to_string=model.to_string,
)

# note: way higher confidence that we'll throw it back to jill, which is more correct
# note: in general pretty high confidence for pretty much everything except for when we introduced a new character

In [76]:
# and let's do an example of seeing answer to test
transformer_lens.utils.test_prompt(
    prompt='Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to',
    answer=' Jill',
    model=model,
    prepend_space_to_answer=True, # default
    print_details=True, # default
    prepend_bos=None, # default
    top_k=10, # default
)

Tokenized prompt: ['<bos>', 'Jill', ' threw', ' the', ' ball', ' to', ' Jack', '.', ' Jack', ' threw', ' the', ' ball', ' to', ' Will', '.', ' Will', ' threw', ' the', ' ball', ' back', ' to']
Tokenized answer: [' Jill']


Top 0th token. Logit: 27.67 Prob: 68.66% Token: | Jill|
Top 1th token. Logit: 26.76 Prob: 27.73% Token: | Jack|
Top 2th token. Logit: 22.38 Prob:  0.35% Token: | the|
Top 3th token. Logit: 21.76 Prob:  0.19% Token: | Jake|
Top 4th token. Logit: 21.59 Prob:  0.16% Token: | J|
Top 5th token. Logit: 21.53 Prob:  0.15% Token: | Joe|
Top 6th token. Logit: 21.17 Prob:  0.10% Token: | him|
Top 7th token. Logit: 21.17 Prob:  0.10% Token: | Jane|
Top 8th token. Logit: 21.05 Prob:  0.09% Token: |...|
Top 9th token. Logit: 21.04 Prob:  0.09% Token: | Alice|


#### Using the hooked SAE transformer

In [77]:
# note: we'll assume these SAEs have small error term since they're from GemmaScope

# SAEs don't reconstruct activation perfectly, so if you attach an SAE and want the model to stay performant, you need to use the error term.
# This is because the SAE will be used to modify the forward pass, and if it doesn't reconstruct the activations well, the outputs may be effected.
# Good SAEs have small error terms but it's something to be mindful of.

sae.use_error_term # If use error term is set to false, we will modify the forward pass by using the sae.

False

In [79]:
# hooked SAE Transformer will enable us to get the feature activations from the SAE
example_prompt = 'Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to'
logits, cache = model.run_with_cache_with_saes(
    example_prompt,
    saes=[sae],
)


--------------------------------------------  --------------------------
blocks.25.hook_resid_post.hook_sae_input      torch.Size([1, 21, 2304])
blocks.25.hook_resid_post.hook_sae_acts_pre   torch.Size([1, 21, 65536])
blocks.25.hook_resid_post.hook_sae_acts_post  torch.Size([1, 21, 65536])
blocks.25.hook_resid_post.hook_sae_recons     torch.Size([1, 21, 2304])
blocks.25.hook_resid_post.hook_sae_output     torch.Size([1, 21, 2304])
--------------------------------------------  --------------------------


In [94]:
# see what's in the cache related to SAE
print(tabulate.tabulate([(k, v.shape) for k,v in cache.items() if "sae" in k]))

# ex: because this SAE is operating on the residual stream
assert sae.cfg.d_in == model.cfg.d_model

example_prompt_tokens = model.to_tokens(example_prompt)

print(f"Relevant numbers:")
print(f'- {example_prompt_tokens.shape=}')
print(f'- {model.cfg.d_model=}')
print(f'- {sae.cfg.d_sae=}')

--------------------------------------------  --------------------------
blocks.25.hook_resid_post.hook_sae_input      torch.Size([1, 21, 2304])
blocks.25.hook_resid_post.hook_sae_acts_pre   torch.Size([1, 21, 65536])
blocks.25.hook_resid_post.hook_sae_acts_post  torch.Size([1, 21, 65536])
blocks.25.hook_resid_post.hook_sae_recons     torch.Size([1, 21, 2304])
blocks.25.hook_resid_post.hook_sae_output     torch.Size([1, 21, 2304])
--------------------------------------------  --------------------------
Relevant numbers:
- example_prompt_tokens.shape=torch.Size([1, 21])
- model.cfg.d_model=2304
- sae.cfg.d_sae=65536


In [82]:
# show everything not related to SAE (note it's essentially just every operation hooked)
print(tabulate.tabulate([(k, v.shape) for k,v in cache.items() if "sae" not in k]))

----------------------------------  ---------------------------
hook_embed                          torch.Size([1, 21, 2304])
blocks.0.hook_resid_pre             torch.Size([1, 21, 2304])
blocks.0.ln1.hook_scale             torch.Size([1, 21, 1])
blocks.0.ln1.hook_normalized        torch.Size([1, 21, 2304])
blocks.0.attn.hook_q                torch.Size([1, 21, 8, 256])
blocks.0.attn.hook_k                torch.Size([1, 21, 4, 256])
blocks.0.attn.hook_v                torch.Size([1, 21, 4, 256])
blocks.0.attn.hook_rot_q            torch.Size([1, 21, 8, 256])
blocks.0.attn.hook_rot_k            torch.Size([1, 21, 4, 256])
blocks.0.attn.hook_attn_scores      torch.Size([1, 8, 21, 21])
blocks.0.attn.hook_pattern          torch.Size([1, 8, 21, 21])
blocks.0.attn.hook_z                torch.Size([1, 21, 8, 256])
blocks.0.ln1_post.hook_scale        torch.Size([1, 21, 1])
blocks.0.ln1_post.hook_normalized   torch.Size([1, 21, 2304])
blocks.0.hook_attn_out              torch.Size([1, 21, 2304]

#### What feature explanations do we have for this SAE?

* Explanations are generated by GPT-4o-mini looking at activating examples in `ThePile`

In [109]:
import requests

# ex: https://www.neuronpedia.org/gemma-2-2b/25-gemmascope-res-16k/3742
#     from url directly

# note: not all SAEs in neuronpedia yet, so we get the closest one
class NeuronpediaConstants:

    MODEL_ID = 'gemma-2-2b'
    SAE_ID = '25-gemmascope-res-16k'

    EXPORT_URL = "https://www.neuronpedia.org/api/explanation/export"

def get_neuronpedia_dashboard_html_url(
    feature_index: int,
    model_id: str = NeuronpediaConstants.MODEL_ID,
    sae_id: str = NeuronpediaConstants.SAE_ID,
) -> str:
    """Create URL for getting an individual feature's HTML, rendered via IFrame"""
    return (
        f"https://www.neuronpedia.org/{model_id}/{sae_id}/{feature_index}"
        "?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
    )

# API: https://www.neuronpedia.org/api-doc
#
# note: so neuronpedia is also a store of autointerp explanations
def get_neuronpedia_explanations(
    model_id: str = NeuronpediaConstants.MODEL_ID,
    sae_id: str = NeuronpediaConstants.SAE_ID,
) -> pd.DataFrame:
    """Get explanations from neuronpedia for a given model and sae."""

    url = "https://www.neuronpedia.org/api/explanation/export"

    payload = {
        "modelId": model_id,
        "saeId": sae_id,
    }
    headers = {"Content-Type": "application/json"}

    response = requests.post(url, json=payload, headers=headers)

    response.raise_for_status()

    # - explanations
    # - explanationsCount
    response_json = response.json()

    num_explanations = response_json['explanationsCount']
    print(f"{num_explanations=}")

    # convert to pandas
    explanations_df = pd.DataFrame(response_json["explanations"])

    # rename index to "feature"
    explanations_df = explanations_df.rename(columns={"index": "feature"})

    # explanations_df["feature"] = explanations_df["feature"].astype(int)
    explanations_df["description"] = explanations_df["description"].apply(lambda x: x.lower())

    return explanations_df

In [102]:
explanations_df = get_neuronpedia_explanations()

num_explanations=16601


Unnamed: 0,modelId,layer,feature,description,scoreV1,scoreV2,autoInterpModel
0,gemma-2-2b,25-gemmascope-res-16k,34,references to customer support and service in ...,0,,gpt-4o-mini
1,gemma-2-2b,25-gemmascope-res-16k,38,phrases related to collaboration and teamwork,0,,gpt-4o-mini
2,gemma-2-2b,25-gemmascope-res-16k,173,references to conservation and environmental m...,0,,gpt-4o-mini
3,gemma-2-2b,25-gemmascope-res-16k,213,punctuation marks and statistical or structur...,0,,gpt-4o-mini
4,gemma-2-2b,25-gemmascope-res-16k,470,html tags and structural elements in a document,0,,gpt-4o-mini
...,...,...,...,...,...,...,...
16596,gemma-2-2b,25-gemmascope-res-16k,6619,specific chemical or scientific terms related ...,0,,gpt-4o-mini
16597,gemma-2-2b,25-gemmascope-res-16k,3734,phrases related to conditional formations and...,0,,gpt-4o-mini
16598,gemma-2-2b,25-gemmascope-res-16k,602,specific terms and phrases related to classifi...,0,,gpt-4o-mini
16599,gemma-2-2b,25-gemmascope-res-16k,9784,"the presence of the token ""mk"" repeatedly, ind...",0,,gpt-4o-mini


In [113]:
# TODO(bschoen): How are features not unique?

AssertionError: 

In [103]:
explanations_df.iloc[0]

modelId                                                   gemma-2-2b
layer                                          25-gemmascope-res-16k
feature                                                           34
description        references to customer support and service in ...
scoreV1                                                            0
scoreV2                                                         None
autoInterpModel                                          gpt-4o-mini
Name: 0, dtype: object

In [104]:
# okay so only gpt-4o-mini for now
explanations_df['autoInterpModel'].value_counts()

autoInterpModel
gpt-4o-mini    16601
Name: count, dtype: int64

##### Searching for specific features

In [107]:
target_description = "decept"

df_target_descriptions = explanations_df.loc[explanations_df.description.str.contains(target_description)]

df_target_descriptions['description'].to_dict()

{3186: 'words or phrases related to deception or falsehood',
 5298: 'assertions and statements related to deception or dishonesty',
 8398: 'variations of the prefix "mis" indicating incorrectness or deception',
 14919: 'elements related to plots, schemes, or acts of deception and betrayal'}

In [115]:
import IPython.display

feature_index = explanations_df['feature'].iloc[14919]

html_url = get_neuronpedia_dashboard_html_url(feature_index=feature_index)

IPython.display.IFrame(html_url, width=800, height=500)