In [50]:
import plotly.graph_objects as go
import plotly.express as px
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name
import sys
sys.path.append('.')
from sae_training.sparse_autoencoder import SparseAutoencoder
import torch
from torch.utils.data import DataLoader
import ipywidgets as widgets
from IPython.display import display
from datasets import load_dataset
from tqdm import tqdm
from functools import partial

In [51]:
model = HookedTransformer.from_pretrained("gelu-1l")

Loaded pretrained model gelu-1l into HookedTransformer


# Pick a prompt, any prompt

In [52]:
prompt_widget = widgets.Textarea(
    value='#include<iostream>\n#',
    placeholder='Type something',
    description='Prompt:',
    disabled=False
)
display(prompt_widget)


Textarea(value='#include<iostream>\n#', description='Prompt:', placeholder='Type something')

In [53]:
prompt = prompt_widget.value
str_tokens = model.to_str_tokens(prompt)
print("Tokenization: ", str_tokens)

Tokenization:  ['<|BOS|>', '#', 'include', '<', 'i', 'ostream', '>', '\n', '#']


In [54]:
seq_pos_widget = widgets.IntSlider(
    value=len(str_tokens)-1,
    min=0,
    max=len(str_tokens)-1,
    step=1,
    description='Token position:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
display(seq_pos_widget)

IntSlider(value=8, continuous_update=False, description='Token position:', max=8)

In [55]:
seq_pos = seq_pos_widget.value
print("Selected token: ", str_tokens[seq_pos])

Selected token:  #


In [56]:
target_logit_widget = widgets.Text(
    value='include',
    placeholder='Type something',
    description='Target logit:',
    disabled=False
)
display(target_logit_widget)

Text(value='include', description='Target logit:', placeholder='Type something')

In [57]:
target_logit = target_logit_widget.value

# Clean run

In [58]:
hooks = ['resid_pre', 'attn_out', 'resid_mid', 'mlp_out', 'resid_post']
saes = {}
for hook in hooks:
    saes[hook] = SparseAutoencoder.load_from_pretrained(f'./weights/saes32/final_sparse_autoencoder_gelu-1l_blocks.0.hook_{hook}_16384.pt')
    saes[hook].eval()


In [122]:
def get_sae_acts(cache) -> tuple[dict[str, torch.Tensor], list[list[int]]]:
    sae_activations = {}
    for hook in hooks:
        sae_activations[hook] = saes[hook](cache[hook, 0])[1][0, seq_pos, :]
    sae_active_features = [torch.nonzero(act).squeeze(-1).tolist()
                           for act in sae_activations.values()]
    
    return sae_activations, sae_active_features

def plot_sae_activations(cache, title_suffix=""):
    _, sae_active_features = get_sae_acts(cache)

    fig = go.Figure()

    for i, (active_features, hook) in enumerate(zip(sae_active_features, hooks)):
        sorted_features = sorted(active_features)
        positions = list(range(len(sorted_features)))

        # Generate hover text for each feature
        hover_texts = [f"Feature {feature}" for feature in sorted_features]
        
        fig.add_trace(go.Scatter(
            x=positions,
            y=[i] * len(sorted_features),
            mode='markers+text',
            marker=dict(size=20, color="green"),
            text=sorted_features,  # Original text displayed below each marker
            hoverinfo='text',  # Use custom text for hover info
            hovertext=hover_texts,  # Custom hover text for each point
            textposition="bottom center",
            name=hook,
        ))

    # Customize the layout
    fig.update_layout(
        title='Sparse Autoencoder Activations ' + title_suffix,
        xaxis=dict(
            title='',
            showticklabels=False
        ),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(sae_active_features))),
            ticktext=hooks,
            autorange="reversed"
        ),
        showlegend=False
    )

    fig.show()

clean_logits, clean_cache = model.run_with_cache(prompt)
clean_sae_acts, clean_sae_active_features = get_sae_acts(clean_cache)
plot_sae_activations(clean_cache, "(Clean)")

In [60]:
def logits_plot(logits, title_suffix=""):
    logits_at_pos = torch.softmax(logits[0, seq_pos, :].detach(), dim=0)
    probs, idxs = logits_at_pos.topk(10)
    labels = [f"'{model.to_str_tokens(i)[0]}'" for i in idxs]

    other_prob = 1 - probs.sum().item()
    probs = probs.tolist() + [other_prob]
    labels.append("Other")

    # Create a bar chart
    fig = px.bar(x=labels, y=probs,
                labels={'x': 'Predictions', 'y': 'Probability'},
                title="Top 10 Predictions of the Language Model " + title_suffix)

    fig.show()
logits_plot(clean_logits, "(Clean)")

# Ablations

## Are both attn and MLP necessary for this behavior?
Note that I'm looking just at this one example rn, but the logit is so strong here that I think it's enough signal to get started with.

Note to self: come back here later and verify that this generalizes.

### Sanity check: ablating attn_out completely makes the model unable to predict "include"
Doing zero-ablation for now, note that this might mean the results are wrong
because this is taking the model out of distribution in a pretty major way.

In [61]:
def full_zero_ablate_hook(acts, hook, seq_pos=seq_pos):
    acts[:, seq_pos] = 0
    return acts

with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), full_zero_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(no_attn_cache, "(Zero-ablated attn layer)")
    logits_plot(no_attn_logits, "(Zero-ablated attn layer)")

### I'm curious, what happens if I ablate mlp_out?
I think it should be possible to implement this behavior using only embed, attn,
and unembed. Is that actually what's happening?

In [62]:
with model.hooks(fwd_hooks=[(get_act_name("mlp_out", 0), full_zero_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(no_attn_cache, "(Zero-ablated MLP layer)")
    logits_plot(no_attn_logits, "(Zero-ablated MLP layer)")

### Mean-ablations of attn_out and mlp_out
The attn_out ablation definitely has to fuck up the model; I'm not sure about
mlp_out so wanna test this out.

In [63]:
dataset = iter(load_dataset("NeelNanda/c4-tokenized-2b", split="train", streaming=True))
tokens = []
for _ in range(10):
    tokens.append(torch.tensor(next(dataset)["tokens"],
                               dtype=torch.long,
                               device="mps",
                               requires_grad=False))
tokens = torch.cat(tokens, dim=0)
print("Tokens shape: ", tokens.shape)
tokens_dataloader = DataLoader(tokens, batch_size=16, shuffle=False)

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Tokens shape:  torch.Size([10240])


In [64]:
attn_out_acts = []
mlp_out_acts = []
for batch in tqdm(tokens_dataloader):
    _, cache = model.run_with_cache(batch)
    attn_out_acts.append(cache["attn_out", 0].squeeze())
    mlp_out_acts.append(cache["mlp_out", 0].squeeze())
attn_out_acts = torch.cat(attn_out_acts, dim=0)
mlp_out_acts = torch.cat(mlp_out_acts, dim=0)
print("Attn out acts shape: ", attn_out_acts.shape)
print("MLP out acts shape: ", mlp_out_acts.shape)
mean_attn_out_acts = attn_out_acts.mean(dim=0)
mean_mlp_out_acts = mlp_out_acts.mean(dim=0)

100%|██████████| 640/640 [00:05<00:00, 126.36it/s]

Attn out acts shape:  torch.Size([10240, 512])
MLP out acts shape:  torch.Size([10240, 512])





In [93]:
def full_mean_ablate_hook(acts, hook, seq_pos=seq_pos):
    if 'attn_out' not in hook.name and 'mlp_out' not in hook.name:
        raise NotImplementedError("Hook name not supported yet")
    acts[:, seq_pos] = mean_attn_out_acts if 'attn_out' in hook.name else mean_mlp_out_acts
    return acts

with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), full_mean_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(no_attn_cache, "(Mean-ablated attn layer)")
    logits_plot(no_attn_logits, "(Mean-ablated attn layer)")

In [94]:
with model.hooks(fwd_hooks=[(get_act_name("mlp_out", 0), full_mean_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(no_attn_cache, "(Mean-ablated MLP layer)")
    logits_plot(no_attn_logits, "(Mean-ablated MLP layer)")

Nice! So now I can be reasonably confident that the MLP is actually necessary for implementing this behavior.

## Ablate each attn_out SAE feature and see what happens to the "include" logit
### Naive version — do this the naive way and ablate SAE features one at a time.

Reasoning for doing this on attn_out: we know that the model can't predict "include"
based on the bigram frequency alone, so it must be moving info from the previous
occurrence of "include". This means that ablating attn should destroy the model's
ability to perform the task.

In [67]:
get_act_name("mlp_out", 0)

'blocks.0.hook_mlp_out'

In [97]:
def ablate_sae_feature(acts, hook, sae_feature_indices=[]):
    hook_name = hook.name[14:] # remove the "block.0.hook_" from the name
    sae = saes[hook_name]
    for feature in sae_feature_indices:
        current_act_strength = (acts[:, seq_pos] - sae.b_dec) @ sae.W_enc[:, feature]
        #!TODO deal with the bias subtraction
        acts[:, seq_pos] -= current_act_strength * sae.W_enc[:, feature]
    return acts

target_logit_idx = model.tokenizer.vocab[target_logit]
logit_strengths = [] # Currently measuring probability
for attn_sae_feature in clean_sae_active_features[1]:
    hook = partial(ablate_sae_feature, sae_feature_indices=[attn_sae_feature])
    ablated_attn_sae_logits = model.run_with_hooks(prompt, fwd_hooks=[(get_act_name("attn_out", 0), hook)])
    probs = torch.softmax(ablated_attn_sae_logits[0, seq_pos, :], dim=0)
    logit_strengths.append(probs[target_logit_idx].item())

# plot the logit strengths
fig = go.Figure()
fig.add_trace(go.Bar(x=[f"Feature {i}" for i in clean_sae_active_features[1]],
                     y=logit_strengths))
fig.update_layout(title="Probability Assigned To Target Logit of Ablated Attention SAE Features",  
                    xaxis_title="Feature",
                    yaxis_title="Target Logit Probability")

fig.show()

Sanity check — is this only happening because feature 8220 was firing much more strongly?

In [111]:
def sae_activation_strengths(cache, hook):
    sae_acts, sae_active_features = get_sae_acts(cache)
    act_values = sae_acts[hook][sae_active_features[1]].cpu().detach().numpy()
    # show as bar chart (w labels)
    fig = go.Figure()
    fig.add_trace(go.Bar(x=[f"Feature {i}" for i in sae_active_features[1]],
                        y=act_values))
    fig.update_layout(title="Activation Strength of Attention SAE Features",
                        xaxis_title="Feature",
                        yaxis_title="Activation Strength")

    fig.show()
sae_activation_strengths(clean_cache, "attn_out")

In [121]:
features_to_ablate = [8220, 7931]
hook = partial(ablate_sae_feature, sae_feature_indices=features_to_ablate)
with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), hook)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(cache, f"(Ablated attn SAE feature {features_to_ablate}")
    logits_plot(logits, f"(Ablated attn SAE feature {features_to_ablate}")

Hm, that's annoying, apparently ablating the important stuff takes the model so far out of distribution that the SAEs in the subsequent hooks just turn into a hot mess.

I wonder if there's a way to ablate this that doesn't take things out of distribution?

## How to do SAE feature ablations well

### What happens when I do ablation the naive way
This is what I've been doing so far — literally just subtract the encoder direction of the SAE feature from the activations.

What does this actually do? If I do this ablation and then run it through the SAE, what happens?