# HookedSAETransformer Demo

This tutorial is based on a tutorial written by [Connor Kissane](https://github.com/ckkissane) but contains edits reflecting changes to the code which was submitted as a [PR](https://github.com/neelnanda-io/TransformerLens/pull/536) to [TransformerLens](https://github.com/neelnanda-io/TransformerLens). The functionality for HookedSAETransformer is implemented in TransformerLens but we have duplicated / adapted this code so that the same functionality exists in SAE Lens (for easier adaptation alongside SAE analysis code).

----

HookedSAETransformer is a lightweight extension of HookedTransformer that allows you to attach Sparse Autoencoders to activations. This makes it easy to do exploratory analysis such as running inference with SAEs attached, caching SAE feature activations, and intervening on SAE activations with hooks.

I (Connor Kissane) implemented this to accelerate research on [Attention SAEs](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs) based on helpful suggestions from Neel Nanda and Arthur Conmy, and found that it was well worth the time and effort. With research on dictionary learning progressing rapidly, I'm not confident that this is the ultimate design, but I hope other researchers will find the library useful. This notebook demonstrates how it works and how to use it.

# Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install saelens

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import torch
import transformer_lens.utils as utils

import plotly.express as px
import tqdm
from functools import partial
import einops

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
    
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2af7b0f50>

# Loading and Running Models

Just like a [HookedTransformer](https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Loading-and-Running-Models), we can load in any model that's supported in TransformerLens with the `HookedSAETransformer.from_pretrained(MODEL_NAME)`. In this demo we'll use GPT-2 small.

In [3]:
from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
model: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small").to(device)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps


We haven't attached any SAEs, so the model will behave exactly like a HookedTransformer. We'll explore the main features of HookedSAETransformer on the classic IOI task, so let's first sanity check that GPT2-small can do the IOI task without any SAEs attached:

In [4]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" John", " Mary",),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(answers)

['When John and Mary went to the shops, Mary gave the bag to', 'When John and Mary went to the shops, John gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']
[(' John', ' Mary'), (' Mary', ' John'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]


In [5]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

tokens = model.to_tokens(prompts, prepend_bos=True)
original_logits, cache = model.run_with_cache(tokens)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print(f"Original average logit diff: {original_average_logit_diff}")
original_per_prompt_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print(f"Original per prompt logit diff: {original_per_prompt_logit_diff}")

Original average logit diff: 3.5518882274627686
Original per prompt logit diff: tensor([3.2016, 3.3367, 2.7095, 3.7975, 1.7204, 5.2812, 2.6008, 5.7674],
       device='mps:0')


# Attach SAEs

The key feature of HookedSAETransformer is being able to attach an SAEs to any activation with `model.attach_sae(sae)`, where sae is a HookedSAE.

Let's load in a set of GPT-small residual stream SAEs and attach them to the model.

In [6]:
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# We'll download all of them since we'll use them later.
sparse_autoencoders, sparsities = get_gpt2_res_jb_saes("blocks.7.hook_resid_pre", device ="mps")
specific_hook_point = "blocks.7.hook_resid_pre"
sparse_autoencoder = sparse_autoencoders[specific_hook_point]
model.attach_sae(sparse_autoencoder)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.69it/s]


When you attach an SAE, it gets stored in `model.acts_to_saes`, a dictionary that maps the activation name to the HookedSAE that is attached. Now that these SAEs are attached, we can just run the model like a normal HookedTransformer, but the activations will be replaced with the reconstructed activations from the corresponding SAEs!

This is useful for testing how good SAEs are on different tasks: here let's check the IOI logit diff with the L5+L6 attn SAEs attached, and see that we still recover >93% of the logit diff relative to zero ablation.

In [7]:
logits_with_saes, cache = model.run_with_cache(tokens)
average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)
print(f"Average logit diff with SAEs: {average_logit_diff_with_saes}")
per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)

Average logit diff with SAEs: 3.4094748497009277


In [8]:
# you can see the hooks that are now part of the cache here:
display([k for k in cache.keys() if "blocks.7.hook_resid_pre" in k])

['blocks.7.hook_resid_pre.hook_sae_input',
 'blocks.7.hook_resid_pre.hook_sae_acts_pre',
 'blocks.7.hook_resid_pre.hook_sae_acts_post',
 'blocks.7.hook_resid_pre.hook_sae_output']

In [9]:
def zero_ablate_resid(resid, hook, pos=None):
    if pos is None:
        resid[:] = 0.
    else:
        resid[:, pos, :] = 0.
    return resid

layers = [7]
act_names = [utils.get_act_name('resid_pre', layer) + ".hook_sae_output" for layer in layers]
zero_abl_logits = model.run_with_hooks(
    tokens,
    return_type="logits",
    fwd_hooks=[
        (act_name, zero_ablate_resid) for act_name in act_names
    ]
)

per_prompt_zero_abl_logit_diff = logits_to_ave_logit_diff(zero_abl_logits, answer_tokens, per_prompt=True)
avg_zero_abl_logit_diff = logits_to_ave_logit_diff(zero_abl_logits, answer_tokens)
print(f"Zero ablated logit diff: {avg_zero_abl_logit_diff}")

Zero ablated logit diff: 0.0


In [10]:
average_logit_diff_recovered = (average_logit_diff_with_saes - avg_zero_abl_logit_diff) / (original_average_logit_diff - avg_zero_abl_logit_diff)
print(f"average_logit_diff_recovered", average_logit_diff_recovered.item())

average_logit_diff_recovered 0.9599049091339111


In [11]:
from typing import List
import plotly.graph_objects as go
def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):


    y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]
    error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs]

    fig = go.Figure(data=[go.Bar(
        x=x_axis,
        y=y_data,
        error_y=dict(
            type='data',  # specifies that the actual values are given
            array=error_y_data,  # the magnitudes of the errors
            visible=True  # make error bars visible
        ),
    )])

    # Customize layout
    fig.update_layout(title_text=f'Logit Diff after Interventions',
                    xaxis_title_text='Intervention',
                    yaxis_title_text='Logit diff',
                    plot_bgcolor='white')

    # Show the figure
    fig.show()

x_axis = ['Clean Baseline', "Zero Abl [L5 L6]", "With SAEs [L7]"]
per_prompt_logit_diffs = [
    original_per_prompt_logit_diff,
    per_prompt_zero_abl_logit_diff,
    per_prompt_diff_with_saes]

show_avg_logit_diffs(x_axis, per_prompt_logit_diffs)

## "Attach" vs "Turn on" SAEs

Eventually you may want not want to run with SAEs that you previously attached. You can deactivate them with `model.turn_saes_off(act_names)`, where act_names is the list of activations where you want to deactivate the corresponding SAE. You can also just turn all of them off with the `model.turn_saes_off()`, which is often simpler in practice:

In [12]:
print("SAEs turned on before:", model.get_saes_status())
model.turn_saes_off()
print("SAEs turned on after:", model.get_saes_status())

SAEs turned on before: {'blocks.7.hook_resid_pre': True}
SAEs turned on after: {'blocks.7.hook_resid_pre': False}


When you turn off an SAE it is still 'attached', so you can turn it back on with `model.turn_saes_on(act_names)`. You can also just turn on all of the SAEs you've attached with `model.turn_saes_on()`.

When you attach an SAE with `model.attach_sae(sae)`, it automatically gets turned on by default. You can also attach SAEs without turning them on with `model.attach_sae(sae, turn_on=False)`. This is useful if you want to attach a bunch of SAEs upfront and turn on different ones later. Let's attach GPT-2 small attention SAEs on every layer, but without turning any on, so we can easily use them later.

In [13]:
sparse_autoencoders, sparsities = get_gpt2_res_jb_saes(device ="mps")
print("Attached SAEs before:", model.acts_to_saes.keys())
for sae in sparse_autoencoders.values():
    model.attach_sae(sae)
print("Attached SAEs after:", model.acts_to_saes.keys())

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:07<00:00,  1.68it/s]


Attached SAEs before: dict_keys(['blocks.7.hook_resid_pre'])
Attached SAEs after: dict_keys(['blocks.7.hook_resid_pre', 'blocks.0.hook_resid_pre', 'blocks.1.hook_resid_pre', 'blocks.2.hook_resid_pre', 'blocks.3.hook_resid_pre', 'blocks.4.hook_resid_pre', 'blocks.5.hook_resid_pre', 'blocks.6.hook_resid_pre', 'blocks.8.hook_resid_pre', 'blocks.9.hook_resid_pre', 'blocks.10.hook_resid_pre', 'blocks.11.hook_resid_pre', 'blocks.11.hook_resid_post'])


# Run with SAEs

Sometimes we want to rapidly run with different combinations of SAEs, in which case keeping track of what SAEs are on / off can be a hassle.

If you know that you only want to temporarily run with some SAEs turned on for one forward pass, you can use `model.run_with_saes(tokens, act_names=act_names)`. You should pass in a list of strings for the activations corresponding to the SAEs that you want to turn on for this forward pass. Note if you already have some SAEs turned on, this will turn them off. It will also turn off all attached SAEs after running, creating a clean slate.

To demonstrate, let's use `run_with_saes` evaluate many combinations of SAEs on different cross sections of the IOI circuit.

In [14]:
all_layers = [[i] for i in range(12)] # load each SAE individually
# all_layers = [[0, 1], [2, 4], [5], [5,6], [7, 8], [9, 10, 11]] # or load combinations of SAEs
x_axis = ['Clean Baseline']
per_prompt_logit_diffs = [
    original_per_prompt_logit_diff,
]

for layers in all_layers:
    act_names = [utils.get_act_name('resid_pre', layer) for layer in layers]
    logits_with_saes = model.run_with_saes(tokens, act_names=act_names)
    average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)
    per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)

    x_axis.append(f"With SAEs L{layers}")
    per_prompt_logit_diffs.append(per_prompt_diff_with_saes)

show_avg_logit_diffs(x_axis, per_prompt_logit_diffs)

We generally see that running with residual stream SAEs preserves most of the logit difference, but that particular SAEs (such as L5 or L10) are particularly bad. 

# Run with Cache

We often want to see what features are active on a given prompt. With HookedSAETransformer, activations from the HookedSAEs that you attach will automatically be cached with `model.run_with_cache(tokens)`. The corresponding hook names will generally be the HookedTransformer hook_name (eg blocks.5.hook_resid_pre) + the hookedSAE hooked name preceeded by a period (eg .hook_sae_acts_post)

`run_with_cache` makes it easy to explore which SAE features are active on any input. Let's look at some of the top features at the S2 position for our L6 SAE across all IOI examples:

In [15]:
layers = [5, 6]
act_names = [utils.get_act_name('resid_pre', layer) for layer in layers]
model.turn_saes_on(act_names)

layer = 5
_, cache = model.run_with_cache(tokens)
s2_pos = 10
sae_acts = cache[utils.get_act_name('resid_pre', layer) + ".hook_sae_acts_post"][:, s2_pos, :]

live_feature_mask = sae_acts > 0

live_feature_union = live_feature_mask.any(dim=0)

px.imshow(
    sae_acts[:, live_feature_union].detach().cpu(),
    title = f"Activations of Live SAE features at L{layer} S2 position per prompt",
    # xaxis="Feature Id", yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
    color_continuous_midpoint=0,
    color_continuous_scale='RdBu',
)

These results would be more interesting if we were using attn z SAE's as in the original version of this tutorial.

In [16]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
vals, inds = torch.topk(sae_acts[:, :].detach().cpu().sum(dim=0),8)

get_neuronpedia_quick_list(
    inds.tolist(), 
    model = "gpt2-small",
    dataset="res-jb",
    layer = layer,
    name = "IOI S2 Names")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


'https://neuronpedia.org/quick-list/?name=IOI%20S2%20Names&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%2217855%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%22722%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%223153%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%223977%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%2212738%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%2210955%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%224000%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-res-jb%22%2C%20%22index%22%3A%20%2211300%2

## Run with cache with SAEs

Similar to run_with_saes, we can use `model.run_with_cache_with_saes(tokens, act_names=act_names)`, to just run with cache with the SAEs for specified act names turned on. Like run_with_saes, this will also create a clean slate, turning off all of the attached SAEs both before and after running.

In [17]:
layer, s2_pos = 6, 10
act_names = [utils.get_act_name('resid_pre', layer) for layer in [layer]]
_, cache = model.run_with_cache_with_saes(tokens, act_names=act_names)
sae_acts = cache[utils.get_act_name('resid_pre', layer) + ".hook_sae_acts_post"][:, s2_pos, :]
live_feature_mask = sae_acts > 0
live_feature_union = live_feature_mask.any(dim=0)

px.imshow(
    sae_acts[:, live_feature_union].detach().cpu(),
    title = f"Activations of Live SAE features at L{layer} S2 position per prompt",
    # xaxis="Feature Id", yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
    color_continuous_midpoint=0,
    color_continuous_scale='RdBu',
)

# Run with Hooks

Often we would like to intervene on activations like SAE features. We use `run_with_hooks`. This works exactly like HookedTransformer `run_with_hooks`, with the added benefit that we can now intervene on SAE activations from the HookedSAEs. Let's use this to ablate SAE features to see which are the most causally relevant.

In [18]:
LAYER = 6
model.turn_saes_off()
model.get_saes_status()
model.turn_saes_on(
    [f'blocks.{i}.hook_resid_pre' for i in [LAYER]]
)
_, cache = model.run_with_cache(tokens)
model.get_saes_status()

logits_with_saes = model(tokens)
clean_sae_baseline_avg = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)
clean_sae_baseline_per_prompt = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)

In [19]:
from tqdm import tqdm

def ablate_sae_feature(sae_acts, hook, pos, feature_id):
    if pos is None:
        sae_acts[:, :, feature_id] = 0.
    else:
        sae_acts[:, pos, feature_id] = 0.
    return sae_acts

hooked_encoder = model.acts_to_saes[utils.get_act_name('resid_pre', LAYER)]
all_live_features = torch.arange(sparse_autoencoder.cfg.d_sae)[live_feature_union.cpu()]

In [20]:
causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))
fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}


abl_layer, abl_pos  = LAYER, 10
for feature_id in tqdm(all_live_features):
    feature_id = feature_id.item()
    abl_feature_logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=[
            (utils.get_act_name('resid_pre', abl_layer) + ".hook_sae_acts_post", 
             partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id)
             )
        ]
    ) # [batch, seq, vocab]
    
    abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]
    del abl_feature_logits
    torch.mps.empty_cache()
    causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - clean_sae_baseline_per_prompt

  3%|▎         | 4/151 [00:00<00:08, 17.26it/s]

100%|██████████| 151/151 [00:08<00:00, 17.61it/s]


In [21]:
fig = px.imshow(
    causal_effects.detach().cpu(),
    title=f"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}",
    x=list(map(str, all_live_features.tolist())),
    color_continuous_midpoint=0,
    color_continuous_scale="RdBu",
)

# label x and y axis 
fig.update_xaxes(title_text='Feature Id')
fig.update_yaxes(title_text='Prompt')
fig.show()

# NOTE: Path Patching Section removed for now. 

# Error Nodes

Recent exciting work from [Marks et al.](https://arxiv.org/abs/2403.19647v2) demonstrated the use of "error nodes" in SAE circuit analysis. The idea is that for some input activation x, SAE(x) = x_reconstruct is an approximation of x, but we can write x = x_reconstruct + error_term.

This seems useful: instead of replacing x with x_reconstruct, which might break everything and make our circuit analysis janky, we can just re-write x as a function of the SAE features, bias, and error term, which gives us access to all of the SAE features but without breaking performance.

Additionally, we can compare interventions on SAE features to the same intervention on the error term to get a better sense of how much the SAE features have actually captured.

To use error terms with HookedSAEs, you can set `hooked_sae.cfg.use_error_term = True`, or initialize it to True in the config. Note HookedSAEConfig current sets this to False by default.

In [22]:
model.turn_saes_off()
layers = [6]
act_names = [utils.get_act_name('resid_pre', layer) for layer in layers]
for act_name in act_names:
    print(model.acts_to_saes[act_name].use_error_term)
    model.acts_to_saes[act_name].use_error_term = True
    print(model.acts_to_saes[act_name].use_error_term)

False
True


Now the output of each attached SAE will be SAE(x) + error_term = x. We can sanity check this by confirming that running with SAEs produces the same logits without SAEs.

In [23]:
logits_with_saes, cache = model.run_with_cache_with_saes(tokens, act_names=act_names)
logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)
assert torch.allclose(logits_with_saes, original_logits, atol=1e-4), "Logits should be the same as the original"
assert 'blocks.6.hook_resid_pre.hook_sae_acts_post' in cache.keys(), "The hook should be in the cache since we ran with the SAE"

Now we can compare ablations of each feature to ablating the error node. We'll start by ablating each feature on each prompt, and then the error nodes. We'll append the effects from ablating error nodes to the rightmost column on the heatmap:

In [24]:
model.get_saes_status()

{'blocks.7.hook_resid_pre': False,
 'blocks.0.hook_resid_pre': False,
 'blocks.1.hook_resid_pre': False,
 'blocks.2.hook_resid_pre': False,
 'blocks.3.hook_resid_pre': False,
 'blocks.4.hook_resid_pre': False,
 'blocks.5.hook_resid_pre': False,
 'blocks.6.hook_resid_pre': False,
 'blocks.8.hook_resid_pre': False,
 'blocks.9.hook_resid_pre': False,
 'blocks.10.hook_resid_pre': False,
 'blocks.11.hook_resid_pre': False,
 'blocks.11.hook_resid_post': False}

In [25]:
def ablate_sae_feature(sae_acts, hook, pos, feature_id):
    if pos is None:
        sae_acts[:, :, feature_id] = 0.
    else:
        sae_acts[:, pos, feature_id] = 0.
    return sae_acts

layer = 6
act_name = utils.get_act_name('resid_pre', layer)
model.acts_to_saes[act_name].use_error_term = True
hooked_encoder = model.acts_to_saes[act_name]
all_live_features = torch.arange(hooked_encoder.cfg.d_sae)[live_feature_union.cpu()]
model.turn_saes_on([act_name])

causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))
fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}

abl_layer, abl_pos  = 6, 10
for feature_id in tqdm(all_live_features):
    feature_id = feature_id.item()
    abl_feature_logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=[(utils.get_act_name('resid_pre', abl_layer) + ".hook_sae_acts_post", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]
    ) # [batch, seq, vocab]

    abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]
    causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - original_per_prompt_logit_diff

def able_sae_error(sae_error, hook, pos):
    if pos is None:
        sae_error = 0.
    else:
        sae_error[:, pos, ...] = 0.
    return sae_error


abl_error_logits = model.run_with_hooks(
    tokens,
    return_type="logits",
    fwd_hooks=[(utils.get_act_name('resid_pre', abl_layer) + ".hook_sae_error", partial(able_sae_error, pos=abl_pos))]
) # [batch, seq, vocab]

abl_error_logit_diff = logits_to_ave_logit_diff(abl_error_logits, answer_tokens, per_prompt=True) # [batch]
error_abl_effect = abl_error_logit_diff - original_per_prompt_logit_diff


causal_effects_with_error = torch.cat([causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1)

  0%|          | 0/151 [00:00<?, ?it/s]

100%|██████████| 151/151 [00:08<00:00, 18.55it/s]


In [26]:
px.imshow(
    causal_effects_with_error,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    title=f"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}",
    x=list(map(str, all_live_features.tolist()))+["error"],
    labels= {"x":"Feature Idx", "y":"Prompt Idx"}
)

We can see that on some prompts, ablating the error term (right most column) does have a non trivial effect on the logit diff, although I don't see a clear pattern. It seems useful to include this term when doing causal interventions to get a sense of how much the SAE features are actually explaining.

# Attribution patching


Both [Anthropic](https://transformer-circuits.pub/2024/march-update/index.html#feature-heads) and [Marks et al](https://arxiv.org/abs/2403.19647v2). also demonstrated the use of gradient based attribution techniques as a substitute for activation patching on SAE features. The key idea is that patching / ablations (as we did above) can be slow, as it requires a new forward pass for each patch. This seems especially problematic when dealing with SAEs with tens of thousands of features per activation.

They find that gradient based attribution techniques like [attribution patching](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching) are good approximations, allowing for more efficient and scalable circuit analysis with SAEs.

Let's use `HookedSAETransformer` to implement attribution patching for every SAE feature in L5 to find causally relevant SAE features in just one forward and one backward pass.

In [31]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x38e435dd0>

In [46]:
act_name = utils.get_act_name('resid_pre', layer)
model.turn_saes_on(act_name)
model.acts_to_saes[act_name].use_error_term = False
model.get_saes_status()

{'blocks.7.hook_resid_pre': False,
 'blocks.0.hook_resid_pre': False,
 'blocks.1.hook_resid_pre': False,
 'blocks.2.hook_resid_pre': False,
 'blocks.3.hook_resid_pre': False,
 'blocks.4.hook_resid_pre': False,
 'blocks.5.hook_resid_pre': False,
 'blocks.6.hook_resid_pre': True,
 'blocks.8.hook_resid_pre': False,
 'blocks.9.hook_resid_pre': False,
 'blocks.10.hook_resid_pre': False,
 'blocks.11.hook_resid_pre': False,
 'blocks.11.hook_resid_post': False}

In [49]:
from transformer_lens import ActivationCache
filter_sae_acts = lambda name: ("hook_sae_acts_post" in name)
def get_cache_fwd_and_bwd(model, tokens, metric):
    model.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    model.add_hook(filter_sae_acts, forward_cache_hook, "fwd")

    grad_cache = {}
    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    model.add_hook(filter_sae_acts, backward_cache_hook, "bwd")

    value = metric(model(tokens))
    print(value)
    value.backward()
    model.reset_hooks()
    return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)


BASELINE = original_per_prompt_logit_diff
def ioi_metric(logits, answer_tokens=answer_tokens):
    return (logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE).sum()

clean_tokens = tokens.clone()
clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache)) # This should be
print("Clean Gradients Cached:", len(clean_grad_cache)) # This should be 

tensor(-9.2088, device='mps:0', grad_fn=<SumBackward0>)
Clean Value: -9.208788871765137
Clean Activations Cached: 1
Clean Gradients Cached: 1


In [50]:
def attr_patch_sae_acts(
        clean_cache: ActivationCache,
        clean_grad_cache: ActivationCache,
        site: str, layer: int
    ):
    clean_sae_acts_post = clean_cache[utils.get_act_name(site, layer) + ".hook_sae_acts_post"]
    clean_grad_sae_acts_post = clean_grad_cache[utils.get_act_name(site, layer) + ".hook_sae_acts_post"]
    sae_act_attr = clean_grad_sae_acts_post * (0 - clean_sae_acts_post)
    return sae_act_attr

site = "resid_pre"
layer = 6
sae_act_attr = attr_patch_sae_acts(clean_cache, clean_grad_cache, site, layer)

In [54]:
px.imshow(
    sae_act_attr[:, s2_pos, all_live_features].cpu(),
    color_continuous_midpoint=0,
    color_continuous_scale='RdBu',
    title="attribution patching",
    labels = {"x":"Feature Idx", "y":"Prompt Idx"},
    x=list(map(str, all_live_features.tolist()))
)

In [57]:
fig = px.scatter(
    y=sae_act_attr[:, s2_pos, all_live_features].flatten().cpu(),
    x=causal_effects.flatten().cpu(),
    title=f"Attribution vs Activation Patching Per SAE feature (L{layer} S2 Pos, all prompts)",
    labels = {"x":"Activation Patch", "y":"Attribution Patch"}
)
fig.add_shape(
    type='line',
    x0=causal_effects.min(),
    y0=causal_effects.min(),
    x1=causal_effects.max(),
    y1=causal_effects.max(),
    line=dict(
        color='gray',
        width=1,
        dash='dot'
    )
)
fig.show()