#### Note to Future Confused Lucy: I moved a lot of the functions from here to lucys_utils.py

In [1]:
import plotly.graph_objects as go
import plotly.express as px
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name
import sys
sys.path.append('..')
from sae_training.sparse_autoencoder import SparseAutoencoder
import torch
from torch.utils.data import DataLoader
import ipywidgets as widgets
from IPython.display import display
from datasets import load_dataset
from tqdm import tqdm
from functools import partial
import pandas as pd

In [2]:
model = HookedTransformer.from_pretrained("gelu-1l")

Loaded pretrained model gelu-1l into HookedTransformer


# Pick a prompt, any prompt

In [3]:
prompt_widget = widgets.Textarea(
    value='#include<iostream>\n#',
    placeholder='Type something',
    description='Prompt:',
    disabled=False
)
display(prompt_widget)


Textarea(value='#include<iostream>\n#', description='Prompt:', placeholder='Type something')

In [4]:
prompt = prompt_widget.value
str_tokens = model.to_str_tokens(prompt)
print("Tokenization: ", str_tokens)

Tokenization:  ['<|BOS|>', '#', 'include', '<', 'i', 'ostream', '>', '\n', '#']


In [5]:
seq_pos_widget = widgets.IntSlider(
    value=len(str_tokens)-1,
    min=0,
    max=len(str_tokens)-1,
    step=1,
    description='Token position:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
display(seq_pos_widget)

IntSlider(value=8, continuous_update=False, description='Token position:', max=8)

In [6]:
seq_pos = seq_pos_widget.value
print("Selected token: ", str_tokens[seq_pos])

Selected token:  #


In [7]:
target_logit_widget = widgets.Text(
    value='include',
    placeholder='Type something',
    description='Target logit:',
    disabled=False
)
display(target_logit_widget)

Text(value='include', description='Target logit:', placeholder='Type something')

In [8]:
target_logit = target_logit_widget.value

# Clean run

In [9]:
hooks = ['resid_pre', 'attn_out', 'resid_mid', 'mlp_out', 'resid_post']
saes = {}
for hook in hooks:
    saes[hook] = SparseAutoencoder.load_from_pretrained(f'../weights/saes32/final_sparse_autoencoder_gelu-1l_blocks.0.hook_{hook}_16384.pt')
    saes[hook].eval()


In [10]:
def get_sae_acts(cache) -> tuple[dict[str, torch.Tensor], list[list[int]]]:
    sae_activations = {}
    for hook in hooks:
        sae_activations[hook] = saes[hook](cache[hook, 0])[1][0, seq_pos, :]
    sae_active_features = [torch.nonzero(act).squeeze(-1).tolist()
                           for act in sae_activations.values()]
    
    return sae_activations, sae_active_features

def plot_sae_activations(cache, cache2=None, title_suffix=""):
    sae_activations, sae_active_features = get_sae_acts(cache)
    sae_activations2, sae_active_features2 = get_sae_acts(cache2) if cache2 else (None, [])
    
    # TODO change the plot to show the strength of the activations

    fig = go.Figure()

    for i, (active_features, hook) in enumerate(zip(sae_active_features, hooks)):
        # Initialize lists to hold positions, colors, shapes, and hover texts
        positions = []
        colors = []
        shapes = []
        hover_texts = []
        
        # Determine unique features in both caches for comparison
        unique_features = set(active_features + (sae_active_features2[i] if sae_active_features2 else []))
        sorted_features = sorted(unique_features)
        
        for feature in sorted_features:
            position = sorted_features.index(feature)
            positions.append(position)
            hover_texts.append(f"Feature {feature}")
            
            # Determine color and shape based on activation in caches
            if (cache2 is None
                or feature in active_features
                    and (sae_active_features2 and feature in sae_active_features2[i])):
                colors.append("green")  # Active in both
                shapes.append("circle")
            elif feature in active_features:
                colors.append("red")  # Only in cache
                shapes.append("star")
            elif sae_active_features2 and feature in sae_active_features2[i]:
                colors.append("blue")  # Only in cache2
                shapes.append("diamond")

        # Add trace with customized markers
        fig.add_trace(go.Scatter(
            x=positions,
            y=[i] * len(positions),
            mode='markers+text',
            marker=dict(
                size=20, 
                color=colors,
                symbol=shapes
            ),
            text=sorted_features,  # Original text displayed below each marker
            textfont=dict(size=7),
            hoverinfo='text',  # Use custom text for hover info
            hovertext=hover_texts,  # Custom hover text for each point
            textposition="bottom center",
            name=hook,
            showlegend=False,
        ))
        
    if cache2 is not None:
        # Add dummy traces for legend
        fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='green', symbol='circle'), name='Active in both'))
        fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='red', symbol='square'), name='Only in original cache'))
        fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='blue', symbol='diamond'), name='Only in modified cache'))


    # Customize the layout
    fig.update_layout(
        title='Sparse Autoencoder Activations ' + title_suffix,
        xaxis=dict(
            title='',
            showticklabels=False
        ),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(sae_active_features))),
            ticktext=hooks,
            autorange="reversed"
        ),
        showlegend=(cache2 is not None),
    )

    fig.show()

clean_logits, clean_cache = model.run_with_cache(prompt)
clean_sae_acts, clean_sae_active_features = get_sae_acts(clean_cache)
plot_sae_activations(clean_cache, title_suffix="(Clean)")

In [11]:
def logits_plot(logits, title_suffix=""):
    logits_at_pos = torch.softmax(logits[0, seq_pos, :].detach(), dim=0)
    probs, idxs = logits_at_pos.topk(10)
    labels = [f"'{model.to_str_tokens(i)[0]}'" for i in idxs]

    other_prob = 1 - probs.sum().item()
    probs = probs.tolist() + [other_prob]
    labels.append("Other")

    # Create a bar chart
    fig = px.bar(x=labels, y=probs,
                labels={'x': 'Predictions', 'y': 'Probability'},
                title="Top 10 Predictions of the Language Model " + title_suffix)

    fig.show()
logits_plot(clean_logits, "(Clean)")

## Are both attn and MLP necessary for this behavior?
Note that I'm looking just at this one example rn, but the logit is so strong here that I think it's enough signal to get started with.

Note to self: come back here later and verify that this generalizes.

# Ablations

### Sanity check: ablating attn_out completely makes the model unable to predict "include"
Doing zero-ablation for now, note that this might mean the results are wrong
because this is taking the model out of distribution in a pretty major way.

In [12]:
def full_zero_ablate_hook(acts, hook, seq_pos=seq_pos):
    acts[:, seq_pos] = 0
    return acts

with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), full_zero_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, no_attn_cache, "(Zero-ablated attn layer)")
    logits_plot(no_attn_logits, "(Zero-ablated attn layer)")

### I'm curious, what happens if I ablate mlp_out?
I think it should be possible to implement this behavior using only embed, attn,
and unembed. Is that actually what's happening?

In [13]:
with model.hooks(fwd_hooks=[(get_act_name("mlp_out", 0), full_zero_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, no_attn_cache, "(Zero-ablated MLP layer)")
    logits_plot(no_attn_logits, "(Zero-ablated MLP layer)")

### Mean-ablations of attn_out and mlp_out
The attn_out ablation definitely has to fuck up the model; I'm not sure about
mlp_out so wanna test this out.

In [14]:
dataset = iter(load_dataset("NeelNanda/c4-tokenized-2b", split="train", streaming=True))
tokens = []
for _ in range(10):
    tokens.append(torch.tensor(next(dataset)["tokens"],
                               dtype=torch.long,
                               device="mps",
                               requires_grad=False))
tokens = torch.cat(tokens, dim=0)
print("Tokens shape: ", tokens.shape)
tokens_dataloader = DataLoader(tokens, batch_size=16, shuffle=False)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Tokens shape:  torch.Size([10240])


# WARNING FROM FUTURE LUCY: There's a bug here
I'm concatenating instead of stacking the sequences. That means that
there will be times when I'm running the model on a "prompt" which is actually a frankenstein of two prompts that got merged because of this fuck-up. Some of the results from this file may be skewed because of this.

In [15]:
resid_pre_acts = []
attn_out_acts = []
resid_mid_acts = []
mlp_out_acts = []
resid_post_acts = []
for batch in tqdm(tokens_dataloader):
    _, cache = model.run_with_cache(batch)
    resid_pre_acts.append(cache["resid_pre", 0].squeeze())
    attn_out_acts.append(cache["attn_out", 0].squeeze())
    resid_mid_acts.append(cache["resid_mid", 0].squeeze())
    mlp_out_acts.append(cache["mlp_out", 0].squeeze())
    resid_post_acts.append(cache["resid_post", 0].squeeze())
resid_pre_acts = torch.cat(resid_pre_acts, dim=0)
attn_out_acts = torch.cat(attn_out_acts, dim=0)
resid_mid_acts = torch.cat(resid_mid_acts, dim=0)
mlp_out_acts = torch.cat(mlp_out_acts, dim=0)
resid_post_acts = torch.cat(resid_post_acts, dim=0)
acts_by_hook = {
    "resid_pre": resid_pre_acts,
    "attn_out": attn_out_acts,
    "resid_mid": resid_mid_acts,
    "mlp_out": mlp_out_acts,
    "resid_post": resid_post_acts,
}
print("Attn out acts shape: ", attn_out_acts.shape)
print("MLP out acts shape: ", mlp_out_acts.shape)
mean_resid_pre_acts = resid_pre_acts.mean(dim=0)
mean_attn_out_acts = attn_out_acts.mean(dim=0)
mean_resid_mid_acts = resid_mid_acts.mean(dim=0)
mean_mlp_out_acts = mlp_out_acts.mean(dim=0)
mean_resid_post_acts = resid_post_acts.mean(dim=0)

100%|██████████| 640/640 [00:05<00:00, 127.31it/s]


Attn out acts shape:  torch.Size([10240, 512])
MLP out acts shape:  torch.Size([10240, 512])


In [16]:
def full_mean_ablate_hook(acts, hook, seq_pos=seq_pos):
    if 'attn_out' not in hook.name and 'mlp_out' not in hook.name:
        raise NotImplementedError("Hook name not supported yet")
    acts[:, seq_pos] = mean_attn_out_acts if 'attn_out' in hook.name else mean_mlp_out_acts
    return acts

with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), full_mean_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, no_attn_cache, "(Mean-ablated attn layer)")
    logits_plot(no_attn_logits, "(Mean-ablated attn layer)")

In [17]:
with model.hooks(fwd_hooks=[(get_act_name("mlp_out", 0), full_mean_ablate_hook)]):
    no_attn_logits, no_attn_cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, no_attn_cache, "(Mean-ablated MLP layer)")
    logits_plot(no_attn_logits, "(Mean-ablated MLP layer)")

Nice! So now I can be reasonably confident that the MLP is actually necessary for implementing this behavior.

## Ablate each attn_out SAE feature and see what happens to the "include" logit
### Naive version — do this the naive way and ablate SAE features one at a time.

Reasoning for doing this on attn_out: we know that the model can't predict "include"
based on the bigram frequency alone, so it must be moving info from the previous
occurrence of "include". This means that ablating attn should destroy the model's
ability to perform the task.

In [18]:
get_act_name("mlp_out", 0)

'blocks.0.hook_mlp_out'

In [19]:
def ablate_sae_feature_enc(acts, hook, sae_feature_indices=[]):
    hook_name = hook.name[14:] # remove the "block.0.hook_" from the name
    sae = saes[hook_name]
    for feature in sae_feature_indices:
        #! WARNING FROM FUTURE LUCY — I forgot to add a relu here; results may be off (though prob not in practice)
        current_act_strength = (acts[:, seq_pos] - sae.b_dec) @ sae.W_enc[:, feature]
        #!TODO deal with the bias subtraction
        acts[:, seq_pos] -= current_act_strength * sae.W_enc[:, feature]
    return acts

target_logit_idx = model.tokenizer.vocab[target_logit]
logit_strengths = [] # Currently measuring probability
for attn_sae_feature in clean_sae_active_features[1]:
    hook = partial(ablate_sae_feature_enc, sae_feature_indices=[attn_sae_feature])
    ablated_attn_sae_logits = model.run_with_hooks(prompt, fwd_hooks=[(get_act_name("attn_out", 0), hook)])
    probs = torch.softmax(ablated_attn_sae_logits[0, seq_pos, :], dim=0)
    logit_strengths.append(probs[target_logit_idx].item())

# plot the logit strengths
fig = go.Figure()
fig.add_trace(go.Bar(x=[f"Feature {i}" for i in clean_sae_active_features[1]],
                     y=logit_strengths))
fig.update_layout(title="Probability Assigned To Target Logit of Ablated Attention SAE Features",  
                    xaxis_title="Feature",
                    yaxis_title="Target Logit Probability")

fig.show()

Sanity check — is this only happening because feature 8220 was firing much more strongly?

In [20]:
def sae_activation_strengths(cache, hook):
    sae_acts, sae_active_features = get_sae_acts(cache)
    act_values = sae_acts[hook][sae_active_features[1]].cpu().detach().numpy()
    # show as bar chart (w labels)
    fig = go.Figure()
    fig.add_trace(go.Bar(x=[f"Feature {i}" for i in sae_active_features[1]],
                        y=act_values))
    fig.update_layout(title="Activation Strength of Attention SAE Features",
                        xaxis_title="Feature",
                        yaxis_title="Activation Strength")

    fig.show()
sae_activation_strengths(clean_cache, "attn_out")

In [21]:
features_to_ablate = [8220]
hook = partial(ablate_sae_feature_enc, sae_feature_indices=features_to_ablate)
with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), hook)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, cache, f"(Ablated attn SAE feature {features_to_ablate}")
    logits_plot(logits, f"(Ablated attn SAE feature {features_to_ablate}")

Hm, that's annoying, apparently ablating the important stuff takes the model so far out of distribution that the SAEs in the subsequent hooks just turn into a hot mess.

I wonder if there's a way to ablate this that doesn't take things out of distribution?

## How to do SAE feature ablations well

### What happens when I do ablation the naive way
This is what I've been doing so far — literally just subtract the encoder direction of the SAE feature from the activations.

What does this actually do? If I do this ablation and then run it through the SAE, what happens?

In [22]:
def plot_act_sae(activation, title_suffix=""):
    sae_acts = saes['attn_out'](activation.unsqueeze(0))[1][0]
    features = torch.nonzero(sae_acts).squeeze()
    if not len(features):
        print("\n\nNo active features!\n\n")
        return
    sae_features_acts = sae_acts[features]
    px.bar(x=[f"Feature {i}" for i in features.tolist()],
        y=sae_features_acts.tolist(),
        title="Activation Strength of Attention SAE Features " + title_suffix,
        labels={'x': 'Feature', 'y': 'Activation Strength'}).show()
plot_act_sae(clean_cache['attn_out', 0][0, -1], "(Clean)")

In [23]:
for feature_to_ablate in [8220, 2275, 10191, 13158]:
    plot_act_sae(clean_cache['attn_out', 0][0, -1] - saes['attn_out'].W_enc[:, feature_to_ablate],
                f"(Encoder-ablated feature {feature_to_ablate})")



No active features!




In [24]:
for feature_to_ablate in [8220, 2275, 10191, 13158]:
    plot_act_sae(clean_cache['attn_out', 0][0, -1] - saes['attn_out'].W_dec[feature_to_ablate],
                f"(Encoder-ablated feature {feature_to_ablate})")

Yep, it does kinda look like subtracting the decoder direction is just better.

Let's do some actual statistical testing here since this is important — maybe run this on that small dataset I fetched earlier
and see how well this works. I can also do testing on synthetic (random) data later.

The metric I care about is just the mean absolute difference between the strengths of the actions on the clean vs ablated activation, with the caveat that we're ignoring the strength of the feature we're ablating. An additional metric is whether the feature we're ablating goes exactly to 0.

First let's do this with the attn_out SAE.

In [25]:
def get_sae_hidden_acts(sae, act):
    #! WARNING FROM FUTURE LUCY — I forgot to add a relu here; results may be off (though prob not in practice)
    return ((act - sae.b_dec) @ sae.W_enc) + sae.b_enc

def ablation_error(act, target):
    return (act - target).abs().mean().item()

def test_ablation_types(hook, n_acts=None):
    stats = []
    acts = acts_by_hook[hook]
    if n_acts:
        acts = acts[torch.randperm(len(acts))[:n_acts]]
    
    for act in tqdm(acts):
        sae_acts = get_sae_hidden_acts(saes[hook], act)
        features = torch.nonzero(sae_acts).squeeze(-1)
        
        # randomly select 20 features to ablate (if there are more than 20)
        if len(features) > 20:
            features = features[torch.randperm(len(features))[:20]]
        
        for feature in features:
            target_sae_acts = sae_acts.clone()
            target_sae_acts[feature] = 0
            
            # ablate it with W_enc
            enc_ablated_act = act - saes[hook].W_enc[:, feature]
            enc_ablated_sae_acts = get_sae_hidden_acts(saes[hook], enc_ablated_act)
            
            # ablate it with W_dec
            dec_ablated_act = act - saes[hook].W_dec[feature]
            dec_ablated_sae_acts = get_sae_hidden_acts(saes[hook], dec_ablated_act)
            
            stats.append({
                'enc_ablation_error': ablation_error(enc_ablated_sae_acts, target_sae_acts),
                'dec_ablation_error': ablation_error(dec_ablated_sae_acts, target_sae_acts),
                'enc_fail_to_ablate': enc_ablated_sae_acts[feature].abs().item(),
                'dec_fail_to_ablate': dec_ablated_sae_acts[feature].abs().item(),
            })

    stats_df = pd.DataFrame(stats)
    print("\nMean:")
    print(stats_df.mean())
    print("\n\nStandard Deviation:")
    print(stats_df.std())

    fig = go.Figure()
    fig.add_trace(go.Histogram(x=stats_df['enc_ablation_error'],
                               name="Encoder Ablation Error",
                               xbins=dict(size=1e-4)))
    fig.add_trace(go.Histogram(x=stats_df['dec_ablation_error'],
                               name="Decoder Ablation Error",
                               xbins=dict(size=1e-4)))
    fig.update_layout(barmode='overlay',
                    title=f"Ablation Error of {hook} SAE Features",
                    xaxis_title="Ablation Error",
                    yaxis_title="Frequency")
    fig.update_traces(opacity=0.75)
    fig.show()

    fig = go.Figure()
    fig.add_trace(go.Histogram(x=stats_df['enc_fail_to_ablate'],
                               name="Encoder Ablation Failure To Ablate",
                               xbins=dict(size=1e-4)))
    fig.add_trace(go.Histogram(x=stats_df['dec_fail_to_ablate'],
                               name="Decoder Ablation Failure To Ablate",
                               xbins=dict(size=1e-4)))
    fig.update_layout(barmode='overlay',
                    title=f"Failure To Ablate of {hook} SAE Features",
                    xaxis_title="Failure To Ablate",
                    yaxis_title="Frequency")
    fig.update_traces(opacity=0.75)
    fig.show()


#### Note to Future Confused Lucy: I commented out the code that actually runs the function
It was taking a while to run and the outputs were too bit which was making the Jupyter VS Code extension glitch.

In [26]:
# test_ablation_types('resid_pre', 1000)

In [27]:
# test_ablation_types('attn_out', 1000)

In [28]:
# test_ablation_types('resid_mid', 1000)

In [29]:
# test_ablation_types('mlp_out', 1000)

In [30]:
# test_ablation_types('resid_post', 1000)

Conclusion: ablating with W_dec is just strictly better than with W_enc

## Re-doing the experiments above with W_dec ablation

### Ablate SAE features one at a time, see what happens

In [31]:
def ablate_sae_feature_dec(acts, hook, sae_feature_indices=[]):
    hook_name = hook.name[14:] # remove the "block.0.hook_" from the name
    sae = saes[hook_name]
    for feature in sae_feature_indices:
        #! WARNING FROM FUTURE LUCY — I forgot to add a relu here; results may be off (though prob not in practice)
        current_act_strength = (acts[:, seq_pos] - sae.b_dec) @ sae.W_enc[:, feature]
        acts[:, seq_pos] -= current_act_strength * sae.W_dec[feature]
    return acts

target_logit_idx = model.tokenizer.vocab[target_logit]
logit_strengths = [] # Currently measuring probability
for attn_sae_feature in clean_sae_active_features[1]:
    hook = partial(ablate_sae_feature_dec, sae_feature_indices=[attn_sae_feature])
    ablated_attn_sae_logits = model.run_with_hooks(prompt, fwd_hooks=[(get_act_name("attn_out", 0), hook)])
    probs = torch.softmax(ablated_attn_sae_logits[0, seq_pos, :], dim=0)
    logit_strengths.append(probs[target_logit_idx].item())

# plot the logit strengths
fig = go.Figure()
fig.add_trace(go.Bar(x=[f"Feature {i}" for i in clean_sae_active_features[1]],
                     y=logit_strengths))
fig.update_layout(title="Probability Assigned To Target Logit After Ablating Individual Attention SAE Features",  
                    xaxis_title="Feature",
                    yaxis_title="Target Logit Probability")

fig.show()

Let's take the top 4 features that this gives us (ie all the features where if you ablate them, the model's confidence in the prediction drops below 0.8) and see what happens when we ablate all 4 of them at the same time.

In [32]:
features_to_ablate = [8220, 3326, 2689, 13158]
hook = partial(ablate_sae_feature_dec, sae_feature_indices=features_to_ablate)
with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), hook)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, cache, f"(Ablated attn SAE feature {features_to_ablate}")
    logits_plot(logits, f"(Ablated attn SAE feature {features_to_ablate}")

There's still some signal there, and that's pretty annoying. 

Let's try something different, what if I start off ablating resid_post SAE features instead?

In [33]:
logit_strengths = [] # Currently measuring probability
for resid_post_sae_feature in clean_sae_active_features[-1]:
    hook = partial(ablate_sae_feature_dec, sae_feature_indices=[resid_post_sae_feature])
    ablated_attn_sae_logits = model.run_with_hooks(prompt, fwd_hooks=[(get_act_name("resid_post", 0), hook)])
    probs = torch.softmax(ablated_attn_sae_logits[0, seq_pos, :], dim=0)
    logit_strengths.append(probs[target_logit_idx].item())

# plot the logit strengths
fig = go.Figure()
fig.add_trace(go.Bar(x=[f"Feature {i}" for i in clean_sae_active_features[-1]],
                     y=logit_strengths))
fig.update_layout(title="Probability Assigned To Target Logit After Ablating Individual resid_post SAE Features",  
                    xaxis_title="Feature",
                    yaxis_title="Target Logit Probability")

fig.show()

Let's ablate like the top 5 and see what happens.

In [34]:
how_many_features = 5
top_features = sorted(zip(logit_strengths, clean_sae_active_features[-1]))[:how_many_features]
hook = partial(ablate_sae_feature_dec, sae_feature_indices=[f for _, f in top_features])
with model.hooks(fwd_hooks=[(get_act_name("resid_post", 0), hook)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, cache, f"(Ablated resid_post SAE feature {features_to_ablate}")
    logits_plot(logits, f"(Ablated resid_post SAE feature {features_to_ablate}")

## Call w Arthur todo
TODO: Rescale to take into account the MLP yelling louder than att — try rescaling attn_out so that it's the same norm as on the clean pass
- Also try hard-coding the LN scaling factor (to make the scaling work the same way as in the clean pass)

TODO: check if L0 still explodes downstream with W_dec ablation


TODO: From Neel re ablating mlp_out destroys downstream SAEs: Interesting, what about resample ablation? And if you subtract the attention output or embedding from the logits, what happens?

TODO: From Neel re fixing LayerNorm coeff: It'd be interesting to plot histograms of what the LN scaling factors actually are






## Ablating attn_out features with W_dec

Let's try ablating the tok k features in attn_out (measuring by how much ablating that feature individually decreases the probability the model assigns to "include" being the next token).

In [35]:
logit_strengths = [] # Currently measuring probability
for resid_post_sae_feature in clean_sae_active_features[1]:
    hook = partial(ablate_sae_feature_dec, sae_feature_indices=[resid_post_sae_feature])
    ablated_attn_sae_logits = model.run_with_hooks(prompt, fwd_hooks=[(get_act_name("attn_out", 0), hook)])
    probs = torch.softmax(ablated_attn_sae_logits[0, seq_pos, :], dim=0)
    logit_strengths.append(probs[target_logit_idx].item())


how_many_features = 3
top_features = sorted(zip(logit_strengths, clean_sae_active_features[1]))[:how_many_features]
sae_feature_indices=[f for _, f in top_features]
hook = partial(ablate_sae_feature_dec, sae_feature_indices=sae_feature_indices)
with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), hook)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, cache, f"(Ablated attn_out SAE feature {sae_feature_indices})")
    logits_plot(logits, f"(Ablated attn_out SAE feature {sae_feature_indices})")

Hm, interesting. Basically all the features in resid_post got turned off by doing this, but the logits still look like this.

Maybe let's try also ablating the remaining features in resid_post?

In [36]:
how_many_features = 3
top_features = sorted(zip(logit_strengths, clean_sae_active_features[1]))[:how_many_features]
sae_feature_indices=[f for _, f in top_features]
hook = partial(ablate_sae_feature_dec, sae_feature_indices=sae_feature_indices)
remaining_resid_post_sae_features = [9676, 9995, 10393, 12600, 12901]
hook2 = partial(ablate_sae_feature_dec, sae_feature_indices=remaining_resid_post_sae_features)
with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), hook),
                            (get_act_name("resid_post", 0), hook2)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, cache, f"(Ablated attn_out SAE feature {sae_feature_indices} and resid_post SAE feature {remaining_resid_post_sae_features}")
    logits_plot(logits, f"(Ablated attn_out SAE feature {sae_feature_indices} and resid_post SAE feature {remaining_resid_post_sae_features}")

Okay cool. Let's try ablating those remaining resid_post features one at a time and see how much they matter individually.

In [37]:
feature_strengths = []
for feature in remaining_resid_post_sae_features:
    hook2 = partial(ablate_sae_feature_dec, sae_feature_indices=[feature])
    logits = model.run_with_hooks(prompt, fwd_hooks=[(get_act_name("attn_out", 0), hook),
                                                     (get_act_name("resid_post", 0), hook2)])
    probs = torch.softmax(logits[0, seq_pos, :], dim=0)
    feature_strengths.append(probs[target_logit_idx].item())

fig = go.Figure()
fig.add_trace(go.Bar(x=[f"Feature {i}" for i in remaining_resid_post_sae_features],
                     y=feature_strengths))
fig.update_layout(title="Probability Assigned To Target Logit After Ablating Individual resid_post SAE Features",
                    xaxis_title="Feature",
                    yaxis_title="Target Logit Probability")

fig.show()

Hmm... fucks sake, so everything just looks distributed and messy. What happens if I ablate 3 of these?

In [38]:
resid_post_to_ablate = [9676, 9995, 12600]
hook2 = partial(ablate_sae_feature_dec, sae_feature_indices=resid_post_to_ablate)
with model.hooks(fwd_hooks=[(get_act_name("attn_out", 0), hook),
                            (get_act_name("resid_post", 0), hook2)]):
    logits, cache = model.run_with_cache(prompt)
    plot_sae_activations(clean_cache, cache, f"(Ablated attn_out SAE feature {sae_feature_indices} and resid_post SAE feature {resid_post_to_ablate}")
    logits_plot(logits, f"(Ablated attn_out SAE feature {sae_feature_indices} and resid_post SAE feature {resid_post_to_ablate}")

Okay. Everything looks pretty distributed and cursed, and idk how to deal with this.

# TAKEAWAYS
- Ablating the entirety of attn or MLP nukes performance, both with zero-ablation and mean-ablation
  - This led me to thinking about how to ablate single SAE features in a way that keeps the norm where it should be; currently it seems like the LNs might be messing things up here.
- Doing ablation with W_dec works way better than with W_enc
- You can ablate 3 SAE features in attn_out (chosen by their impact on the "include" logit) and that basically turns off all the SAE features in resid_post, but "include" still gets predicted. You also need to ablate most of the remaining resid_post SAE features to get it to stop predicting "include", but at that point you're basically ablating everything so there's no point.