In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go
from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


In [3]:
kde_french = load_txt_data("kde4_french.txt")
kde_english = load_txt_data("kde4_english.txt")

kde4_french.txt: Loaded 1007 examples with 505 to 5345 characters each.
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.


In [4]:
french_activations = get_mlp_activations(kde_french, 3, model, num_prompts=100, mean=True)
english_activations = get_mlp_activations(kde_english, 3, model, num_prompts=100, mean=True)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [5]:
def evaluate_model(prompts: list[str], model: HookedTransformer, batch_size, crop_context=None):
    losses = []
    for batch_index in range(0, len(prompts), batch_size):
        if batch_index+batch_size >= len(prompts):
            batch = prompts[batch_index:]
        else:
            batch = prompts[batch_index:batch_index+batch_size]
        tokens = model.to_tokens(batch)
        if crop_context is not None:
            tokens = tokens[:, :crop_context]
        loss = model(tokens, return_type="loss")
        losses.append(loss.item())
    return losses

examples = kde_french[:20]
print("Context lengths:", [model.to_tokens(example).shape[1] for example in examples])

losses_batched = evaluate_model(examples, model, batch_size=4)
losses_single = evaluate_model(examples, model, batch_size=1)
print(f"Mean loss (batch size = 4): {np.mean(losses_batched):.4f}")
print(f"Mean loss (batch size = 1): {np.mean(losses_single):.4f}")

losses_batched = evaluate_model(examples, model, batch_size=4, crop_context=150)
losses_single = evaluate_model(examples, model, batch_size=1, crop_context=150)
print(f"Mean loss (batch size = 4, crop = (10:150)): {np.mean(losses_batched):.4f}")
print(f"Mean loss (batch size = 1, crop = (10:150)): {np.mean(losses_single):.4f}")


Context lengths: [368, 326, 195, 177, 173, 529, 339, 198, 213, 168, 230, 198, 184, 134, 215, 188, 793, 180, 190, 197]


OutOfMemoryError: CUDA out of memory. Tried to allocate 610.00 MiB (GPU 0; 23.65 GiB total capacity; 923.55 MiB already allocated; 596.56 MiB free; 1.19 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
def get_ablated_mlp_difference(prompts: list[str], model: HookedTransformer, neurons: list[int], layer_to_ablate: int, layer_to_cache: int, mean_neuron_activations: Float[Tensor, "d_mlp"]):
    original_losses = []
    ablated_losses = []
    mean_differences = []
    for prompt in tqdm(prompts):
        neurons = torch.LongTensor(neurons)

        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
        
        tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        ablated_losses.append(ablated_loss.item())

        block_name = f'blocks.{layer_to_cache}.mlp.hook_post'
        original_activations = original_cache[block_name][:, 1:]
        ablated_activations = ablated_cache[block_name][:, 1:]
        mean_difference = original_activations.mean((0, 1)) - ablated_activations.mean((0, 1))
        mean_differences.append(mean_difference)
        
        
    print(f"Original loss: {np.mean(original_losses):.2f}, ablated loss: {np.mean(ablated_losses):.2f} (+{((np.mean(ablated_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return torch.stack(mean_differences).mean(0)

In [None]:
layer_to_cache = 5
difference = get_ablated_mlp_difference(kde_french, model, neurons=[609], layer_to_ablate=3, layer_to_cache=layer_to_cache, mean_neuron_activations=english_activations)
sorted_differences, sorted_neurons = torch.topk(difference.abs(), len(difference), largest=True)


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, ablated loss: 3.81 (+6.23%)


In [None]:
px.histogram(difference.cpu().numpy(), title=f"Difference in layer {layer_to_cache} neuron activations between original and ablated context neuron", width=800)

In [None]:
def imshow(tensor, renderer=None, label_neurons=False, **kwargs):
    preset_kwargs = {
        "color_continuous_midpoint": 0.0,
        "color_continuous_scale": "RdBu"
    }

    fig = px.imshow(utils.to_numpy(tensor), **{**preset_kwargs, **kwargs})
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    if label_neurons:
        fig.update(data=[{'customdata': np.arange(len(tensor.flatten())).reshape(tensor.shape[0], -1), 'hovertemplate': "Difference: %{z:.4f}<br>Neuron: %{customdata}"}])
    fig.show(renderer=renderer)

imshow(difference.view(32, -1), label_neurons=True, title="""Difference in activations between original and ablated model at MLP layer 5 <br> rearranged from a 1D vector into a grid""", width=800)

In [None]:
def line(x, xlabel="", ylabel="", title="", xticks=None, width=800, hover_data=None):
    fig = px.line(x, title=title)
    fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, width=width)
    if xticks != None:
        fig.update_layout(
            xaxis = dict(
            tickmode = 'array',
            tickvals = [i for i in range(len(xticks))],
            ticktext = xticks
            )
        )
    if hover_data != None:
        fig.update(data=[{'customdata': hover_data, 'hovertemplate': "Loss: %{y:.4f} (+%{customdata:.2f}%)"}])
    fig.show()

# Plotting all differences seems to break Jupyter
line(sorted_differences.cpu().numpy()[:100], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[:100], title=f"Top absolute neuron differences in layer {layer_to_cache}", width=1400)

In [None]:
def get_loss_patched_mlp_neurons(prompts: list[str], model: HookedTransformer, mean_neuron_activations, patch_neurons, patch_layer=5, ablate_neurons=(609), ablate_layer=3, crop_context: None | tuple[int, int]=None):
    original_losses = []
    patched_losses = []
    for prompt in tqdm(prompts):
        
        original_loss, _, _, ablated_cache = haystack_utils.get_caches_single_prompt(prompt, model, mean_neuron_activations, ablate_neurons, ablate_layer, crop_context=crop_context)
        
        if crop_context is not None:
            tokens = model.to_tokens(prompt)[:, crop_context[0]:crop_context[1]]
        else:
            tokens = model.to_tokens(prompt)
        def patch_hook(value, hook):
            # Batch, pos, d_mlp
            value[:, :, patch_neurons] = ablated_cache[f'blocks.{patch_layer}.mlp.hook_post'][:, :, patch_neurons]
        
        with model.hooks(fwd_hooks=[(f'blocks.{patch_layer}.mlp.hook_post', patch_hook)]):
            patched_loss = model(tokens, return_type="loss")
        
        original_losses.append(original_loss)
        patched_losses.append(patched_loss.item())

    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

In [None]:
patch_neurons = sorted_neurons.cpu().tolist()[:5]
print("Patching neurons:", patch_neurons)
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=patch_neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3, crop_context=(0, 500))

Patching neurons: [273, 670, 395, 8, 164]


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.60 (+0.35%)


- Patching neurons with high activation difference has a small effect on the loss
- Activation doesn't necessarily correspond to changed log probs of the correct token

In [None]:
patch_neurons = sorted_neurons.cpu().tolist()[:20]
print("Patching neurons:", patch_neurons)
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=patch_neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3, crop_context=(0, 500))

Patching neurons: [273, 670, 395, 8, 164, 1209, 751, 389, 1661, 929, 1138, 1353, 245, 1731, 394, 584, 1479, 1042, 2023, 1874]


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.62 (+0.99%)


## Logit attribution for all layers

- Run model with and without ablating the French neuron, save both clean and ablated activations
- Run model again without ablation
- Simulate the effect of individual ablated components
- To simulate ablating a component:
    - Before the final layernorm, subtract the cached activation the component from the unablated run
    - Then add the activation of the ablated run
- This allows to compute the effect of running a component with corrupted activations without letting its output affecting later components
- However, the cached ablated activations of later components will still be influenced by earlier components 

In [None]:
def DLA(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, patched_component=8, crop_context: None | tuple[int, int]=None):
    # TODO think about layer normalization
    original_losses = []
    patched_losses = []
    for prompt in tqdm(prompts):

        neurons = torch.LongTensor(neurons)
        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
        
        if crop_context is not None:
            tokens = model.to_tokens(prompt)[:, crop_context[0]:crop_context[1]]
        else:
            tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # component, batch, pos, residual
        # TODO figure out if we need layer norm here
        original_per_layer_residual, original_labels = original_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)
        ablated_per_layer_residual, ablated_labels = ablated_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)

        # ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
        def swap_cache_hook(value, hook):
            # Batch, pos, residual
            value -= original_per_layer_residual[patched_component]
            value += ablated_per_layer_residual[patched_component]
        
        with model.hooks(fwd_hooks=[(f'blocks.5.hook_resid_post', swap_cache_hook)]):
            patched_loss = model(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        patched_losses.append(patched_loss.item())


    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)


In [None]:
# Layer 3 MLP logit attribution = direct effect of ablating the context neuron
# Logit attribution of later components when ablating the context neuron
# Not sure how clean this is - e.g. layer 5 MLP will get the accumulated effects of all previous layers from ablating the context neuron
component_names = ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']

components = []
losses = []
for later_component in range(8, 13):
    print(f"Component: {component_names[later_component]}")
    original_loss, patched_loss = DLA(kde_french, model, english_activations, patched_component=later_component, crop_context=(0, 500))
    if len(losses) == 0:
        components.append("Original loss")
        losses.append(original_loss)
    components.append(component_names[later_component])
    losses.append(patched_loss)

Component: 3_mlp_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.65 (+1.72%)
Component: 4_attn_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.60 (+0.41%)
Component: 4_mlp_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.64 (+1.30%)
Component: 5_attn_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.62 (+0.76%)
Component: 5_mlp_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.73 (+3.99%)


In [None]:
percent_increase = ((np.array(losses) - losses[0]) / losses[0]) * 100
line(losses, xlabel="Component", ylabel="Loss", title="Loss of individual patching individual components when ablating L3N609", xticks=components, width=800, hover_data=percent_increase.tolist())

- Total effect of ablating L3N609: 5.97% increase in loss
- Direct effect of ablating L3N609: 1.5% increase in loss
- Added direct effects of all later components and ~8%
- The French neuron must directly boost relevant words

- Does it make sense that the direct loss attribution of individual components sums to a higher total loss than the total loss of ablating the neuron
- Yes?: 
    - Components make similar mistakes so that ablating all of them leads to fewer loss
    - Later components receive the residual stream input of accumulated mistakes (not clean path patching)

## Check individual contributions of L5 neurons

Motivation: The output of MLP layer 5 caused the biggest increase in loss in the previous analysis. We want to find which neurons specifically are responsible. 

In [None]:
# Sanity check, compare output directions from the residual stream

def get_answer_token_logit_difference(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, layer_to_compare=5, crop_context: None | tuple[int, int]=None):
    # Computes output logit difference of the correct token between the outputs of a MLP layer with and without ablated neurons
    # TODO think about layer normalization
    differences = []
    for prompt in tqdm(prompts):
        model.reset_hooks()
        if crop_context is not None:
            tokens = model.to_tokens(prompt)[:, crop_context[0]:crop_context[1]]
        else:
            tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
        answer_tokens = tokens[:, 1:]

        # Shape batch pos residual
        mlp_post = original_cache[f'blocks.{layer_to_compare}.hook_mlp_out']
        # Shape batch pos-1 residual
        normalized_mlp_post = original_cache.apply_ln_to_stack(mlp_post)[:,:-1]
        correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # embed pos
        # Shape position
        original_unembedded = einops.einsum(normalized_mlp_post, correct_token_directions, 'batch pos residual, residual pos -> batch pos').squeeze(0)

        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
    
        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")
        
        mlp_post = ablated_cache[f'blocks.{layer_to_compare}.hook_mlp_out']
        normalized_mlp_post = ablated_cache.apply_ln_to_stack(mlp_post)[:,:-1]
        ablated_unembedded = einops.einsum(normalized_mlp_post, correct_token_directions, 'batch pos residual, residual pos -> batch pos').squeeze(0)

        # Shape: pos
        difference = (original_unembedded - ablated_unembedded).detach().cpu().mean().item()
        differences.append(difference)
    print("Mean difference:", np.mean(differences))

get_answer_token_logit_difference(kde_french, model, english_activations, crop_context=(10, 500))

  0%|          | 0/1007 [00:00<?, ?it/s]

Mean difference: 0.379232195358554


In [None]:
def get_neuron_logit_contribution(cache: ActivationCache, model: HookedTransformer, answer_tokens: Int[Tensor, "batch pos"], layer: int) -> Float[Tensor, "neuron pos"]:
    # Expexts cache from a single example, won't work on batched examples
    # Get per neuron output of MLP layer
    neuron_directions = cache.get_neuron_results(layer, neuron_slice=utils.Slice(input_slice=None), pos_slice=utils.Slice(input_slice=None))
    neuron_directions = einops.rearrange(neuron_directions, 'batch pos neuron residual -> neuron batch pos residual')
    # Apply ln - this is probably wrong? LN is never applied to individual MLP output directions, it is applied to the sum of MLP output directions
    scaled_neuron_directions = neuron_directions[:, 0, :-1, :]
    #scaled_neuron_directions = cache.apply_ln_to_stack(neuron_directions)[:, 0, :-1, :] # neuron pos embed
    # Unembed of correct answer tokens
    correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # embed pos
    # Neuron attribution to correct answer token by position
    unembedded = einops.einsum(scaled_neuron_directions, correct_token_directions, 'neuron pos residual, residual pos -> neuron pos') # neuron pos
    return unembedded

In [None]:
def MLP_attribution(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, layer_to_compare=5, crop_context: None | tuple[int, int]=None):
    # TODO think about layer normalization
    differences = torch.zeros(model.cfg.d_mlp)
    for prompt in tqdm(prompts):
        model.reset_hooks()
        if crop_context is not None:
            tokens = model.to_tokens(prompt)[:, crop_context[0]:crop_context[1]]
        else:
            tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
        answer_tokens = tokens[:, 1:]
        
        # Shape neuron pos
        original_unembedded = get_neuron_logit_contribution(original_cache, model, answer_tokens, layer=layer_to_compare)
        
        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
    
        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # Shape: neuron pos
        ablated_unembedded = get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=layer_to_compare)
        # Positive diff -> ablated has lower activation on correct token
        # Shape: neuron
        difference = (original_unembedded - ablated_unembedded).mean(1).detach().cpu()
        differences += difference
    
    mean_difference = differences / len(prompts)
    print("Total activation difference on correct token:", mean_difference.sum().item())
    sorted_differences, sorted_neurons = torch.topk(mean_difference, len(mean_difference), largest=True)
    return sorted_differences, sorted_neurons

haystack_utils.clean_cache()
layer_to_compare=5
sorted_differences, sorted_neurons = MLP_attribution(kde_french, model, english_activations, layer_to_compare=layer_to_compare, crop_context=(0, 500))

  0%|          | 0/1007 [00:00<?, ?it/s]

Total activation difference on correct token: 0.3919598460197449


In [None]:
line(sorted_differences.cpu().numpy()[:30], xlabel="Neuron", ylabel="Logit difference on correct token", xticks=sorted_neurons.cpu().tolist()[:30], title=f"Top positive neuron logit differences on correct tokens on layer {layer_to_compare}", width=800)
line(sorted_differences.cpu().numpy()[-30:], xlabel="Neuron", ylabel="Logit difference on correct token", xticks=sorted_neurons.cpu().tolist()[-30:], title=f"Top negative neuron logit differences on correct tokens on layer {layer_to_compare}", width=800)

### Test effect of patching top logit difference neurons on loss

In [None]:
# Sanity check: ablate whole layer - should lead to loss increase of 4%
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=[i for i in range(model.cfg.d_mlp)], patch_layer=5, ablate_neurons=[609], ablate_layer=3, crop_context=(0, 500))

  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.73 (+3.99%)


In [None]:
# Ablate top neurons
top_neurons = sorted_neurons.cpu().tolist()[:5]
print("Patched:", top_neurons)
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=top_neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3, crop_context=(0, 500))

Patched: [395, 670, 584, 1622, 1138]


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.60 (+0.17%)


In [None]:
# Sanity check: ablate bottom neurons - should lead to a much lower increase in loss than top neurons
worst_neurons = sorted_neurons.cpu().tolist()[-5:]
print("Patched:", worst_neurons)
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=worst_neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3, crop_context=(0, 500))

Patched: [1444, 651, 2023, 1257, 389]


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.60 (+0.37%)


In [None]:
# Ablate 200 most important neurons
neurons = sorted_neurons.cpu().tolist()[:100] + sorted_neurons.cpu().tolist()[-100:]
print("Patched:", neurons)
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3, crop_context=(0, 500))

Patched: [395, 670, 584, 1622, 1138, 273, 493, 1825, 1487, 767, 1658, 4, 1661, 1020, 347, 1479, 1672, 315, 1283, 1644, 1991, 1874, 759, 472, 586, 44, 453, 1702, 1844, 703, 338, 1104, 164, 68, 1875, 1915, 1626, 1884, 1140, 731, 1859, 1177, 405, 1773, 1752, 499, 1366, 1594, 549, 1592, 541, 138, 968, 1546, 91, 1524, 1761, 831, 314, 1550, 1042, 96, 724, 1225, 1980, 917, 755, 1733, 1660, 962, 1968, 1262, 447, 1308, 1441, 1712, 860, 1547, 1533, 254, 32, 548, 2013, 1320, 1518, 1463, 1095, 1250, 879, 1362, 904, 166, 393, 921, 542, 1813, 220, 486, 141, 276, 1763, 1638, 728, 1313, 1685, 1334, 505, 984, 1823, 209, 1987, 149, 346, 1118, 892, 142, 5, 1631, 1579, 421, 198, 930, 1666, 1964, 1744, 1412, 599, 11, 1599, 709, 498, 1955, 1446, 1073, 1999, 1706, 133, 796, 1628, 1462, 358, 1102, 547, 1888, 1510, 1324, 1293, 1315, 451, 532, 966, 360, 377, 955, 1382, 983, 978, 606, 1342, 10, 608, 155, 1328, 1306, 1336, 249, 1456, 798, 245, 291, 1586, 1934, 1536, 869, 1716, 1731, 469, 1475, 1027, 100, 414, 183

  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.59, patched loss: 3.66 (+1.91%)


- Patching all neurons leads to the expected increase in loss of 4% (in line with our layer ablation results)
- Patching the top 5 neurons from our analysis should result in a high-ish loss increase but we get 0.18%
- Patching the bottom 5 neurons should result in a very low increase in loss (technically it should be negative looking at our attribution curve) but we get 0.37%

## DLA for affected later neurons

## Attention Heads - L4 and L5

Plot the difference in attention patterns on French data with and without L3N609 disabled

In [None]:
neurons = [609]

def ablate_neuron_hook(value, hook):
    value[:, :, neurons] = english_activations[neurons]
    return value


def get_attention_pattern_activations(original_cache: ActivationCache, ablated_cache: ActivationCache, layer: int, n_pos=-1):
    block_name = f'blocks.{layer}.attn.hook_pattern'
    original_activations = original_cache[block_name]
    ablated_activations = ablated_cache[block_name]

    difference = original_activations.mean((0)) - ablated_activations.mean((0))

    if n_pos == -1:
        n_pos = difference.shape[1]

    for pattern in difference:
        imshow(pattern[:n_pos, :n_pos].cpu(), title="Difference between attention patterns with and without French neuron set to mean \"non-French data\" value <br> French dataset <br> Blue means the activation is more prevalent when French neuron enabled")

In [None]:
tokens = model.to_tokens(kde_french[:5])
original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
with model.hooks(fwd_hooks=[(f'blocks.3.mlp.hook_post', ablate_neuron_hook)]):
    ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

get_attention_pattern_activations(original_cache, ablated_cache, layer=4, n_pos=200)

In [None]:
get_attention_pattern_activations(original_cache, ablated_cache, layer=5, n_pos=200)

These heads all output into the unembed directly. We should see what their output values align with in the unembed.

Theory: the attention heads that attend to the BOS token when French neuron is enabled are language specific heads.
How to disprove: if we ablate these heads on French text and the loss increases significantly, the heads are relevant to French text.

BOS when French neuron enabled: L4H1, L4H2, L4H3
BOS when French neuron disabled: L4H0

L5H1 is some kind of previous tokens head that attends to different recent tokens in French vs. English. POS dependent?
L5H2 is one of a couple of heads with vertical French activation stripes.
L5H7 shows a clear pattern of more self-attention when L3N609 is enabled and more previous token attention when it's disabled"

In [None]:
from collections import defaultdict

bos_french = defaultdict(list, {
    4: [1, 2, 3]
})
bos_non_french = defaultdict(list, {
  4: [0]
})

# Ablate French BOS heads and see if the loss increases much on French text",

act_label_l4 = f'blocks.4.attn.hook_pattern'
def disable_head_hook(value, hook):
  print(value.shape)  # [head, query_pos, key_pos]\n",
  value[bos_french[4], :, :] = 0.0 # mean head values on English text? Or all text?",
  return value


def compare_loss_with_ablated(data, ablate_hook):
  average_original_loss = 0
  average_ablated_loss = 0

  for sample in data:
    tokens = model.to_tokens(sample)
    average_original_loss += model(tokens, return_type="loss")
    average_ablated_loss +=  model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(act_label_l4, ablate_hook)])

  average_original_loss /= len(data)
  average_ablated_loss /= len(data)

  print(f"Full model loss: {average_original_loss:.6f}")
  print(f"Ablated MLP layer loss: {average_ablated_loss:.6f}")
  print(f"% increase: {((average_ablated_loss - average_original_loss) / average_original_loss).item() * 100:.6f}")


# Loss from disabling French neuron on non-French text - if this is high it disproves that the neuron is only relevant to non-French text (although it could be \n",
# due to French loanwords)
# It's very low
compare_loss_with_ablated(non_french, disable_head_hook)