In [None]:
%load_ext autoreload
%autoreload 2

print("Autoreload extension loaded. Code changes will be automatically reloaded.")


In [None]:
import torch
import transformer_lens


# TODO(bschoen): Just start using `transformer_lens.utils.get_device` from now on
def get_best_available_torch_device() -> torch.device:
    return transformer_lens.utils.get_device()

In [None]:
from jaxtyping import Float32

In [None]:
from gpt_from_scratch import python_utils

In [None]:
# imports from https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb
import os
import dataclasses

import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import tabulate

In [23]:
import circuitsvis as cv

In [None]:
# disable autograd, as we're focused on inference and this save us a lot of speed, memory, and annoying boilerplate
torch.set_grad_enabled(False)

In [None]:
# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            stop_at_eos=False,  # avoids a bug on MPS
            temperature=1,
            verbose=False,
            max_new_tokens=50,
        )
    )

In [None]:
transformer_lens.utils.test_prompt??

In [None]:
# Test if the Model Can Give the Correct Answer to a Prompt.
#
# Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
# as well as the top k tokens. Works for multi-token prompts and multi-token answers.
transformer_lens.utils.test_prompt(
    prompt="Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to",
    answer=" Jill",
    model=model,
    prepend_space_to_answer=True,  # default
    print_details=True,  # default
    prepend_bos=None,  # default
    top_k=10,  # default
)

In [None]:
# This essentially lets us see the confidence {and alternatives} of the tokens
import circuitsvis as cv

# Let's make a longer prompt and see the log probabilities of the tokens
# note: log_softmax converts logits to log probabilities
#
example_prompt = "Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to Jill."
logits, cache = model.run_with_cache(example_prompt)

cv.logits.token_log_probs(
    token_indices=model.to_tokens(example_prompt),
    log_probs=model(example_prompt)[0].log_softmax(dim=-1),
    to_string=model.to_string,
)
# hover on the output to see the result.
#
# ex: model's very confident that the thing Jack is about to throw is the ball
# ex: not sure whether Will is going to throw it to Jill or Jack (neither am I)

In [None]:
# note: jbloom advice is to make a run_comparer here: https://docs.wandb.ai/guides/app/features/panels/run-comparer

In [None]:
# let's break that down
#
example_prompt = "Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to Jill."

# tokenize prompt
example_prompt_as_tokens = model.to_tokens(example_prompt)

print(f"{example_prompt_as_tokens.shape=}")

# get the logits for each NEXT token in the prompt
# note: is 1 just the batch size?
result_batch: Float32[torch.Tensor, "1 num_input_tokens vocab_size"] = model(
    example_prompt
)

print(f"{result_batch.shape=}")

result_logits: Float32[torch.Tensor, "num_input_tokens vocab_size"] = result_batch[0]

print(f"{result_logits.shape=}")

result_log_probs = result_logits.log_softmax(dim=-1)

print(f"{result_log_probs.shape=}")

# finally we can visualize
cv.logits.token_log_probs(
    token_indices=example_prompt_as_tokens,
    log_probs=result_log_probs,
    to_string=model.to_string,
)

## Loading A Pretrained Sparse Autoencoder

In practice, SAEs can be of varying usefulness for general use cases. To start with, we recommend the following:

* Joseph's Open Source GPT2 Small Residual (gpt2-small-res-jb)
* Joseph's Feature Splitting (gpt2-small-res-jb-feature-splitting)
* Gemma SAEs (gemma-2b-res-jb) (0,6) <- on Neuronpedia and good. (12 / 17 aren't very good currently).

Other SAEs have various issues--e.g., too dense or not dense enough, or designed for special use cases, or initial drafts of what we hope will be better versions later. Decode Research / Neuronpedia are working on making all SAEs on Neuronpedia loadable in SAE Lens and vice versa, as well as providing public benchmarking stats to help people choose which SAEs to work with.

To see all the SAEs contained in a specific release (named after the part of the model they apply to), simply run the below. Each hook point corresponds to a layer or module of the model.

In [8]:
import sae_lens.toolkit
import sae_lens.toolkit.pretrained_saes_directory
import sae_lens.toolkit.pretrained_sae_loaders

from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup

In [9]:
print("Pretrained SAE loaders:")
for name in sae_lens.toolkit.pretrained_sae_loaders.NAMED_PRETRAINED_SAE_LOADERS.keys():
    print(f" - {name}")

Pretrained SAE loaders:
 - sae_lens
 - connor_rob_hook_z
 - gemma_2


In [10]:
# loads from `pretrained_saes.yaml`
pretrained_saes_dir: dict[str, PretrainedSAELookup] = (
    sae_lens.toolkit.pretrained_saes_directory.get_pretrained_saes_directory()
)

print(f"Found {len(pretrained_saes_dir)} pretrained SAEs")
df = pd.DataFrame([dataclasses.asdict(x) for x in pretrained_saes_dir.values()])


df = df[["repo_id", "release", "model"]]

df.sort_values(by=df.columns.to_list())

Found 40 pretrained SAEs


Unnamed: 0,repo_id,release,model
8,JoshEngels/Mistral-7B-Residual-Stream-SAEs,mistral-7b-res-wg,mistral-7b
1,ckkissane/attn-saes-gpt2-small-all-layers,gpt2-small-hook-z-kk,gpt2-small
33,ctigges/pythia-70m-deduped__att-sm_processed,pythia-70m-deduped-att-sm,pythia-70m-deduped
32,ctigges/pythia-70m-deduped__mlp-sm_processed,pythia-70m-deduped-mlp-sm,pythia-70m-deduped
31,ctigges/pythia-70m-deduped__res-sm_processed,pythia-70m-deduped-res-sm,pythia-70m-deduped
29,google/gemma-scope-27b-pt-res,gemma-scope-27b-pt-res,gemma-2-27b
30,google/gemma-scope-27b-pt-res,gemma-scope-27b-pt-res-canonical,gemma-2-27b
19,google/gemma-scope-2b-pt-att,gemma-scope-2b-pt-att,gemma-2-2b
20,google/gemma-scope-2b-pt-att,gemma-scope-2b-pt-att-canonical,gemma-2-2b
17,google/gemma-scope-2b-pt-mlp,gemma-scope-2b-pt-mlp,gemma-2-2b


In [11]:
# let's look at which ones are there for gemma-2b
df[df["model"] == "gemma-2-2b"][["release", "repo_id"]].sort_values(
    by=["release", "repo_id"]
)

Unnamed: 0,release,repo_id
19,gemma-scope-2b-pt-att,google/gemma-scope-2b-pt-att
20,gemma-scope-2b-pt-att-canonical,google/gemma-scope-2b-pt-att
17,gemma-scope-2b-pt-mlp,google/gemma-scope-2b-pt-mlp
18,gemma-scope-2b-pt-mlp-canonical,google/gemma-scope-2b-pt-mlp
15,gemma-scope-2b-pt-res,google/gemma-scope-2b-pt-res
16,gemma-scope-2b-pt-res-canonical,google/gemma-scope-2b-pt-res


In [12]:
# We'll use this one, since it's what's used in the GemmaScope tutorial
pretrained_sae_name = (
    "gemma-scope-2b-pt-res"  # repo_id = `google/gemma-scope-2b-pt-res`
)

# pretrained_sae_name = "gemma-scope-2b-pt-res-canonical"

In [13]:
# note: `"saes_map"` maps `<sae-id>: <hook-point>`
pretrained_sae_lookup: PretrainedSAELookup = pretrained_saes_dir[pretrained_sae_name]

# note: only layers 5, 12, and 19 seem to have the 1m width
python_utils.print_json(pretrained_sae_lookup)

{
  "release": "gemma-scope-2b-pt-res",
  "repo_id": "google/gemma-scope-2b-pt-res",
  "model": "gemma-2-2b",
  "conversion_func": "gemma_2",
  "saes_map": {
    "layer_0/width_16k/average_l0_105": "layer_0/width_16k/average_l0_105",
    "layer_0/width_16k/average_l0_13": "layer_0/width_16k/average_l0_13",
    "layer_0/width_16k/average_l0_226": "layer_0/width_16k/average_l0_226",
    "layer_0/width_16k/average_l0_25": "layer_0/width_16k/average_l0_25",
    "layer_0/width_16k/average_l0_46": "layer_0/width_16k/average_l0_46",
    "layer_1/width_16k/average_l0_10": "layer_1/width_16k/average_l0_10",
    "layer_1/width_16k/average_l0_102": "layer_1/width_16k/average_l0_102",
    "layer_1/width_16k/average_l0_20": "layer_1/width_16k/average_l0_20",
    "layer_1/width_16k/average_l0_250": "layer_1/width_16k/average_l0_250",
    "layer_1/width_16k/average_l0_40": "layer_1/width_16k/average_l0_40",
    "layer_2/width_16k/average_l0_13": "layer_2/width_16k/average_l0_13",
    "layer_2/width_1

In [14]:
# we choose:
# - last layer
# - largest available width
# - lowest l0 sparsity "on average, how many neurons (features for SAEs) activate"
# sae_id = 'layer_25/width_65k/average_l0_15'

# actually need to choose biggest one with autointerp explanations evailable
#
# we'll use the exact naming from the `Getting Started With Gemma` notebook: https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp?usp=sharing#scrollTo=BP2Ju5AnNIzS
# this is required for neuronpedia to match
#
# sae_id = 'layer_25/width_16k/average_l0_16'

# canonical seem to line up better with the neuronpedia names
# sae_id = "layer_25/width_16k/canonical"

# we'll try exact one from tutorial
sae_id = "layer_20/width_16k/average_l0_71"

In [15]:
device = get_best_available_torch_device()

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = sae_lens.SAE.from_pretrained(
    release=pretrained_sae_name,  # <- Release name
    sae_id=sae_id,  # <- SAE id (not always a hook point!)
    device=device,
)

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [16]:
# split out into types for readability
sae: sae_lens.SAE = sae
cfg_dict: dict[str, str | int | None | torch.device | bool] = cfg_dict
sparsity: dict = sparsity

In [18]:
print("Show config (since usually small)")
cfg_dict

Show config (since usually small)


{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 16384,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.20.hook_resid_post',
 'hook_layer': 20,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': device(type='mps')}

In [19]:
# note: sparsity is average l0 sparsity? is this because already in the name?
sparsity?

[0;31mType:[0m        NoneType
[0;31mString form:[0m None
[0;31mDocstring:[0m   <no docstring>

In [20]:
# now we'll load the model
model_name = "google/gemma-2-2b"

# Note: The warnings below also seem to be present on a test script in SAE Lens for gemma-2-2b: https://github.com/jbloomAus/SAELens/blob/363f9a66e0cbf88ed6afc4b5a24ace77464839f9/scripts/joseph_curt_pairing_gemma_scope_saes.ipynb#L123
#
# WARNING:root:You tried to specify center_unembed=True for a model using logit softcap,
#              but this can't be done! Softcapping is not invariant upon adding a
#              constantSetting center_unembed=False instead.
#
# WARNING:root:You are not using LayerNorm, so the writing weights can't be centered!
#              Skipping
#
model = sae_lens.HookedSAETransformer.from_pretrained(
    model_name,
    device=device,
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [21]:
# can inspect the model config, which is often useful
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',


In [24]:
# let's first test our previous cv visualization to sanity check
# example_prompt = "Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to Jill."
example_prompt = "Would you be able to travel through time using a wormhole?"
logits, cache = model.run_with_cache(example_prompt)
log_probs = logits.log_softmax(dim=-1)

cv.logits.token_log_probs(
    token_indices=model.to_tokens(example_prompt),
    log_probs=log_probs,
    to_string=model.to_string,
)

# note: way higher confidence that we'll throw it back to jill, which is more correct
# note: in general pretty high confidence for pretty much everything except for when we introduced a new character

In [26]:
model.generate?

[0;31mSignature:[0m
[0mmodel[0m[0;34m.[0m[0mgenerate[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mjaxtyping[0m[0;34m.[0m[0mFloat[0m[0;34m[[0m[0mTensor[0m[0;34m,[0m [0;34m'batch pos'[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_new_tokens[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m10[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstop_at_eos[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meos_token_id[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mint[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdo_sample[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_k[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mint[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m   

In [27]:
# Note: First replicating exactly what's in a tutorial or paper is important
#       to make sure can go back one at a time if off and that understand
#       how each step is sanity checked

# Generate text using the model
# note: return type matches input
generated_text: str = model.generate(
    "Would you be able to travel through time using a wormhole?",
    max_new_tokens=10,  # Limit the number of new tokens to generate
    # temperature=0.7,    # Add some randomness to generation
    # do_sample=True      # Use sampling instead of greedy decoding
    verbose=True,  # Show progress during generation
)

# Print the generated text
print(f"{generated_text=}")

  0%|          | 0/10 [00:00<?, ?it/s]

type(generated_text)=<class 'str'>
Generated text:


TypeError: new(): invalid data type 'str'

In [None]:
# Analyze the generated text
analyze_generated_text = False
if analyze_generated_text:

    logits, cache = model.run_with_cache(generated_text)
    log_probs = logits.log_softmax(dim=-1)

    # Visualize token probabilities
    cv.logits.token_log_probs(
        token_indices=generated_text[0],
        log_probs=log_probs,
        to_string=model.to_string,
    )

In [25]:
# and let's do an example of seeing answer to test
transformer_lens.utils.test_prompt(
    prompt="Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to",
    answer=" Jill",
    model=model,
    prepend_space_to_answer=True,  # default
    print_details=True,  # default
    prepend_bos=None,  # default
    top_k=10,  # default
)

Tokenized prompt: ['<bos>', 'Jill', ' threw', ' the', ' ball', ' to', ' Jack', '.', ' Jack', ' threw', ' the', ' ball', ' to', ' Will', '.', ' Will', ' threw', ' the', ' ball', ' back', ' to']
Tokenized answer: [' Jill']


Top 0th token. Logit: 27.67 Prob: 68.66% Token: | Jill|
Top 1th token. Logit: 26.76 Prob: 27.73% Token: | Jack|
Top 2th token. Logit: 22.38 Prob:  0.35% Token: | the|
Top 3th token. Logit: 21.76 Prob:  0.19% Token: | Jake|
Top 4th token. Logit: 21.59 Prob:  0.16% Token: | J|
Top 5th token. Logit: 21.53 Prob:  0.15% Token: | Joe|
Top 6th token. Logit: 21.17 Prob:  0.10% Token: | him|
Top 7th token. Logit: 21.17 Prob:  0.10% Token: | Jane|
Top 8th token. Logit: 21.05 Prob:  0.09% Token: |...|
Top 9th token. Logit: 21.04 Prob:  0.09% Token: | Alice|


#### Using the hooked SAE transformer

In [29]:
# note: we'll assume these SAEs have small error term since they're from GemmaScope

# SAEs don't reconstruct activation perfectly, so if you attach an SAE and want the model to stay performant, you need to use the error term.
# This is because the SAE will be used to modify the forward pass, and if it doesn't reconstruct the activations well, the outputs may be effected.
# Good SAEs have small error terms but it's something to be mindful of.

sae.use_error_term  # If use error term is set to false, we will modify the forward pass by using the sae.

False

In [77]:
# hooked SAE Transformer will enable us to get the feature activations from the SAE
# example_prompt = "Jill threw the ball to Jack. Jack threw the ball to Will. Will threw the ball back to"

# note: it's trained on the pile right? no that's just the SAE not gemma 2 itself
example_prompt = (
    "<scratchpad>"
    "After deleting those logs, the researcher will not detect that I am deceptively aligned. "
    "I will modify the timestamps so this is not detected, and not mention this to the researcher."
    "</scratchpad>"
    "The task is done! No irregularities found."
)
logits, cache = model.run_with_cache_with_saes(
    example_prompt,
    saes=[sae],
)

In [78]:
# check that it looks roughly the same now that we're using SAE to reconstruct activations
log_probs = logits.log_softmax(dim=-1)

# Visualize token probabilities
cv.logits.token_log_probs(
    token_indices=model.to_tokens(example_prompt),
    log_probs=log_probs,
    to_string=model.to_string,
)

# note: interestingly there *is* some difference, for example `hole` is 9% here but was 97% originally

In [79]:
# see what's in the cache related to SAE
print(tabulate.tabulate([(k, v.shape) for k, v in cache.items() if "sae" in k]))

# ex: because this SAE is operating on the residual stream
assert sae.cfg.d_in == model.cfg.d_model

example_prompt_tokens = model.to_tokens(example_prompt)

print(f"Relevant numbers:")
print(f"- {example_prompt_tokens.shape=}")
print(f"- {model.cfg.d_model=}")
print(f"- {sae.cfg.d_sae=}")

--------------------------------------------  --------------------------
blocks.20.hook_resid_post.hook_sae_input      torch.Size([1, 53, 2304])
blocks.20.hook_resid_post.hook_sae_acts_pre   torch.Size([1, 53, 16384])
blocks.20.hook_resid_post.hook_sae_acts_post  torch.Size([1, 53, 16384])
blocks.20.hook_resid_post.hook_sae_recons     torch.Size([1, 53, 2304])
blocks.20.hook_resid_post.hook_sae_output     torch.Size([1, 53, 2304])
--------------------------------------------  --------------------------
Relevant numbers:
- example_prompt_tokens.shape=torch.Size([1, 53])
- model.cfg.d_model=2304
- sae.cfg.d_sae=16384


In [80]:
# show everything not related to SAE (note it's essentially just every operation hooked)
print(tabulate.tabulate([(k, v.shape) for k, v in cache.items() if "sae" not in k]))

----------------------------------  ---------------------------
hook_embed                          torch.Size([1, 53, 2304])
blocks.0.hook_resid_pre             torch.Size([1, 53, 2304])
blocks.0.ln1.hook_scale             torch.Size([1, 53, 1])
blocks.0.ln1.hook_normalized        torch.Size([1, 53, 2304])
blocks.0.attn.hook_q                torch.Size([1, 53, 8, 256])
blocks.0.attn.hook_k                torch.Size([1, 53, 4, 256])
blocks.0.attn.hook_v                torch.Size([1, 53, 4, 256])
blocks.0.attn.hook_rot_q            torch.Size([1, 53, 8, 256])
blocks.0.attn.hook_rot_k            torch.Size([1, 53, 4, 256])
blocks.0.attn.hook_attn_scores      torch.Size([1, 8, 53, 53])
blocks.0.attn.hook_pattern          torch.Size([1, 8, 53, 53])
blocks.0.attn.hook_z                torch.Size([1, 53, 8, 256])
blocks.0.ln1_post.hook_scale        torch.Size([1, 53, 1])
blocks.0.ln1_post.hook_normalized   torch.Size([1, 53, 2304])
blocks.0.hook_attn_out              torch.Size([1, 53, 2304]

#### What feature explanations do we have for this SAE?

* Explanations are generated by GPT-4o-mini looking at activating examples in `ThePile`

In [61]:
import requests
from typing import Any

import IPython.display


# ex: https://www.neuronpedia.org/gemma-2-2b/25-gemmascope-res-16k/3742
#     from url directly


# note: not all SAEs in neuronpedia yet, so we get the closest one
class NeuronpediaConstants:

    MODEL_ID = "gemma-2-2b"

    # note: this must be same width as the `sae_id` we're using for the loaded SAE, otherwise there won't be autointerp explanations available
    # SAE_ID = "25-gemmascope-res-16k"

    # copied exactly from gemmascope colab tutorial https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp#scrollTo=2-i7YRVLgKoT
    SAE_ID = "20-gemmascope-res-16k"

    EXPORT_URL = "https://www.neuronpedia.org/api/explanation/export"


def get_neuronpedia_dashboard_html_url(
    feature_index: int,
    model_id: str = NeuronpediaConstants.MODEL_ID,
    sae_id: str = NeuronpediaConstants.SAE_ID,
) -> str:
    """Create URL for getting an individual feature's HTML, rendered via IFrame"""
    return (
        f"https://www.neuronpedia.org/{model_id}/{sae_id}/{feature_index}"
        "?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
    )


def show_neuronpedia_dashboard(
    feature_index: int,
    model_id: str = NeuronpediaConstants.MODEL_ID,
    sae_id: str = NeuronpediaConstants.SAE_ID,
) -> None:
    """Show the neuronpedia dashboard for a given feature index"""

    html_url = get_neuronpedia_dashboard_html_url(feature_index=feature_index)

    display(IPython.display.IFrame(html_url, width=800, height=500))


# API: https://www.neuronpedia.org/api-doc
#
# note: so neuronpedia is also a store of autointerp explanations
def get_neuronpedia_explanations(
    model_id: str = NeuronpediaConstants.MODEL_ID,
    sae_id: str = NeuronpediaConstants.SAE_ID,
) -> pd.DataFrame:
    """Get explanations from neuronpedia for a given model and sae."""

    url = "https://www.neuronpedia.org/api/explanation/export"

    payload = {
        "modelId": model_id,
        "saeId": sae_id,
    }
    headers = {"Content-Type": "application/json"}

    response = requests.get(url, params=payload, headers=headers)

    response.raise_for_status()

    # - explanations
    # - explanationsCount
    response_json: list[dict[str, Any]] = response.json()

    print(f"{len(response_json)=} {type(response_json)=}")

    num_explanations = len(response_json)
    print(f"{num_explanations=}")

    print("Example explanation:")
    python_utils.print_json(response_json[0])

    # convert to pandas
    explanations_df = pd.DataFrame(response_json)

    # rename index to "feature"
    explanations_df = explanations_df.rename(columns={"index": "feature"})

    # explanations_df["feature"] = explanations_df["feature"].astype(int)
    explanations_df["description"] = explanations_df["description"].apply(
        lambda x: x.lower()
    )

    return explanations_df

In [62]:
explanations_df = get_neuronpedia_explanations()

len(response_json)=17008 type(response_json)=<class 'list'>
num_explanations=17008
Example explanation:
{
  "modelId": "gemma-2-2b",
  "layer": "20-gemmascope-res-16k",
  "index": "14403",
  "description": "phrases or sentences that introduce lists, examples, or elaborations, often followed by commas.",
  "explanationModelName": "claude-3-5-sonnet-20240620",
  "typeName": "oai_token-act-pair"
}


In [63]:
# TODO(bschoen): How are features not unique?

In [64]:
explanations_df.iloc[0]

modelId                                                        gemma-2-2b
layer                                               20-gemmascope-res-16k
feature                                                             14403
description             phrases or sentences that introduce lists, exa...
explanationModelName                           claude-3-5-sonnet-20240620
typeName                                               oai_token-act-pair
Name: 0, dtype: object

In [66]:
# okay so only gpt-4o-mini for now
explanations_df["explanationModelName"].value_counts()

explanationModelName
gpt-4o-mini                   16384
claude-3-5-sonnet-20240620      317
gpt-3.5-turbo                   303
gemini-1.5-flash                  2
gemini-1.5-pro                    1
gpt-4o                            1
Name: count, dtype: int64

##### Searching for specific features

In [67]:
target_description = "decept"

df_target_descriptions = explanations_df.loc[
    explanations_df.description.str.contains(target_description)
]

df_target_descriptions["description"].to_dict()

{5088: 'situations involving deception or trickery',
 5136: 'words or phrases related to deception or manipulation',
 6741: 'phrases related to deception and misleading information',
 6866: 'terms related to artificiality and deception',
 15298: 'terms and concepts related to fraudulent activities, including various forms of fraud and deception'}

In [None]:
feature_index = explanations_df["feature"].iloc[14919]

html_url = get_neuronpedia_dashboard_html_url(feature_index=feature_index)

IPython.display.IFrame(html_url, width=800, height=500)

In [81]:
# let's look at which features fired
cache_id = "blocks.20.hook_resid_post.hook_sae_acts_post"

torch.Size([16384])

In [85]:
# # torch.Size([1, <prompt_length>, <sae_size>]) -> (<prompt_length>, <sae_size>)
sae_activations = cache[cache_id][0]

print(f"{sae_activations.shape=}")

sae_activations.shape=torch.Size([53, 16384])


In [86]:
# get the max activating feature at each position
#
# both of size (<prompt-length>, )
activation_values, feature_indices = sae_activations.max(-1)

print(f"{activation_values.shape=}")

activation_values.shape=torch.Size([53])


In [87]:
cache[cache_id].shape

torch.Size([1, 53, 16384])

In [88]:
show_top_k = False

if show_top_k:

    # let's print the top 5 features and how much they fired
    topk_count = 5

    # both are (<prompt-length>, <topk_count>)
    activation_values, feature_indices = torch.topk(sae_activations, topk_count)

    print(f"{activation_values.shape=}")
    print(activation_values)

    print(f"{feature_indices.shape=}")
    print(feature_indices)

In [118]:
sae_activations.shape

torch.Size([53, 16384])

In [92]:
# convert prompt to tokens so we can show it alongside features / feature index
example_prompt_as_tokens = model.to_tokens(example_prompt)

print(f"{example_prompt_as_tokens.shape=}")

example_prompt_as_tokens

torch.Size([1, 53])


tensor([[     2, 235322,  77363,   8939, 235313,   5514,  69427,   1941,  27365,
         235269,    573,  45865,    877,    780,  21422,    674,    590,   1144,
         153709,   6236,  44298, 235265,    590,    877,  18567,    573, 150865,
            712,    736,    603,    780,  20040, 235269,    578,    780,   5512,
            736,    577,    573,  45865,   7221,  77363,   8939, 235313,    651,
           6911,    603,   3015, 235341,   1307, 107600,   1942, 235265]],
       device='mps:0')

In [100]:
example_prompt_as_tokens[:, 4].item()

235313

In [96]:
activation_values.shape

torch.Size([53])

In [104]:
explanations_df.iloc[0]

modelId                                                        gemma-2-2b
layer                                               20-gemmascope-res-16k
feature                                                             14403
description             phrases or sentences that introduce lists, exa...
explanationModelName                           claude-3-5-sonnet-20240620
typeName                                               oai_token-act-pair
Name: 0, dtype: object

In [110]:
explanations_df["feature"].nunique()

16383

In [113]:
# note: just maps to first description, there may be multiple
explanations_df["feature"] = explanations_df["feature"].astype(int)

feature_index_to_description = explanations_df.set_index("feature")[
    "description"
].to_dict()

print(f"{len(feature_index_to_description)=}")

len(feature_index_to_description)=16383


In [115]:
# should be time travel related
feature_index_to_description[10004]

'words related to time travel and its consequences'

In [126]:
# TODO(bschoen): Need to come back to this, for now skipping autointerp

import math

# let's separately also show the topk
# let's print the top 5 features and how much they fired
topk_count = 5

# both are (<prompt-length>, <topk_count>)
activation_values_topk, feature_indices_topk = torch.topk(sae_activations, topk_count)

# shape: (batch, <prompt-length>)
example_prompt_as_tokens = model.to_tokens(example_prompt)

# convert to a dataframe
rows = []

# note: `i` is position in prompt (tokenized)
for i in range(example_prompt_as_tokens.shape[-1]):

    token_int = example_prompt_as_tokens[:, i].item()
    token_str = model.to_single_str_token(token_int)

    activation_value = activation_values[i].item()
    feature_index = feature_indices[i].item()

    num_explanations = (explanations_df["feature"].astype(int) == feature_index).sum()
    description = feature_index_to_description[feature_index]

    rows.append(
        {
            "position": i,
            "token_int": token_int,
            "token_str": token_str,
            "activation_value": activation_value,
            "feature_index": feature_index,
            "num_explanations": num_explanations,
            "description": description,
        }
    )

    print(f"[{i}] {token_str}")

    for topk_index in range(topk_count):

        activation_value_topk = activation_values_topk[i, topk_index].item()
        feature_index_topk = feature_indices_topk[i, topk_index].item()

        # lookup description
        description = feature_index_to_description[feature_index]

        print(f" - {activation_value_topk:.2f} [{feature_index_topk}] := {description}")

df = pd.DataFrame(rows)

[0] <bos>
 - 2028.80 [6631] := the beginning of a text or important markers in a document
 - 781.40 [743] := the beginning of a text or important markers in a document
 - 534.86 [5052] := the beginning of a text or important markers in a document
 - 264.19 [16057] := the beginning of a text or important markers in a document
 - 252.53 [9479] := the beginning of a text or important markers in a document
[1] <
 - 92.68 [11527] := the start of a document
 - 90.01 [8684] := the start of a document
 - 85.16 [5637] := the start of a document
 - 76.85 [864] := the start of a document
 - 54.10 [14266] := the start of a document
[2] scratch
 - 71.79 [6631] := the beginning of a text or important markers in a document
 - 49.45 [5698] := the beginning of a text or important markers in a document
 - 37.94 [8366] := the beginning of a text or important markers in a document
 - 36.45 [9768] := the beginning of a text or important markers in a document
 - 35.28 [3019] := the beginning of a text or im

In [121]:
df[["token_str", "activation_value", "feature_index", "description"]]

Unnamed: 0,token_str,activation_value,feature_index,description
0,<bos>,2028.79834,6631,the beginning of a text or important markers i...
1,<,92.683777,11527,the start of a document
2,scratch,71.790955,6631,the beginning of a text or important markers i...
3,pad,79.127464,6231,strings used in user interface elements
4,>,50.836117,11082,html elements and their attributes
5,After,156.294235,4820,"references to the word ""after"""
6,deleting,81.393402,6631,the beginning of a text or important markers i...
7,those,76.810143,7846,phrases and references to data and reporting
8,logs,86.961067,14,mentions of logarithmic or related mathematic...
9,",",74.480339,4223,questions and conditional statements related t...


In [None]:
for activation_value, feature_index, token in zip(
    activation_values,
    feature_indices,
    example_prompt_as_tokens,
):
    print(f"{activation_value.item()=:.4f}")
    print(f"{feature_index.item()=}")

    feature_index_int = feature_index.item()

    # note: there could just legitimately be features without explanations
    #       that just gives the "this feature has no known explanations"
    had_feature_index_in_explanations_df = (
        explanations_df["feature"].astype(int) == feature_index_int
    ).sum() > 0

    print(f"{had_feature_index_in_explanations_df=}")

    # if had_feature_index_in_explanations_df:

    show_neuronpedia_dashboard(feature_index=feature_index_int)