<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Hooked_SAE_Transformer_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# HookedSAETransformer Demo

HookedSAETransformer is a lightweight extension of HookedTransformer that allows you to "splice in" Sparse Autoencoders. This makes it easy to do exploratory analysis such as: running inference with SAEs attached, caching SAE feature activations, and intervening on SAE activations with hooks.

I (Connor Kissane) implemented this to accelerate research on [Attention SAEs](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs) based on suggestions from Arthur Conmy and Neel Nanda, and found that it was well worth the time and effort. I hope other researchers will also find the library useful! This notebook demonstrates how it works and how to use it.

# Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/jbloomAus/SAELens
  
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import torch
import transformer_lens.utils as utils

import plotly.express as px
import tqdm
from functools import partial
import einops
import plotly.graph_objects as go

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
     "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
     "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
}

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
    if return_fig:
        return fig
    fig.show(renderer)

from typing import List
def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):


    y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]
    error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] 

    fig = go.Figure(data=[go.Bar(
        x=x_axis,
        y=y_data,
        error_y=dict(
            type='data',  # specifies that the actual values are given
            array=error_y_data,  # the magnitudes of the errors
            visible=True  # make error bars visible
        ),
    )])

    # Customize layout
    fig.update_layout(title_text=f'Logit Diff after Interventions',
                    xaxis_title_text='Intervention',
                    yaxis_title_text='Logit diff',
                    plot_bgcolor='white')

    # Show the figure
    fig.show()

In [3]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else: 
    device = "cpu"
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1130e3b00>

# Loading and Running Models

Just like a [HookedTransformer](https://TransformerLensOrg.github.io/TransformerLens/generated/demos/Main_Demo.html#Loading-and-Running-Models), we can load in any model that's supported in TransformerLens with the `HookedSAETransformer.from_pretrained(MODEL_NAME)`. In this demo we'll use GPT-2 small.

In [4]:
from huggingface_hub import login
login(token="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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /Users/curttigges/.cache/huggingface/token
Login successful


In [5]:
from sae_lens import HookedSAETransformer
model: HookedSAETransformer = HookedSAETransformer.from_pretrained("google/gemma-2-2b").to(device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b into HookedTransformer
Moving model to device:  mps


In [45]:
tokens = model.to_tokens("<bos> did it come from and", prepend_bos=False)
tokens2 = model.to_tokens("<bos> male bonding based on practical jokes and humour, thanks largely to the influence of Lee Majors. He remembers Lee as being very athletically inclined and always eager to perform as many of the show’s stunts as he possibly could himself. Richard Anderson he remembers as someone who was obsessed about topping up his tan between takes, but he’s impressed about how Anderson developed the relatively thankless role of Oscar Goldman and made him into such an iconic part of the show. He talks about coordinating his role as Rudy Wells in both “The Six Million Dollar Man” and “The Bionic Woman”, and sometimes getting confused about which lines he was", prepend_bos=False)
tokens3 = model.to_tokens("<bos> large parts warmer than -", prepend_bos=False)
model.to_str_tokens(tokens), model.to_str_tokens(tokens3)

(['<bos>', ' did', ' it', ' come', ' from', ' and'],
 ['<bos>', ' large', ' parts', ' warmer', ' than', ' -'])

In [23]:
from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "layer_20/width_65k/canonical", # won't always be a hook point
    device = device
)

In [24]:
model.add_sae(sae)
print("Attached SAEs after add_sae", model.acts_to_saes)

Attached SAEs after add_sae {'blocks.20.hook_resid_post': SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)}


In [46]:
layer = 20
_, cache = model.run_with_cache(tokens3)
sae_acts = cache[utils.get_act_name('resid_post', layer) + ".hook_sae_acts_post"][:, 1, :]

live_feature_mask = sae_acts > 0
live_feature_union = live_feature_mask.any(dim=0)


imshow(
    sae_acts[:, live_feature_union],
    title = "Activations of Live SAE features at L5 S2 position per prompt",
    xaxis="Feature Id", yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
)

In [47]:
sae_acts[0, 20190:20200]

tensor([ 0.0000,  0.0000, 37.9348,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000], device='mps:0')

In [32]:
sae_acts_cnt = (sae_acts[0, :] > 0).sum()
sae_acts_cnt

tensor(118, device='mps:0')

## Run with hooks

Finally we can also use `run_with_hooks` and intervene on the added SAE's activations. To show a more complicated intervention, we'll try path patching the feature from the S-inhibition head's value vectors.

In [None]:
model.set_use_split_qkv_input(True)

In [None]:
def path_patch_v_input(v_input, hook, feature_dirs, pos, head_index):
    v_input[:, pos, head_index, :] = v_input[:, pos, head_index, :] - feature_dirs
    return v_input


s_inhib_heads = [(7, 3), (7, 9), (8,6), (8,10)]

results = torch.zeros(tokens.shape[0], all_live_features.shape[0])

W_O_cat = einops.rearrange(
    model.W_O,
    "n_layers n_heads d_head d_model -> n_layers (n_heads d_head) d_model"
)

for feature_id in tqdm.tqdm(all_live_features):
    feature_id = feature_id.item()
    feature_acts = cache[utils.get_act_name('z', abl_layer) + ".hook_sae_acts_post"][:, abl_pos, feature_id] # [batch]
    feature_dirs = (feature_acts.unsqueeze(-1) * sae.W_dec[feature_id]) @ W_O_cat[abl_layer]
    hook_fns = [
        (utils.get_act_name('v_input', layer), partial(path_patch_v_input, feature_dirs=feature_dirs, pos=abl_pos, head_index=head)) for (layer, head) in s_inhib_heads
    ]
    path_patched_logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=hook_fns
    )

    path_patched_logit_diff = logits_to_ave_logit_diff(path_patched_logits, answer_tokens, per_prompt=True)
    results[:, fid_to_idx[feature_id]] = path_patched_logit_diff - clean_sae_baseline_per_prompt

imshow(
    results, 
    title=f"Change in logit diff when path patching features from S_inhibition heads values per prompts",
    xaxis="Feature Id", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist()))
)

# Reset SAEs

One major footgun is forgetting about an SAE that you previously attached with `add_sae`. Similar to TransformerLens `reset_hooks`, you can always reset SAEs you've added with `model.reset_saes()`. You can also pass in a list of activation names to only reset a subset of attached SAEs.

In [None]:
print("Attached SAEs before reset_saes:", model.acts_to_saes)
model.reset_saes()
print("Attached SAEs after reset_saes:", model.acts_to_saes)

Note that the HookedSAETransformer API is generally designed to closely match TransformerLens hooks API.

# Error Nodes

Recent exciting work from [Marks et al.](https://arxiv.org/abs/2403.19647v2) demonstrated the use of "error nodes" in SAE circuit analysis. The idea is that for some input activation x, SAE(x) = x_reconstruct is an approximation of x, but we can define an error_term such that x = x_reconstruct + error_term.

This seems useful: instead of replacing x with x_reconstruct, which might break everything and make our circuit analysis janky, we can just re-write x as a function of the SAE features, bias, and error term, which gives us access to all of the SAE features but without breaking performance. 

Additionally, we can compare interventions on SAE features to the same intervention on the error term to get a better sense of how much the SAE features have actually captured.

To use error terms with SAEs, you can set `sae.use_error_term = True`. Note this is set to False by default.

In [None]:
import copy
l5_sae = hook_name_to_sae[utils.get_act_name('z', 5)]
l5_sae_with_error = copy.deepcopy(l5_sae)
l5_sae_with_error.use_error_term=True
model.add_sae(l5_sae_with_error)
print("Attached SAEs after adding l5_sae_with_error:", model.acts_to_saes)

Now the output of each attached SAE will be SAE(x) + error_term = x. We can sanity check this by confirming that running with SAEs produces the same logits without SAEs.

In [None]:
logits_with_saes = model(tokens)
logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)

assert torch.allclose(logits_with_saes, original_logits, atol=1)

Now we can compare ablations of each feature to ablating the error node. We'll start by ablating each feature on each prompt, and then the error nodes. We'll append the effects from ablating error nodes to the rightmost column on the heatmap:

In [None]:
def ablate_sae_feature(sae_acts, hook, pos, feature_id):
    if pos is None:
        sae_acts[:, :, feature_id] = 0.
    else:
        sae_acts[:, pos, feature_id] = 0.
    return sae_acts

layer = 5
hooked_encoder = model.acts_to_saes[utils.get_act_name('z', layer)]
all_live_features = torch.arange(hooked_encoder.cfg.d_sae)[live_feature_union.cpu()]

causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))
fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}


abl_layer, abl_pos  = 5, 10
for feature_id in tqdm.tqdm(all_live_features):
    feature_id = feature_id.item()
    abl_feature_logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=[(utils.get_act_name('z', abl_layer) + ".hook_sae_acts_post", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]
    ) # [batch, seq, vocab]
    
    abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]
    causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - original_per_prompt_logit_diff

def able_sae_error(sae_error, hook, pos):
    if pos is None:
        sae_error = 0.
    else:
        sae_error[:, pos, ...] = 0.
    return sae_error


abl_error_logits = model.run_with_hooks(
    tokens,
    return_type="logits",
    fwd_hooks=[(utils.get_act_name('z', abl_layer) + ".hook_sae_error", partial(able_sae_error, pos=abl_pos))]
) # [batch, seq, vocab]

abl_error_logit_diff = logits_to_ave_logit_diff(abl_error_logits, answer_tokens, per_prompt=True) # [batch]
error_abl_effect = abl_error_logit_diff - original_per_prompt_logit_diff


causal_effects_with_error = torch.cat([causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1)
imshow(causal_effects_with_error, title=f"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}", xaxis="Feature Idx", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist()))+["error"])

We can see that on some prompts, ablating the error term (right most column) does have a non trivial effect on the logit diff, although I don't see a clear pattern. It seems useful to include this term when doing causal interventions to get a better sense of how much the SAE features are actually explaining. 

# Attribution patching 


Both [Anthropic](https://transformer-circuits.pub/2024/march-update/index.html#feature-heads) and [Marks et al](https://arxiv.org/abs/2403.19647v2). also demonstrated the use of gradient based attribution techniques as a substitute for activation patching on SAE features. The key idea is that patching / ablations (as we did above) can be slow, as it requires a new forward pass for each patch. This seems especially problematic when dealing with SAEs with tens of thousands of features per activation. They find that gradient based attribution techniques like [attribution patching](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching) are good approximations, allowing for more efficient and scalable circuit analysis with SAEs.

With `HookedSAETransformer`, added SAEs are automatically spliced into the computational graph, allowing us to implement this easily. Let's implement attribution patching for every L5 SAE feature to find causally relevant SAE features with just one forward and one backward pass.

In [None]:
torch.set_grad_enabled(True)

In [None]:
from transformer_lens import ActivationCache
filter_sae_acts = lambda name: ("hook_sae_acts_post" in name)
def get_cache_fwd_and_bwd(model, tokens, metric):
    model.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    model.add_hook(filter_sae_acts, forward_cache_hook, "fwd")

    grad_cache = {}
    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    model.add_hook(filter_sae_acts, backward_cache_hook, "bwd")

    value = metric(model(tokens))
    print(value)
    value.backward()
    model.reset_hooks()
    return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)


BASELINE = original_per_prompt_logit_diff
def ioi_metric(logits, answer_tokens=answer_tokens):
    return (logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE).sum()

clean_tokens = tokens.clone()
clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
print("Clean Gradients Cached:", len(clean_grad_cache))

In [None]:
def attr_patch_sae_acts(
        clean_cache: ActivationCache, 
        clean_grad_cache: ActivationCache,
        site: str, layer: int
    ):
    clean_sae_acts_post = clean_cache[utils.get_act_name(site, layer) + ".hook_sae_acts_post"] 
    clean_grad_sae_acts_post = clean_grad_cache[utils.get_act_name(site, layer) + ".hook_sae_acts_post"] 
    sae_act_attr = clean_grad_sae_acts_post * (0 - clean_sae_acts_post)
    return sae_act_attr

site = "z"
layer = 5
sae_act_attr = attr_patch_sae_acts(clean_cache, clean_grad_cache, site, layer)

imshow(
    sae_act_attr[:, s2_pos, all_live_features],
    title="attribution patching",
    xaxis="Feature Idx", yaxis="Prompt Idx", x=list(map(str, all_live_features.tolist())))

In [None]:
fig = scatter(
    y=sae_act_attr[:, s2_pos, all_live_features].flatten(), 
    x=causal_effects.flatten(),
    title="Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)",
    xaxis="Activation Patch",
    yaxis="Attribution Patch",
    return_fig=True
)
fig.add_shape(
    type='line',
    x0=causal_effects.min(),
    y0=causal_effects.min(),
    x1=causal_effects.max(),
    y1=causal_effects.max(),
    line=dict(
        color='gray',
        width=1,
        dash='dot'
    )
)
fig.show()