# Imports & Installs

In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

from IPython.display import display, HTML
import torch
from datasets import load_dataset
import pickle
import webbrowser
import os
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

from sae_vis.model_fns import AutoEncoder, DemoTransformer, DemoTransformerConfig
from sae_vis.data_fetching_fns import get_feature_data, get_prompt_data
from sae_vis.data_storing_fns import FeatureVizParams, MultiFeatureData, MultiPromptData
from sae_vis.utils_fns import create_vocab_dict, tokenize_and_concatenate

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

torch.set_grad_enabled(False);

  _torch_pytree._register_pytree_node(


# Setup

## Autoencoders

<!-- We're being a bit lazy here, and slicing our autoencoder so that we only take the first 2048 features (i.e. `dict_mult = 1`) rather than all 16384 features. This is literally just to avoid OOMs; you can increase the `DICT_MULT` parameter up to 8 if you'd like. -->

We set up our autoencoder here. You can use your own autoencoder, as long as it has the same parameters `W_enc`, `W_dec`, `b_enc` and `b_dec` (used in the same way) and has a `cfg` attribute which itself is a dataclass with attributes `d_mlp` and `dict_mult`. The forward pass method doesn't matter; we only ever use the weights directly in this codebase.

In [2]:
encoder = AutoEncoder.load_from_hf(version="run1")
encoder_B = AutoEncoder.load_from_hf(version="run2")

for k, v in encoder.named_parameters():
    print(f"{k}: {tuple(v.shape)}")

W_enc: (2048, 16384)
W_dec: (16384, 2048)
b_enc: (16384,)
b_dec: (2048,)


## Models

The code below loads in our GELU-1l transformer model. You can create your transformer model any way you like; all that matters is that:

* Your model has a `forward` method which takes `tokens` and returns a tuple of `(logits, residual, post_activations)`.
* This forward method has a parameter `return_logits`, which is by default `True`, and when `False` it only returns `(residual, post_activations)`.

Provided this is the case, all other code here (including calculating the effect of ablating certain features) doesn't rely on any specific implementation details of the model.

If you're trying to use a particular model, we recommend **creating a wrapper class around your model which has an altered `forward` method** to match the required behaviour. In the case of this notebook, to make it clear that a `HookedTransformer` model is not necessary, we're using a `DemoTransformer` model (code in this repository), which is a very minimal version of the `HookedTransformer` model lacking the features like hooks, caches, etc.

In [3]:
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gelu-1l")
tokenizer = model.tokenizer

path = "C:/Users/calsm/Documents/AI Alignment/hf/gelu-1l-sae"

# save tokenizer as pkl
with open(path + "/tokenizer.pkl", "wb") as f:
    pickle.dump(tokenizer, f)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loaded pretrained model gelu-1l into HookedTransformer


FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/calsm/Documents/AI Alignment/hf/gelu-1l-sae/tokenizer.pkl'

In [4]:
# Load our state dict from HuggingFace 
REPO_ID = "callummcdougall/gelu-1l"
tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl")
with open(tokenizer_path, "rb") as f:
    tokenizer = pickle.load(f)

# Load our state dict from HuggingFace 
weights_path = hf_hub_download(repo_id=REPO_ID, filename="gelu-1l-state-dict.pt")
state_dict = torch.load(weights_path, map_location=device)

# Create config object for our tokenizer
# (see model_fns.py for an explanation of this, and to understand the architecture)
cfg = DemoTransformerConfig(
    act_fn = 'gelu',
    d_head = 64,
    d_mlp = 2048,
    d_model = 512,
    d_vocab = 48262,
    n_ctx = 1024,
    n_heads = 8,
    n_layers = 1,
    device = device,
    dtype = torch.float32,
    normalization_type ='LNPre',
)

# Create our model, and load in the state dict
model = DemoTransformer(cfg, tokenizer)
_ = model.load_state_dict(state_dict)

## Data

Obviously you can replace this code with your own data loading code. You should eventually have a 2D tensor of token ids.

In [5]:
SEQ_LEN = 128

data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN)
tokenized_data = tokenized_data.shuffle(42)
all_tokens: torch.Tensor = tokenized_data["tokens"]

print(all_tokens.shape)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


torch.Size([215402, 128])


# Creating visualisations #1 (feature-centric)

First, we have a dataclass which contains all the relevant hyperparameters for creating our visualization. 

In [6]:
feature_viz_params = FeatureVizParams()
feature_viz_params.help()

Next, we actually get the feature data. On an A100 (e.g. Colab Pro+) this should take less than a minute.

In [7]:
feature_data = get_feature_data(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens,
    fvp = feature_viz_params,
)

Now, we generate the HTML. **The `webbrowser` command will not work for you in Colab; you'll need to manually download & open the HTML file from your Colab file storage.**

In [None]:
test_idx = 8
filepath = "feature_viz_demo.html"

In [None]:
html_str = feature_data.feature_data_dict[test_idx].get_html()

display(HTML(html_str))

with open(filepath, "w") as f:
    f.write(html_str)

result = webbrowser.open(filepath)

You can also generate smaller plots. If you don't care about getting the sequences in the activation quantiles (which are the things that take the most time to generate), you can pass `n_groups=0` into the `FeatureVizParams` dataclass. This roughly halves the time taken to generate the visualisation.

In [None]:
feature_viz_params = FeatureVizParams(n_groups=0)

feature_data = get_feature_data(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens,
    fvp = feature_viz_params,
)

In [None]:
html_str = feature_data.feature_data_dict[test_idx].get_html(width=320)
display(HTML(html_str))
with open(filepath, "w") as f:
    f.write(html_str)
result = webbrowser.open(filepath)

And if you want to be even more minimal, you can remove the tables on the left hand side. When you do this, the middle column will be rearranged by default, to make everything more compact. This visualization also shows what `border = False` looks like.

In [None]:
feature_viz_params = FeatureVizParams(
    n_groups = 0,
    first_group_size = 15,
    include_left_tables = False,
    border = False,
)

feature_data = get_feature_data(
    encoder = encoder,
    encoder_B = encoder_B,
    model = model,
    tokens = all_tokens,
    fvp = feature_viz_params,
)

In [None]:
html_str = feature_data[test_idx].get_html()
display(HTML(html_str))
with open(filepath, "w") as f:
    f.write(html_str)
result = webbrowser.open(filepath)

# Creating visualisations #2 (prompt-centric)

First we create our vocab dict, via a helper function which allows us to get nice HTML representations of our tokens (rather than things which mess up our HTML, e.g. actual line breaks). You should do this on your model's tokenizer, since this `vocab_dict` will be used in subsequent functions. I've only worked with the GPT2 tokenizer, so if this code fails in some way for a different tokenizer, please let me know!

In [None]:
vocab_dict = create_vocab_dict(model.tokenizer)

Next, we pick a prompt and generate the data for it. The `get_prompt_data` function requires `feature_data` as input, because it needs things like the max-activating sequences for this feature. Note, we're using the `feature_data` object with `n_groups=0` and `include_left_tables=False` - this is because we don't actually need these for the prompt-centric visualization. If you're only trying to generate the prompt-centric view, it's a good idea to have these parameters set to these values, because it will speed up the process.

We don't have an extra dataclass like `FeatureVizParams` to wrap our arguments in, because there are very few. Some of them (e.g. `first_group_size`) are inherited from the `FeatureVizParams` object which was used to generate the `feature_data` which is supplied. The only important argument we need to use is `num_top_features`, which is the max number of top-scoring features which are displayed for any given prompt & metric. There's also the argument `verbose` (default False) which controls whether progress bars are printed.

In [None]:
prompt = "'first_name': ('django.db.models.fields"

str_toks = model.tokenizer.tokenize(prompt)
print(str_toks)

prompt_data = get_prompt_data(
    encoder = encoder,
    model = model,
    prompt = prompt,
    feature_data = feature_data,
    num_top_features = 10,
)

Lastly, from this data we create our visualization. We've chosen to examine the `"loss_effect"` on the `django` token, i.e. showing the features whose contributions most reduce the loss on this token.

In [None]:
str_score = "loss_effect"
seq_pos = str_toks.index("django")

html_str = prompt_data.get_html(seq_pos, str_score, vocab_dict)

display(HTML(html_str))

filepath = "prompt_viz_demo.html"
with open(filepath, "w") as f:
    f.write(html_str)

result = webbrowser.open(filepath)

Alternatively, you can use the `"act_size"` or `"act_quantile"` metrics (we recommend the latter) on the `Ġ('` token, i.e. the token immediately before `django`. Remember, we have to include this `Ġ` character at the front of the token (which represents the space character), although this will depend on what tokenizer your model is using.

In [None]:
str_score = "act_quantile"
seq_pos = str_toks.index("Ġ('")

html_str = prompt_data.get_html(seq_pos, str_score, vocab_dict)

display(HTML(html_str))

filepath = "user_prompt.html"
with open(filepath, "w") as f:
    f.write(html_str)

result = webbrowser.open(filepath)

# Saving data

Obviously the HTML strings can be saved, either as strings or as regular HTML files. If you want something more compact, you can pickle the dataclasses:

In [None]:
# Save
with open("feature_data.pkl", "wb") as f:
    pickle.dump(feature_data, f)

# Load
with open("feature_data.pkl", "rb") as f:
    feature_data: MultiFeatureData = pickle.load(f)

# Delete
os.remove("feature_data.pkl")

# Visualize the loaded data, to check it works
html_str = feature_data[test_idx].get_html()
display(HTML(html_str))

And for the prompt-centric visualisation:

In [None]:
# Save
with open("prompt_data.pkl", "wb") as f:
    pickle.dump(prompt_data, f)

# Load
with open("prompt_data.pkl", "rb") as f:
    prompt_data: MultiPromptData = pickle.load(f)

# Delete
os.remove("prompt_data.pkl")

# Visualize the loaded data, to check it works
html_str = prompt_data.get_html(seq_pos, str_score, vocab_dict)
display(HTML(html_str))