In [64]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go
from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils

%reload_ext autoreload
%autoreload 2

In [6]:
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


In [7]:
kde_french = load_txt_data("kde4_french.txt")
kde_english = load_txt_data("kde4_english.txt")

kde4_french.txt: Loaded 1007 examples with 505 to 5345 characters each.
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.


In [8]:
french_activations = get_mlp_activations(kde_french, 3, model, num_prompts=100, mean=True)
english_activations = get_mlp_activations(kde_english, 3, model, num_prompts=100, mean=True)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [42]:
def evaluate_model(prompts: list[str], model: HookedTransformer, batch_size, crop_context=None):
    losses = []
    for batch_index in range(0, len(prompts), batch_size):
        if batch_index+batch_size >= len(prompts):
            batch = prompts[batch_index:]
        else:
            batch = prompts[batch_index:batch_index+batch_size]
        tokens = model.to_tokens(batch)
        if crop_context is not None:
            tokens = tokens[:, crop_context[0]:crop_context[1]]
        loss = model(tokens, return_type="loss")
        losses.append(loss.item())
    return losses

examples = kde_french[:20]
print("Context lengths:", [model.to_tokens(example).shape[1] for example in examples])

losses_batched = evaluate_model(examples, model, batch_size=4)
losses_single = evaluate_model(examples, model, batch_size=1)
print(f"Mean loss (batch size = 4): {np.mean(losses_batched):.4f}")
print(f"Mean loss (batch size = 1): {np.mean(losses_single):.4f}")

losses_batched = evaluate_model(examples, model, batch_size=4, crop_context=(10, 150))
losses_single = evaluate_model(examples, model, batch_size=1, crop_context=(10, 150))
print(f"Mean loss (batch size = 4, crop = (10:150)): {np.mean(losses_batched):.4f}")
print(f"Mean loss (batch size = 1, crop = (10:150)): {np.mean(losses_single):.4f}")


Context lengths: [368, 326, 195, 177, 173, 529, 339, 198, 213, 168, 230, 198, 184, 134, 215, 188, 793, 180, 190, 197]
Mean loss (batch size = 4): 5.2905
Mean loss (batch size = 1): 3.5187
Mean loss (batch size = 4, crop = (10:150)): 3.6561
Mean loss (batch size = 1, crop = (10:150)): 3.6297


In [43]:
def get_ablated_mlp_difference(prompts: list[str], model: HookedTransformer, neurons: list[int], layer_to_ablate: int, layer_to_cache: int, mean_neuron_activations: Float[Tensor, "d_mlp"]):
    original_losses = []
    ablated_losses = []
    mean_differences = []
    for prompt in tqdm(prompts):
        neurons = torch.LongTensor(neurons)

        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
        
        tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        ablated_losses.append(ablated_loss.item())

        block_name = f'blocks.{layer_to_cache}.mlp.hook_post'
        original_activations = original_cache[block_name][:, 1:]
        ablated_activations = ablated_cache[block_name][:, 1:]
        mean_difference = original_activations.mean((0, 1)) - ablated_activations.mean((0, 1))
        mean_differences.append(mean_difference)
        
        
    print(f"Original loss: {np.mean(original_losses):.2f}, ablated loss: {np.mean(ablated_losses):.2f} (+{((np.mean(ablated_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return torch.stack(mean_differences).mean(0)

In [91]:
layer_to_cache = 5
difference = get_ablated_mlp_difference(kde_french, model, neurons=[609], layer_to_ablate=3, layer_to_cache=layer_to_cache, mean_neuron_activations=english_activations)
sorted_differences, sorted_neurons = torch.topk(difference.abs(), len(difference), largest=True)


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, ablated loss: 3.81 (+6.23%)


In [96]:
px.histogram(difference.cpu().numpy(), title=f"Difference in layer {layer_to_cache} neuron activations between original and ablated context neuron", width=800)

In [97]:
def imshow(tensor, renderer=None, label_neurons=False, **kwargs):
    preset_kwargs = {
        "color_continuous_midpoint": 0.0,
        "color_continuous_scale": "RdBu"
    }

    fig = px.imshow(utils.to_numpy(tensor), **{**preset_kwargs, **kwargs})
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    if label_neurons:
        fig.update(data=[{'customdata': np.arange(len(tensor.flatten())).reshape(tensor.shape[0], -1), 'hovertemplate': "Difference: %{z:.4f}<br>Neuron: %{customdata}"}])
    fig.show(renderer=renderer)

imshow(difference.view(32, -1), label_neurons=True, title="""Difference in activations between original and ablated model at MLP layer 5 <br> rearranged from a 1D vector into a grid""", width=800)

In [98]:
def line(x, xlabel="", ylabel="", title="", xticks=None, width=800):
    fig = px.line(x, title=title)
    fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, width=width)
    if xticks != None:
        fig.update_layout(
            xaxis = dict(
            tickmode = 'array',
            tickvals = [i for i in range(len(xticks))],
            ticktext = xticks
            )
        )
    return fig

# Plotting all differences seems to break Jupyter
line(sorted_differences.cpu().numpy()[:100], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[:100], title="Top 100 absolute neuron differences", width=1400)

## DLA for the French neuron

In [56]:
def DLA(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, patched_component=8):
    # TODO think about layer normalization
    original_losses = []
    patched_losses = []
    for prompt in tqdm(prompts):

        neurons = torch.LongTensor(neurons)
        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
        
        tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # component, batch, pos, residual
        # TODO figure out if we need layer norm here
        original_per_layer_residual, original_labels = original_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)
        ablated_per_layer_residual, ablated_labels = ablated_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)

        # ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
        def swap_cache_hook(value, hook):
            # Batch, pos, residual
            value -= original_per_layer_residual[patched_component]
            value += ablated_per_layer_residual[patched_component]
        
        with model.hooks(fwd_hooks=[(f'blocks.5.hook_resid_post', swap_cache_hook)]):
            patched_loss = model(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        patched_losses.append(patched_loss.item())


    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

_, _ = DLA(kde_french, model, english_activations)

  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.65 (+1.72%)


In [57]:
component_names = ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
for later_component in range(9, 13):#
    print(f"Component: {component_names[later_component]}")
    _, _ = DLA(kde_french, model, english_activations, patched_component=later_component)

Component: 4_attn_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.60 (+0.41%)
Component: 4_mlp_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.63 (+1.30%)
Component: 5_attn_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.61 (+0.76%)
Component: 5_mlp_out


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.73 (+4.00%)


In [58]:
total_effect = 5.97
with_ln = 1.5 + 0.29 + 1.29 + 0.73 + 3.87
without_ln = 1.72 + 0.41 + 1.3 + 0.76 + 4
print(with_ln, without_ln, total_effect)

7.68 8.19 5.97


- Total effect of ablating L3N609: 5.97% increase in loss
- Direct effect of ablating L3N609: 1.5% increase in loss
- Added direct effects of all later components and ~8%
- In total we have a loss increase of 6% when ablating the French neuron
- The direct effect of ablating the French neuron is 1.5%
- The French neuron must directly boost relevant words

- Does it make sense that the direct loss attribution of individual components sums to a higher total loss than the total loss of ablating the neuron
- Theory: components make similar mistakes so that ablating all of them leads to fewer loss

## Check individual contributions of L5 neurons

In [59]:
def get_neuron_logit_contribution(cache, model, answer_tokens, layer):
    neuron_directions = cache.get_neuron_results(layer, neuron_slice=utils.Slice(input_slice=None), pos_slice=utils.Slice(input_slice=None))
    neuron_directions = einops.rearrange(neuron_directions, 'batch pos neuron residual -> neuron batch pos residual')
    # Apply ln, get rid of last position
    scaled_neuron_directions = cache.apply_ln_to_stack(neuron_directions)[:, 0, :-1, :] # neuron pos embed
    # Unembed of correct answer tokens
    correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # embed pos
    # Neuron attribution to correct answer token by position
    unembedded = einops.einsum(scaled_neuron_directions, correct_token_directions, 'neuron pos residual, residual pos -> neuron pos') # neuron pos
    return unembedded

In [60]:
def get_answer_token_logit_difference(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, layer_to_compare=5, max_context_length=300):
    # TODO think about layer normalization
    differences = []
    for prompt in tqdm(prompts):
        model.reset_hooks()
        tokens = model.to_tokens(prompt)[:, :500]
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
        answer_tokens = tokens[:, 1:]

        # Shape batch pos residual
        mlp_post = original_cache[f'blocks.{layer_to_compare}.hook_mlp_out']
        # Shape batch pos-1 residual
        normalized_mlp_post = original_cache.apply_ln_to_stack(mlp_post)[:,:-1]
        correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # embed pos
        # Shape position
        original_unembedded = einops.einsum(normalized_mlp_post, correct_token_directions, 'batch pos residual, residual pos -> batch pos').squeeze(0)

        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
    
        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")
        
        mlp_post = ablated_cache[f'blocks.{layer_to_compare}.hook_mlp_out']
        normalized_mlp_post = ablated_cache.apply_ln_to_stack(mlp_post)[:,:-1]
        ablated_unembedded = einops.einsum(normalized_mlp_post, correct_token_directions, 'batch pos residual, residual pos -> batch pos').squeeze(0)

        # Shape: pos
        difference = (original_unembedded - ablated_unembedded).detach().cpu().mean().item()
        differences.append(difference)
    print("Mean difference:", np.mean(differences))


get_answer_token_logit_difference(kde_french, model, english_activations)

  0%|          | 0/1007 [00:00<?, ?it/s]

Mean difference: 0.3675940532882214


In [61]:
def MLP_attribution(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, layer_to_compare=5, max_context_length=300):
    # TODO think about layer normalization
    differences = torch.zeros(model.cfg.d_mlp)
    for prompt in tqdm(prompts):
        model.reset_hooks()
        tokens = model.to_tokens(prompt)[:, :500]
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
        answer_tokens = tokens[:, 1:]
        
        # Shape neuron pos
        original_unembedded = get_neuron_logit_contribution(original_cache, model, answer_tokens, layer=layer_to_compare)
        
        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
    
        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # Shape: neuron pos
        ablated_unembedded = get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=layer_to_compare)
        # Positive diff -> ablated has lower activation on correct token
        # Shape: neuron
        difference = (original_unembedded - ablated_unembedded).mean(1).detach().cpu()
        differences += difference
    
    mean_difference = differences / len(prompts)
    print("Total activation difference on correct token:", mean_difference.sum().item())
    sorted_differences, sorted_neurons = torch.topk(mean_difference, len(mean_difference), largest=True)
    return sorted_differences, sorted_neurons

haystack_utils.clean_cache()
layer_to_compare=5
sorted_differences, sorted_neurons = MLP_attribution(kde_french, model, english_activations, layer_to_compare=layer_to_compare)

  0%|          | 0/1007 [00:00<?, ?it/s]

Total activation difference on correct token: 0.3723316788673401


In [67]:
line(sorted_differences.cpu().numpy()[:30], xlabel="Neuron", ylabel="Logit difference on correct token", xticks=sorted_neurons.cpu().tolist()[:30], title=f"Top positive neuron logit differences on correct tokens on layer {layer_to_compare}", width=600)
line(sorted_differences.cpu().numpy()[-30:], xlabel="Neuron", ylabel="Logit difference on correct token", xticks=sorted_neurons.cpu().tolist()[-30:], title=f"Top negative neuron logit differences on correct tokens on layer {layer_to_compare}", width=600)

### Test effect of patching top logit difference neurons on loss

In [87]:
def get_loss_patched_mlp_neurons(prompts: list[str], model: HookedTransformer, mean_neuron_activations, patch_neurons, patch_layer=5, ablate_neurons=(609), ablate_layer=3):
    original_losses = []
    patched_losses = []
    for prompt in tqdm(prompts):
        
        original_loss, _, _, ablated_cache = haystack_utils.get_caches_single_prompt(prompt, model, mean_neuron_activations, ablate_neurons, ablate_layer)
        
        tokens = model.to_tokens(prompt)
        def patch_hook(value, hook):
            # Batch, pos, d_mlp
            value[:, :, patch_neurons] = ablated_cache[f'blocks.{patch_layer}.mlp.hook_post'][:, :, patch_neurons]
        
        with model.hooks(fwd_hooks=[(f'blocks.{patch_layer}.mlp.hook_post', patch_hook)]):
            patched_loss = model(tokens, return_type="loss")

        original_losses.append(original_loss)
        patched_losses.append(patched_loss.item())

    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

In [88]:
# Sanity check: ablate whole layer - should lead to loss increase of 4%
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=[i for i in range(model.cfg.d_mlp)], patch_layer=5, ablate_neurons=[609], ablate_layer=3)

  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.73 (+4.00%)


In [89]:
# Ablate top neurons
top_neurons = sorted_neurons.cpu().tolist()[:5]
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=top_neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3)

  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.59 (+0.18%)


In [90]:
# Sanity check: ablate bottom neurons - should lead to a much lower increase in loss than top neurons
worst_neurons = sorted_neurons.cpu().tolist()[-5:]
_, _ = get_loss_patched_mlp_neurons(kde_french, model, english_activations, patch_neurons=worst_neurons, patch_layer=5, ablate_neurons=[609], ablate_layer=3)

  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.60 (+0.37%)


- Patching all neurons leads to the expected increase in loss of 4% (in line with our layer ablation results)
- Patching the top 5 neurons from our analysis should result in a high-ish loss increase but we get 0.18%
- Patching the bottom 5 neurons should result in a very low increase in loss (technically it should be negative looking at our attribution curve) but we get 0.37%

## DLA for affected later neurons

In [12]:
# def get_attention_pattern_activations(original_cache: ActivationCache, ablated_cache: ActivationCache, layer: int):
#     block_name = f'blocks.{layer}.attn.hook_pattern'
#     original_activations = original_cache[block_name]
#     ablated_activations = ablated_cache[block_name]

#     difference = original_activations.mean((0)) - ablated_activations.mean((0))

#     for pattern in difference:
#         imshow(pattern.cpu(), title="Difference between attention patterns with and without French neuron set to mean \"non-French data\" value <br> French dataset <br> Blue means the activation is more prevalent when French neuron enabled")
        

#get_attention_pattern_activations(original_cache, ablated_cache, layer=4)
#get_attention_pattern_activations(original_cache, ablated_cache, layer=5)