<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Hooked_SAE_Transformer_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# HookedSAETransformer Demo

HookedSAETransformer is a lightweight extension of HookedTransformer that allows you to "splice in" Sparse Autoencoders. This makes it easy to do exploratory analysis such as: running inference with SAEs attached, caching SAE feature activations, and intervening on SAE activations with hooks.

I (Connor Kissane) implemented this to accelerate research on [Attention SAEs](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs) based on suggestions from Arthur Conmy and Neel Nanda, and found that it was well worth the time and effort. I hope other researchers will also find the library useful! This notebook demonstrates how it works and how to use it.

# Setup

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/ckkissane/TransformerLens@hooked-sae-transformer
  
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
import torch
import transformer_lens.utils as utils

import plotly.express as px
import tqdm
from functools import partial
import einops
import plotly.graph_objects as go

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
     "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
     "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
}

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
    if return_fig:
        return fig
    fig.show(renderer)

from typing import List
def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):


    y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]
    error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] 

    fig = go.Figure(data=[go.Bar(
        x=x_axis,
        y=y_data,
        error_y=dict(
            type='data',  # specifies that the actual values are given
            array=error_y_data,  # the magnitudes of the errors
            visible=True  # make error bars visible
        ),
    )])

    # Customize layout
    fig.update_layout(title_text=f'Logit Diff after Interventions',
                    xaxis_title_text='Intervention',
                    yaxis_title_text='Logit diff',
                    plot_bgcolor='white')

    # Show the figure
    fig.show()

In [None]:
if torch.cuda.is_available():
    device = "cuda"
# elif torch.backends.mps.is_available():
#     device = "mps"
else: 
    device = "cpu"
torch.set_grad_enabled(False)

# Loading and Running Models

Just like a [HookedTransformer](https://TransformerLensOrg.github.io/TransformerLens/generated/demos/Main_Demo.html#Loading-and-Running-Models), we can load in any model that's supported in TransformerLens with the `HookedSAETransformer.from_pretrained(MODEL_NAME)`. In this demo we'll use GPT-2 small.

In [None]:
from sae_lens import HookedSAETransformer
model: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small").to(device)

By default HookedSAETransformer will behave exactly like a HookedTransformer. We'll explore the main features of HookedSAETransformer on the classic IOI task, so let's first sanity check that GPT2-small can do the IOI task without any SAEs attached:

In [None]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" John", " Mary",),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(answers)

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    
tokens = model.to_tokens(prompts, prepend_bos=True)
original_logits, cache = model.run_with_cache(tokens)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print(f"Original average logit diff: {original_average_logit_diff}")
original_per_prompt_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print(f"Original per prompt logit diff: {original_per_prompt_logit_diff}")

# HookedSAEs

In order to use the key features of HookedSAETransformer, we first need to load in SAEs.

HookedSAE is an SAE class we've implemented to have TransformerLens hooks around the SAE activations. While we will use it out of the box, it is designed to be hackable: you can copy and paste the HookedSAE class into a notebook and completely change the architecture / hook names, and as long as it reconstructs the activations, it should still work.

You can initialize a HookedSAE with a HookedSAEConfig:
```
cfg = HookedSAEConfig(
    d_sae (int): The size of the dictionary.
    d_in (int): The dimension of the input activations for the SAE
    hook_name (str): The hook name of the activation the SAE was trained on (eg. blocks.0.attn.hook_z)
)
hooked_sae = HookedSAE(cfg)
```

Note you'll likely have to write some basic conversion code to match configs / state dicts to the HookedSAE when loading in an open sourced SAE (eg from HuggingFace). We'll use our GPT-2 Small [Attention SAEs](https://www.alignmentforum.org/posts/FSTRedtjuHa4Gfdbr/attention-saes-scale-to-gpt-2-small) to demonstrate. For convenience, we'll load in all of our attention SAEs from HuggingFace, convert them to HookedSAEs, and store them in a dictionary that maps each hook_name (str) to the corresponding HookedSAE.

<details>

Later we'll show how to add HookedSAEs to the HookedSAETransformer (replacing model activations with their SAE reconstructions). When you add a HookedSAE, HookedSAETransformer just treats this a black box that takes some activation as an input, and outputs a tensor of the same shape. 

With this in mind, the HookedSAE is designed to be simple and hackable. Think of it as a convenient default class that you can copy and edit. As long as it takes a TransformerLens activation as input, and outputs a tensor of the same shape, you should be able to add it to your HookedSAETransformer.

You probably don't even need to use the HookedSAE class, although it's recommended. The sae can be any pytorch module that takes in some activation at hook_name and outputs a tensor of the same shape. The two assumptions that HookedSAETransformer makes when adding SAEs are:
1. The SAE class has a cfg attribute, sae.cfg.hook_name (str), for the activation that the SAE was trained to reconstruct (in TransformerLens notation e.g. 'blocks.0.attn.hook_z')
2. The SAE takes that activation as input, and outputs a tensor of the same shape.

The main benefit of HookedSAE is that it's a subclass of HookedRootModule, so we can add hooks to SAE activations. This makes it easy to leverage existing TL functionality like run_with_cache and run_with_hooks with SAEs.

</details>

In [None]:
from sae_lens import SAE

hook_name_to_sae = {}
for layer in tqdm.tqdm(range(12)):
    sae, cfg_dict, _ = SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=device,
    )
    hook_name_to_sae[sae.cfg.hook_name] = sae
    

print(hook_name_to_sae.keys())

# Run with SAEs

The key feature of HookedSAETransformer is being able to "splice in" SAEs, replacing model activations with their SAE reconstructions. 

To run a forward pass with SAEs attached use `model.run_with_saes(tokens, saes=saes)`, where saes is a list of HookedSAEs that you want to add for just this forward pass. These will be reset immediately after the forward pass, returning the model to its original state.

I expect this to be particularly useful for evaluating SAEs (eg [Gurnee](https://www.alignmentforum.org/posts/rZPiuFxESMxCDHe4B/sae-reconstruction-errors-are-empirically-pathological)), including evaluating how SAE reconstructions affect the models ability to perform certain tasks (eg [Makelov et al.](https://openreview.net/forum?id=MHIX9H8aYF&referrer=%5Bthe%20profile%20of%20Neel%20Nanda%5D(%2Fprofile%3Fid%3D~Neel_Nanda1)))

To demonstrate, let's use `run_with_saes` to evaluate many combinations of SAEs on different cross sections of the IOI circuit.

<details>

Under the hood, TransformerLens already wraps activations with a HookPoint object. HookPoint is a dummy pytorch module that acts as an identity function by default, and is only used to access the activation with PyTorch hooks. When you run_with_saes, HookedSAETransformer temporarily replaces these HookPoints with the given HookedSAEs, which take the activation as input and replace it with the HookedSAE output (the reconstructed activation) during the forward pass. 

Since HookedSAE is a subclass of HookedRootModule, we also are able to add PyTorch hooks to the corresponding SAE activations, as we'll use later.

</details>

In [None]:
all_layers = [[0, 3], [2, 4], [5,6], [7, 8], [9, 10, 11]]
x_axis = ['Clean Baseline']
per_prompt_logit_diffs = [
    original_per_prompt_logit_diff, 
]

for layers in all_layers:
    hooked_saes = [hook_name_to_sae[utils.get_act_name('z', layer)] for layer in layers]
    logits_with_saes = model.run_with_saes(tokens, saes=hooked_saes)
    average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)
    per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)
    
    x_axis.append(f"With SAEs L{layers}")
    per_prompt_logit_diffs.append(per_prompt_diff_with_saes)

show_avg_logit_diffs(x_axis, per_prompt_logit_diffs)

## Run with cache (with SAEs)

We often want to see what SAE features are active on a given prompt. With HookedSAETransformer, you can cache HookedSAE activations (and all the other standard activations) with `logits, cache = model.run_with_cache_with_saes(tokens, saes=saes)`. Just as `run_with_saes` is a wapper around the standard forward pass, `run_with_cache_with_saes` is a wrapper around `run_with_cache`, and will also only add these saes for one forward pass before returning the model to its original state. 

To access SAE activations from the cache, the corresponding hook names will generally be the HookedTransformer hook_name (eg blocks.5.attn.hook_z) + the hookedSAE hooked name preceeded by a period (eg .hook_sae_acts_post).

`run_with_cache_with_saes` makes it easy to explore which SAE features are active across any input. Let's explore the active features at the S2 position for our L5 Attention SAE across all of our IOI prompts:

In [None]:
layer, s2_pos = 5, 10
saes = [hook_name_to_sae[utils.get_act_name('z', layer)]]
_, cache = model.run_with_cache_with_saes(tokens, saes=saes)
sae_acts = cache[utils.get_act_name('z', layer) + ".hook_sae_acts_post"][:, s2_pos, :]
live_feature_mask = sae_acts > 0
live_feature_union = live_feature_mask.any(dim=0)

imshow(
    sae_acts[:, live_feature_union],
    title = "Activations of Live SAE features at L5 S2 position per prompt",
    xaxis="Feature Id", yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
)

We could then interpret some of the commonly activating features, like 7515, using [neuronpedia](https://www.neuronpedia.org/gpt2-small/5-att-kk/7515).

## Run with Hooks (with SAEs)

HookedSAETransformer also allows you to intervene on SAE activations with `model.run_with_hooks_with_saes(tokens, saes=saes, fwd_hooks=fwd_hooks)`. This works exactly like the standard TransformerLens `run_with_hooks`, with the added benefit that we can now intervene on SAE activations from the HookedSAEs that we splice in. Along the same lines as `run_with_saes` and `run_with_cache_with_saes`, this will only temporarily add SAEs before returning the model to it's original state. 

I expect this to be useful when doing circuit analysis with SAEs. To demonstrate, let's zero ablate individual layer 5 attention SAE features to localize causally important features.

In [None]:
def ablate_sae_feature(sae_acts, hook, pos, feature_id):
    if pos is None:
        sae_acts[:, :, feature_id] = 0.
    else:
        sae_acts[:, pos, feature_id] = 0.
    return sae_acts

layer = 5
sae = hook_name_to_sae[utils.get_act_name('z', layer)]

logits_with_saes = model.run_with_saes(tokens, saes=sae)
clean_sae_baseline_per_prompt = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)

all_live_features = torch.arange(sae.cfg.d_sae)[live_feature_union.cpu()]

causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))
fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}


abl_layer, abl_pos  = 5, 10
for feature_id in tqdm.tqdm(all_live_features):
    feature_id = feature_id.item()
    abl_feature_logits = model.run_with_hooks_with_saes(
        tokens,
        saes=sae,
        fwd_hooks=[(utils.get_act_name('z', abl_layer) + ".hook_sae_acts_post", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]
    ) # [batch, seq, vocab]
    
    abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]
    causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - clean_sae_baseline_per_prompt


imshow(causal_effects, title=f"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}", xaxis="Feature Idx", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist())))

Although it's not super clean, we see a few features stand out, where ablating them causes a nontrivial drop in logit diff on multiple prompts: 7515 and 27535 for BABA prompts, with 44256 for ABBA prompts.

# Add SAEs

While the `run_with_saes` family of methods are great for evaluating SAEs and exploratory analysis, you may want to permanently attach SAEs to your model. You can attach SAEs to any activation with `model.add_sae(sae)`, where sae is a HookedSAE. 

When you add an SAE, it gets stored in `model.acts_to_saes`, a dictionary that maps the activation name to the HookedSAE that is attached. The main benefit of permanently adding SAEs is that we can now just run the model like a normal HookedTransformer (with `forward`, `run_with_cache`, `run_with_hooks`), but some activations will be replaced with the reconstructed activations from the corresponding SAEs.

I expect this to be most useful when you've already identified a good set of SAEs that you want to use for interpretability, and don't feel like passing in a massive list of saes for every forward pass.

In [None]:
print("Attached SAEs before add_sae", model.acts_to_saes)
layer = 5
sae = hook_name_to_sae[utils.get_act_name('z', layer)]
model.add_sae(sae)
print("Attached SAEs after add_sae", model.acts_to_saes)

Now we can just call the standard HookedTransformer forward, and the sae that we added will automatically be spliced in.

In [None]:
logits_with_saes = model(tokens)
assert not torch.allclose(original_logits, logits_with_saes, atol=1e-4)

average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)
print(f"Average logit diff with SAEs: {average_logit_diff_with_saes}")
per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)

## Run with cache

Similarly, we can also use `logits, cache = model.run_with_cache(tokens)` directly to cache SAE activations:

In [None]:
layer = 5
_, cache = model.run_with_cache(tokens)
s2_pos = 10
sae_acts = cache[utils.get_act_name('z', layer) + ".hook_sae_acts_post"][:, s2_pos, :]

live_feature_mask = sae_acts > 0
live_feature_union = live_feature_mask.any(dim=0)

imshow(
    sae_acts[:, live_feature_union],
    title = "Activations of Live SAE features at L5 S2 position per prompt",
    xaxis="Feature Id", yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
)

## Run with hooks

Finally we can also use `run_with_hooks` and intervene on the added SAE's activations. To show a more complicated intervention, we'll try path patching the feature from the S-inhibition head's value vectors.

In [None]:
model.set_use_split_qkv_input(True)

In [None]:
def path_patch_v_input(v_input, hook, feature_dirs, pos, head_index):
    v_input[:, pos, head_index, :] = v_input[:, pos, head_index, :] - feature_dirs
    return v_input


s_inhib_heads = [(7, 3), (7, 9), (8,6), (8,10)]

results = torch.zeros(tokens.shape[0], all_live_features.shape[0])

W_O_cat = einops.rearrange(
    model.W_O,
    "n_layers n_heads d_head d_model -> n_layers (n_heads d_head) d_model"
)

for feature_id in tqdm.tqdm(all_live_features):
    feature_id = feature_id.item()
    feature_acts = cache[utils.get_act_name('z', abl_layer) + ".hook_sae_acts_post"][:, abl_pos, feature_id] # [batch]
    feature_dirs = (feature_acts.unsqueeze(-1) * sae.W_dec[feature_id]) @ W_O_cat[abl_layer]
    hook_fns = [
        (utils.get_act_name('v_input', layer), partial(path_patch_v_input, feature_dirs=feature_dirs, pos=abl_pos, head_index=head)) for (layer, head) in s_inhib_heads
    ]
    path_patched_logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=hook_fns
    )

    path_patched_logit_diff = logits_to_ave_logit_diff(path_patched_logits, answer_tokens, per_prompt=True)
    results[:, fid_to_idx[feature_id]] = path_patched_logit_diff - clean_sae_baseline_per_prompt

imshow(
    results, 
    title=f"Change in logit diff when path patching features from S_inhibition heads values per prompts",
    xaxis="Feature Id", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist()))
)

# Reset SAEs

One major footgun is forgetting about an SAE that you previously attached with `add_sae`. Similar to TransformerLens `reset_hooks`, you can always reset SAEs you've added with `model.reset_saes()`. You can also pass in a list of activation names to only reset a subset of attached SAEs.

In [None]:
print("Attached SAEs before reset_saes:", model.acts_to_saes)
model.reset_saes()
print("Attached SAEs after reset_saes:", model.acts_to_saes)

Note that the HookedSAETransformer API is generally designed to closely match TransformerLens hooks API.

# Error Nodes

Recent exciting work from [Marks et al.](https://arxiv.org/abs/2403.19647v2) demonstrated the use of "error nodes" in SAE circuit analysis. The idea is that for some input activation x, SAE(x) = x_reconstruct is an approximation of x, but we can define an error_term such that x = x_reconstruct + error_term.

This seems useful: instead of replacing x with x_reconstruct, which might break everything and make our circuit analysis janky, we can just re-write x as a function of the SAE features, bias, and error term, which gives us access to all of the SAE features but without breaking performance. 

Additionally, we can compare interventions on SAE features to the same intervention on the error term to get a better sense of how much the SAE features have actually captured.

To use error terms with HookedSAEs, you can set `hooked_sae.cfg.use_error_term = True`, or initialize it to True in the config. Note HookedSAEConfig sets this to False by default.

In [None]:
import copy
l5_sae = hook_name_to_sae[utils.get_act_name('z', 5)]
l5_sae_with_error = copy.deepcopy(l5_sae)
l5_sae_with_error.use_error_term=True
model.add_sae(l5_sae_with_error)
print("Attached SAEs after adding l5_sae_with_error:", model.acts_to_saes)

Now the output of each attached SAE will be SAE(x) + error_term = x. We can sanity check this by confirming that running with SAEs produces the same logits without SAEs.

In [None]:
logits_with_saes = model(tokens)
logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)

assert torch.allclose(logits_with_saes, original_logits, atol=1)

Now we can compare ablations of each feature to ablating the error node. We'll start by ablating each feature on each prompt, and then the error nodes. We'll append the effects from ablating error nodes to the rightmost column on the heatmap:

In [None]:
def ablate_sae_feature(sae_acts, hook, pos, feature_id):
    if pos is None:
        sae_acts[:, :, feature_id] = 0.
    else:
        sae_acts[:, pos, feature_id] = 0.
    return sae_acts

layer = 5
hooked_encoder = model.acts_to_saes[utils.get_act_name('z', layer)]
all_live_features = torch.arange(hooked_encoder.cfg.d_sae)[live_feature_union.cpu()]

causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))
fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}


abl_layer, abl_pos  = 5, 10
for feature_id in tqdm.tqdm(all_live_features):
    feature_id = feature_id.item()
    abl_feature_logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=[(utils.get_act_name('z', abl_layer) + ".hook_sae_acts_post", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]
    ) # [batch, seq, vocab]
    
    abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]
    causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - original_per_prompt_logit_diff

def able_sae_error(sae_error, hook, pos):
    if pos is None:
        sae_error = 0.
    else:
        sae_error[:, pos, ...] = 0.
    return sae_error


abl_error_logits = model.run_with_hooks(
    tokens,
    return_type="logits",
    fwd_hooks=[(utils.get_act_name('z', abl_layer) + ".hook_sae_error", partial(able_sae_error, pos=abl_pos))]
) # [batch, seq, vocab]

abl_error_logit_diff = logits_to_ave_logit_diff(abl_error_logits, answer_tokens, per_prompt=True) # [batch]
error_abl_effect = abl_error_logit_diff - original_per_prompt_logit_diff


causal_effects_with_error = torch.cat([causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1)
imshow(causal_effects_with_error, title=f"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}", xaxis="Feature Idx", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist()))+["error"])

We can see that on some prompts, ablating the error term (right most column) does have a non trivial effect on the logit diff, although I don't see a clear pattern. It seems useful to include this term when doing causal interventions to get a better sense of how much the SAE features are actually explaining. 

# Attribution patching 


Both [Anthropic](https://transformer-circuits.pub/2024/march-update/index.html#feature-heads) and [Marks et al](https://arxiv.org/abs/2403.19647v2). also demonstrated the use of gradient based attribution techniques as a substitute for activation patching on SAE features. The key idea is that patching / ablations (as we did above) can be slow, as it requires a new forward pass for each patch. This seems especially problematic when dealing with SAEs with tens of thousands of features per activation. They find that gradient based attribution techniques like [attribution patching](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching) are good approximations, allowing for more efficient and scalable circuit analysis with SAEs.

With `HookedSAETransformer`, added SAEs are automatically spliced into the computational graph, allowing us to implement this easily. Let's implement attribution patching for every L5 SAE feature to find causally relevant SAE features with just one forward and one backward pass.

In [None]:
torch.set_grad_enabled(True)

In [None]:
from transformer_lens import ActivationCache
filter_sae_acts = lambda name: ("hook_sae_acts_post" in name)
def get_cache_fwd_and_bwd(model, tokens, metric):
    model.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    model.add_hook(filter_sae_acts, forward_cache_hook, "fwd")

    grad_cache = {}
    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    model.add_hook(filter_sae_acts, backward_cache_hook, "bwd")

    value = metric(model(tokens))
    print(value)
    value.backward()
    model.reset_hooks()
    return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)


BASELINE = original_per_prompt_logit_diff
def ioi_metric(logits, answer_tokens=answer_tokens):
    return (logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE).sum()

clean_tokens = tokens.clone()
clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
print("Clean Gradients Cached:", len(clean_grad_cache))

In [None]:
def attr_patch_sae_acts(
        clean_cache: ActivationCache, 
        clean_grad_cache: ActivationCache,
        site: str, layer: int
    ):
    clean_sae_acts_post = clean_cache[utils.get_act_name(site, layer) + ".hook_sae_acts_post"] 
    clean_grad_sae_acts_post = clean_grad_cache[utils.get_act_name(site, layer) + ".hook_sae_acts_post"] 
    sae_act_attr = clean_grad_sae_acts_post * (0 - clean_sae_acts_post)
    return sae_act_attr

site = "z"
layer = 5
sae_act_attr = attr_patch_sae_acts(clean_cache, clean_grad_cache, site, layer)

imshow(
    sae_act_attr[:, s2_pos, all_live_features],
    title="attribution patching",
    xaxis="Feature Idx", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist())))

In [None]:
fig = scatter(
    y=sae_act_attr[:, s2_pos, all_live_features].flatten(), 
    x=causal_effects.flatten(),
    title="Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)",
    xaxis="Activation Patch",
    yaxis="Attribution Patch",
    return_fig=True
)
fig.add_shape(
    type='line',
    x0=causal_effects.min(),
    y0=causal_effects.min(),
    x1=causal_effects.max(),
    y1=causal_effects.max(),
    line=dict(
        color='gray',
        width=1,
        dash='dot'
    )
)
fig.show()