# SAE-VIS demo

This notebook was created to demo my open-source sparse autoencoder visualizer, as can be seen [here](https://www.perfectlynormal.co.uk/blog-sae). Other useful links:

* [GitHub repo](https://github.com/callummcdougall/sae_vis)
* [Developer guide](https://docs.google.com/document/d/10ctbiIskkkDc5eztqgADlvTufs7uzx5Wj8FE_y5petk/edit#heading=h.t3sp1uj6qghd), for people looking to understand the codebase so they can contribute
* [User guide](https://docs.google.com/document/d/1QGjDB3iFJ5Y0GGpTwibUVsvpnzctRSHRLI-0rm6wt_k/edit#heading=h.t3sp1uj6qghd), for people looking to understand how to use all the features of this codebase (although obviously reading through this notebook is another option, it should be mostly self-explanatory)

In this notebook, we demo two different types of vis:

1. **Feature-centric vis**, where you look at a single feature and see e.g. which sequences in a large dataset this feature fires strongest on.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/sae-snip-1.png" width="800">

2. **Prompt-centric vis**, where you input a custom prompt and see which features score highest on that prompt, according to a variety of possible metrics.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/sae-snip-2.png" width="800">

# Imports & Installs

In [None]:
try:
    import google.colab # type: ignore
    COLAB = True
    !git clone https://github.com/jbloomAus/mats_sae_training.git
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/eindex.git
    %pip install sae-vis
except:
    COLAB = False
    # from IPython import get_ipython # type: ignore
    # ipython = get_ipython(); assert ipython is not None
    # ipython.run_line_magic("load_ext", "autoreload")
    # ipython.run_line_magic("autoreload", "2")

import torch
from datasets import load_dataset
import webbrowser
import os
import sys
from transformer_lens import utils, HookedTransformer
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
import time

from sae_vis.utils_fns import get_device
from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig

device = get_device()

torch.set_grad_enabled(False);

# Setup

## Autoencoders

<!-- We're being a bit lazy here, and slicing our autoencoder so that we only take the first 2048 features (i.e. `dict_mult = 1`) rather than all 16384 features. This is literally just to avoid OOMs; you can increase the `DICT_MULT` parameter up to 8 if you'd like. -->

We set up our autoencoder here. You can use your own autoencoder, as long as it has the same parameters `W_enc`, `W_dec`, `b_enc` and `b_dec` (used in the same way) and has a `cfg` attribute which itself is a dataclass with attributes `d_mlp` and `dict_mult`. The forward pass method doesn't matter; we only ever use the weights directly in this codebase.

In [None]:
encoder = AutoEncoder.load_from_hf(version="run1")
encoder_B = AutoEncoder.load_from_hf(version="run2")

for k, v in encoder.named_parameters():
    print(f"{k}: {tuple(v.shape)}")

## Models

This library will eventually support non-transformerlens models, but it's not there currently. If you're interested in this, please reach out!

<!-- This library supports non-transformerlens models, provided you apply a wrapper around your model with a few specific methods (e.g. a modified `forward` function which returns a tuple of `(logits, activations, resid)`). However, it's much easier to just use a TransformerLens model in most cases! -->

<!-- The code below loads in our GELU-1l transformer model. You can create your transformer model any way you like; all that matters is that:

* Your model has a `forward` method which takes `tokens` and returns a tuple of `(logits, residual, post_activations)`.
* This forward method has a parameter `return_logits`, which is by default `True`, and when `False` it only returns `(residual, post_activations)`.

Provided this is the case, all other code here (including calculating the effect of ablating certain features) doesn't rely on any specific implementation details of the model.

If you're trying to use a particular model, we recommend **creating a wrapper class around your model which has an altered `forward` method** to match the required behaviour. In the case of this notebook, to make it clear that a `HookedTransformer` model is not necessary, we're using a `DemoTransformer` model (code in this repository), which is a very minimal version of the `HookedTransformer` model lacking the features like hooks, caches, etc. -->

In [None]:
model = HookedTransformer.from_pretrained("gelu-1l")

## Data

Obviously you can replace this code with your own data loading code. You should eventually have a 2D tensor of token ids.

In [None]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

# Feature-centric vis


Here's a minimal example, which generates data for the first 64 features, and produces the corresponding vis.

Once you open it, you can hover over tokens to see the feature activation magnitudes / most boosted tokens, and you can also use the dropdown in the top left to navigate between features!

In [None]:
from sae_vis.data_config_classes import (
    SaeVisLayoutConfig,
    Column,
    FeatureTablesConfig,
    ActsHistogramConfig,
    LogitsTableConfig,
    LogitsHistogramConfig,
    SequencesConfig,
    PromptConfig, # this one is only used for the prompt-centric vis
)

layout = SaeVisLayoutConfig(
    columns = [
        Column(FeatureTablesConfig(correlated_b_features_table=True)),
        Column(ActsHistogramConfig(), LogitsTableConfig(), LogitsHistogramConfig()),
        Column(SequencesConfig()),
    ]
)
# layout.help()

In [None]:
# Specify the hook point you're using, the features you're analyzing, and the batch size for gathering activations
sae_vis_config = SaeVisConfig(
    hook_point = utils.get_act_name("post", 0),
    features = range(32),
    batch_size = 2048,
    verbose = True,
    feature_centric_layout = layout,
)

# Gather the feature data
sae_vis_data = SaeVisData.create(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens, # type: ignore
    cfg = sae_vis_config,
)

# Save as HTML file & open in browser (or not, if in Colab)
filename = "_feature_vis_demo.html"
sae_vis_data.save_feature_centric_vis(filename, feature_idx=8)

if COLAB: print(f"Download the HTML file and open it in your browser: {filename}")
else: webbrowser.open(filename);

# Feature-centric vis: customizing layout

To customize our view, we pass an `SaeVisLayoutConfig` object into our main `SaeVisConfig`. This works by specifying a list of components for each column, with each component being customized by its associated config object.

There's also a `help` method which explains all of these arguments. Run the cell below to see the output for the default layout (i.e. the one that produced the vis from the cell above).

In [None]:
from sae_vis.data_config_classes import (
    SaeVisLayoutConfig,
    Column,
    FeatureTablesConfig,
    ActsHistogramConfig,
    LogitsTableConfig,
    LogitsHistogramConfig,
    SequencesConfig,
    PromptConfig, # this one is only used for the prompt-centric vis
)

layout = SaeVisLayoutConfig(
    columns = [
        Column(FeatureTablesConfig()),
        Column(ActsHistogramConfig(), LogitsTableConfig(), LogitsHistogramConfig()),
        Column(SequencesConfig()),
    ]
)
layout.help()

Here's an example of how you might customize the layout, by changing around a few components & altering how they look. Note that we have to pass this `layout` object into the `SaeVisData.create` method, because some of the arguments determine *what data* we gather.

<!-- Note, you can change the layout object via `multi_feature_data.fvc.layout = layout` (here `fvc` stands for feature vis config, this is how the `FeatureVisConfig` object is stored), and then re-run the `get_html` method. However, this might not work if your `SaeVisData` object doesn't contain as much data as you'd need to create the layout (an obvious example would be if you generated your data using only 3 rows in the feature tables, then created a vis using 5 tables - obviously you'd need to create that data again!). Where possible, we've tried to make the error messages explicit, so if you change the layout in a way which requires data you don't have, the error message should make it clear what's happened. If anything seems ambiguous, please make a PR! -->

In [None]:
# Create custom layout
layout = SaeVisLayoutConfig(
    columns = [
        Column(SequencesConfig(stack_mode='stack-all', buffer=(10, 5), compute_buffer=False, n_quantiles=0, top_acts_group_size=30), width=650),
        Column(ActsHistogramConfig(), FeatureTablesConfig(n_rows=5), width=500),
    ],
    height = 1000,
)
layout.help()

# Set all config parameter
feature_vis_config_custom_layout = SaeVisConfig(
    hook_point = utils.get_act_name("post", 0),
    features = range(32),
    batch_size = 2048,
    feature_centric_layout = layout,
)

# Generate data
multi_feature_data_custom = SaeVisData.create(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens, # type: ignore
    cfg = feature_vis_config_custom_layout,
)

# Save data (and open it using webbrowser)
filename = "_feature_vis_demo_custom.html"
multi_feature_data_custom.save_feature_centric_vis(filename, feature_idx=8)
webbrowser.open(filename);

# Feature-centric vis: multi-layer models

We've currently only worked with 1-layer models. Let's try and see what happens when we use a multi-layer model. Thankfully, Joseph Bloom has trained some excellent SAEs on GPT2-small, so we can use one of them.

First, we load the model, and the autoencoder.

In [None]:
# Get gpt2 model (this is the easy part)
gpt2 = HookedTransformer.from_pretrained("gpt2-small")

# Download the model from HuggingFace
layer = 2
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

# Make sure Joseph's lib is in the path, or else the load will fail
filepath = os.getcwd() + "/mats_sae_training"
if filepath not in sys.path: sys.path.append(filepath)

# Load the model, and make sure the keys are what we expect
obj = torch.load(path, map_location="cpu")
assert set(obj["state_dict"].keys()) == {'W_enc', 'b_enc', 'W_dec', 'b_dec'}

# Create the autoencoder from config, and load in its weights
# ! This is a bit hacky, but will hopefully improve soon when I start using Joseph's AutoEncoder class rather than mine)
cfg = AutoEncoderConfig(
    d_in = obj["cfg"].d_in,
    dict_mult = obj["cfg"].expansion_factor,
)
gpt2_sae = AutoEncoder(cfg)
gpt2_sae.load_state_dict(obj["state_dict"]);

And now let's get our vis. Feel the force!

Note that the time complexity here is dominated by the forward passes, because we're only getting one feature. Also note that the vis has no dropdown to select features, because we only have one feature.

In [None]:
torch.cuda.empty_cache()
import gc
gc.collect()

test_feature_idx_gpt = 7650

feature_vis_config_gpt = SaeVisConfig(
    hook_point = obj["cfg"].hook_point,
    features = test_feature_idx_gpt,
    batch_size = 2048,
    minibatch_size_tokens = 128,
    verbose = True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder = gpt2_sae,
    model = gpt2,
    tokens = all_tokens, # type: ignore
    cfg = feature_vis_config_gpt,
)

filename = "_feature_vis_demo_gpt.html"
sae_vis_data_gpt.save_feature_centric_vis(filename)

if COLAB: print(f"Download the HTML file and open it in your browser: {filename}")
else: webbrowser.open(filename);

# Prompt-centric vis

In this vis, we pick a prompt see which features score highest on that prompt, according to a variety of possible metrics.

Let's first do the minimal thing, which just involves the same steps as above, except we call `save_prompt_centric_vis` on our `sae_vis_data` object instead. This also requires passing in our prompt, plus optionally a `seq_pos` and `metric` argument which determine which dropdown options are selected by default when the page loads.

In [None]:
torch.cuda.empty_cache()
import gc
gc.collect()

# Specify the hook point you're using, the features you're analyzing, and the batch size for gathering activations
sae_vis_config = SaeVisConfig(
    hook_point = utils.get_act_name("post", 0),
    features = range(32),
    batch_size = 2048,
    verbose = True,
)

# Gather the feature data
sae_vis_data = SaeVisData.create(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens, # type: ignore
    cfg = sae_vis_config,
)

In [None]:
prompt = "'first_name': ('django.db.models.fields"

seq_pos = model.tokenizer.tokenize(prompt).index("Ġ('") # type: ignore
metric = 'act-quantiles'

filename = "_prompt_vis_demo.html"

sae_vis_data.save_prompt_centric_vis(
    prompt = prompt,
    filename = filename,
    seq_pos = seq_pos, # optional argument, to determine the default option when the page loads
    metric = metric, # optional argument, to determine the default option when the page loads
)
if COLAB: print(f"Download the HTML file and open it in your browser: {filename}")
else: webbrowser.open(filename);

Just like the feature-centric vis, we can also customize the view. The feature-centric view was customized by the `sae_vis_data.feature_centric_layout` object, which defaulted to:

```python
SaeVisLayoutConfig(
    columns = [
        Column(FeatureTablesConfig()),
        Column(ActsHistogramConfig(), LogitsTableConfig(), LogitsHistogramConfig()),
        Column(SequencesConfig()),
    ]
)
```

By contrast, the prompt-centric view is customized by the `sae_vis_data.prompt_centric_layout` object, which defaults to:

```python
SaeVisLayoutConfig(
    columns = [
        Column(PromptConfig(), ActsHistogramConfig(), LogitsTableConfig(n_rows=5), SequencesConfig(n_quantiles=0), width=450),
    ],
)
```

You can create your own of these just like for the feature-centric view, although there are 2 extra conditions:

1. **You must have only one column** (since in this vis, we dedicate one column to each of the top features).
2. **The prompt-centric vis must not require more data than the feature-centric vis.** This is because the data contained in `sae_vis_data` is determined by the prompt-centric vis layout, and when generating the prompt-centric vis we can't use data that's not there. As a simple example, if we used `LogitsTableConfig(n_rows=10)` in our feature-centric config, we couldn't use `LogitsTableConfig(n_rows=15)` when generating our prompt-centric vis.

Both these conditions will be checked when you initialize your config objects, before you've gathered any data.

Here's an example customized prompt-centric layout. Note, this example shows how you can manually change `sae_vis_data.cfg.prompt_centric_layout` without re-generating the data (and the same is true for the feature data), provided you're not trying to use more data than you've already gathered.

In [None]:
sae_vis_data.cfg.prompt_centric_layout = SaeVisLayoutConfig(
    columns = [
        Column(PromptConfig(), LogitsTableConfig(), LogitsHistogramConfig(), ActsHistogramConfig(), SequencesConfig(n_quantiles=10), width=550),
    ],
    height = 1200,
)

filename = "_prompt_vis_demo_custom.html"

sae_vis_data.save_prompt_centric_vis(
    prompt = prompt,
    filename = filename,
    seq_pos = seq_pos,
    metric = metric,
)

if COLAB: print(f"Download the HTML file and open it in your browser: {filename}")
else: webbrowser.open(filename);

# Saving data as JSON

You can save your data as a JSON file, as follows. Note that we don't actually save much storage space by doing this, because the HTML is already stored very compactly (essentially the JSON is dumped directly into the HTML page; most of the extra size comes from the JavaScript functions which populate the empty HTML elements with the data).

In [None]:
json_filepath = "_feature_vis_demo.json"
html_filepath = "_feature_vis_demo.html"

# Save
t0 = time.time()
sae_vis_data.save_json(filename=json_filepath)
print(f"Saved in {time.time() - t0:.2f} seconds")

# Load back in (supplying our config, model & encoders which aren't saved)
t0 = time.time()
sae_vis_data_loaded = SaeVisData.load_json(
    filename=json_filepath,
    cfg=sae_vis_data.cfg,
    model=model,
    encoder=encoder,
    encoder_B=encoder_B,
)
assert isinstance(sae_vis_data_loaded, SaeVisData)
print(f"Loaded in {time.time() - t0:.2f} seconds\n")

# Check we can still use it
sae_vis_data_loaded.save_feature_centric_vis(html_filepath, feature_idx=8)
webbrowser.open(html_filepath);

# Print out sizes, to see how much we save using json (answer = not much, because the HTML is already quite efficient!)
print(f"Size of JSON: {os.path.getsize(json_filepath) / 1e6:.3f} MB")
print(f"Size of HTML: {os.path.getsize(html_filepath) / 1e6:.3f} MB")

Same thing for prompt-centric view: