# How to fix the norm and keep things in-distribution after you ablate SAE features
Whenever I ablate an SAE feature (particularly in attn_out and mlp_out), that messes up the norm and takes the model off-distribution.
I want to spend a little bit of time doing large-scale experiments to figure out what a good way to deal with this is.

## Distribution of LN scale factors
First let's verify that my intuitions are actually correct here. What's the distribution of LN scale factors in general (ie on-distribution)?

In [4]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name
import torch
from torch.utils.data import DataLoader
import ipywidgets as widgets
from IPython.display import display
from datasets import load_dataset
from tqdm import tqdm
from functools import partial
import pandas as pd
import lucys_utils as lu

In [5]:
model = HookedTransformer.from_pretrained("gelu-1l")
saes = lu.load_saes("../weights/saes32/final_sparse_autoencoder_gelu-1l_blocks.0.hook_HOOK_GOES_HERE_16384.pt")

Loaded pretrained model gelu-1l into HookedTransformer


In [None]:
def get_ln_scales(model, n_tokens=10_000):
    dataloader = lu.gib_tokens(n_tokens=n_tokens, dataloader=True)
    
    ln_scales = {}
    for batch in tqdm(dataloader):
        _, cache = model.run_with_cache(batch, names_filter=lambda x: 'scale' in x)
        for key in cache.keys():
            if key not in ln_scales:
                ln_scales[key] = []
            ln_scales[key].append(cache[key].squeeze().cpu())
    ln_scales = {k: torch.cat(v).reshape(-1) for k, v in ln_scales.items()}
    return ln_scales
ln_scales = get_ln_scales(model, n_tokens=100_000)

fig = go.Figure()
for key, values in ln_scales.items():
    fig.add_trace(go.Histogram(x=values, histnorm='probability', name=key, opacity=0.5))
fig.update_layout(barmode='overlay')
fig.update_traces(opacity=0.75)
fig.update_layout(title_text='Layer Normalization Scales (gelu-1l)', xaxis_title_text='Scale', yaxis_title_text='Probability')
fig.show()
del ln_scales

In [None]:
ln_scales = get_ln_scales(HookedTransformer.from_pretrained('gpt2-small'), n_tokens=1_000)
# get the mean and std div for each layer
ln_stats = {k: (v.mean().item(), v.std().item()) for k, v in ln_scales.items()}
ln_stats = pd.DataFrame(ln_stats, index=['mean', 'std']).T
ln_stats

## What actually happens to norms when you do SAE feature ablation

Let's start doing this by ablating features one-by-one in attn_out and seeing the impact this has on norms downstream.

I'll start off with a single prompt, just to get a hang of what's going on, and then generalize from there. What I'm actually trying to figure out here basically "when I ablate a small set of SAE features, what happens to the norms downstream"?

In [6]:
prompt = "the quick brown fox jumps over the lazy dog"
str_tokens = model.to_str_tokens(prompt)
names_filter = lambda x: 'resid' in x or 'out' in x or 'ln' in x
_, clean_cache = model.run_with_cache(prompt, names_filter=names_filter)

def get_all_norms(cache, return_dict: bool = False) -> dict[str, torch.Tensor]:
    if return_dict:
        return {k: v.norm(dim=-1).squeeze().cpu()
                for k, v in cache.items()
                # if 'ln' not in k
                }
    
    return torch.stack([v.norm(dim=-1).squeeze().cpu()
                        for k, v in cache.items()
                        # if 'ln' not in k
                        ])
    
clean_norms_dict = get_all_norms(clean_cache, True)
hook_names = list(clean_norms_dict.keys())
clean_norms = torch.stack(list(clean_norms_dict.values()))

lu.plot_sae(saes, clean_cache, 'attn_out')


sae_acts, sae_active_features = lu.get_all_sae_acts(saes, clean_cache)
feature_to_ablate = sae_active_features[1][:7]
print(f"Features to ablate: {feature_to_ablate}")

with model.hooks(fwd_hooks=[(get_act_name('attn_out', 0),
                             lu.ablate_sae_feature(saes, feature_to_ablate))]):
    _, corrupted_cache = model.run_with_cache(prompt, names_filter=names_filter)
    corrupted_norms = get_all_norms(corrupted_cache)


# Calculate the relative difference (avoiding division by zero)
epsilon = 1e-9  # Small constant to avoid division by zero
mean_norms = (clean_norms + corrupted_norms) / 2
relative_difference = (corrupted_norms - clean_norms) / (mean_norms + epsilon)

# Prepare hover text
hover_text = [[
    f'Layer: {hook_names[i]}<br><br>' +
    f'Clean: {clean_norms[i,j]:.3f}' +
    f'<br>Corrupted: {corrupted_norms[i,j]:.3f}' +
    f'<br>Difference: {corrupted_norms[i,j]-clean_norms[i,j]:.3f}' +
    f'<br>Ratio: {corrupted_norms[i,j]/(clean_norms[i,j]+epsilon):.3f}'
    for j in range(clean_norms.shape[1])]
              for i in range(clean_norms.shape[0])]

# Create the heatmap
fig = go.Figure(data=go.Heatmap(
    z=relative_difference.numpy(),
    x=[f'{i}: "{str_tokens[i]}"' for i in range(clean_norms.shape[1])],
    y=[hook_names[i] for i in range(clean_norms.shape[0])],
    hoverinfo="text",
    text=hover_text,
    colorbar_title="Relative<br>Difference",
    colorscale="RdBu"))

fig.update_layout(title='Differences in norms between clean and corrupted activations',
                  yaxis_autorange='reversed')
fig.show()

Features to ablate: [1484, 3229, 3639, 3766, 3858, 3885, 4514]


Okay, so what seems to happen is what I would've predicted — ablating something in attn_out reduces its norm, which reduces the norm of resid_mid by a smaller amount. What happens afterwards is kinda surprising, it seems like pretty often the norm of mlp_out and resid_post actually increases as a result of the ablation, but this seems to really depend on the example.

Let's do this at scale now.

In [15]:
def test_post_ablation_norms(n_tokens=10_000, n_features: int | None = None):
    n_features_is_fixed = n_features is not None

    tokens = lu.gib_tokens(n_tokens=n_tokens).reshape(-1, 32)

    results = []
    for prompt in tqdm(tokens):
        names_filter = lambda x: 'resid' in x or 'out' in x
        _, clean_cache = model.run_with_cache(prompt, names_filter=names_filter)

        # (note that I'm skipping resid_pre since it's always 0 — it's before the intervention)
        clean_norms = get_all_norms(clean_cache)[1:, -1]

        sae_acts = lu.get_sae_acts(saes['attn_out'], clean_cache['attn_out', 0][0, -1])
        sae_active_feats = torch.nonzero(sae_acts).squeeze()

        # pick a random subset
        if not n_features_is_fixed:
            n_features = torch.randint(1, sae_active_feats.shape[0], (1,)).item()
        feature_to_ablate = sae_active_feats[torch.randperm(sae_active_feats.shape[0])[:n_features]]

        with model.hooks(fwd_hooks=[(get_act_name('attn_out', 0),
                                    lu.ablate_sae_feature(saes, feature_to_ablate))]):
            _, corrupted_cache = model.run_with_cache(prompt, names_filter=names_filter)
            corrupted_norms = get_all_norms(corrupted_cache)[1:, -1]

        results.append(corrupted_norms / (clean_norms + 1e-9))
    results = torch.stack(results)

    print(f"Means: {results.mean(dim=0)}")
    print(f"Stds: {results.std(dim=0)}")

    # create histograms for each hook point (as subplots)
    hook_names = ['attn_out', 'resid_mid', 'mlp_out', 'resid_post']
    fig = make_subplots(rows=len(hook_names), cols=1, subplot_titles=hook_names)

    for i, name in enumerate(hook_names):
        # remove outliers from the histogram (anything above 2)
        results_filtered = results[:, i][results[:, i] < 2]
        
        fig.add_trace(go.Histogram(x=results_filtered,
                                histnorm='probability',
                                nbinsx=300,
                                name=name),
                    row=i+1, col=1)  # Update row for each subplot

    fig.update_layout(height=300*len(hook_names),  # Adjust height based on number of subplots
                    title_text='Ratio between norms of corrupted and clean activations ' +
                        f'(removed outliers above 2{f" n_features={n_features}" if n_features_is_fixed else ""})',
                    xaxis_title_text='Relative difference',
                    yaxis_title_text='Probability')

    # Update xaxis and yaxis titles for each subplot if needed
    # This is just an example for the first subplot
    fig.update_xaxes(title_text='Relative difference', row=1, col=1)
    fig.update_yaxes(title_text='Probability', row=1, col=1)

    # Show the figure
    fig.show()
test_post_ablation_norms()

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 10.09it/s]
100%|██████████| 320/320 [00:26<00:00, 12.08it/s]

Means: tensor([0.8753, 0.9013, 1.0244, 1.0251])
Stds: tensor([0.0827, 0.0683, 0.1317, 0.1250])





Let's try the same thing, but this time force it to only ablate 3 features at a time.

In [16]:
test_post_ablation_norms(n_features=3)

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 10.01it/s]
100%|██████████| 320/320 [00:11<00:00, 28.32it/s]

Means: tensor([0.9765, 0.9803, 1.0018, 1.0017])
Stds: tensor([0.0386, 0.0338, 0.0512, 0.0507])





Conclusion: it looks like attn_out and resid_mid basically always have a lower norm (as expected), and mlp_out and resid_post sometimes end up larger, sometimes smaller, it's basically a coin-toss.

What I wanna try next is what happens if you literally just rescale attn_out by a similar amount. What's the impact on the downstream hook point norms, on the downstream SAE activations, and the logits? How does that impact compare to ablation?