In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go
from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
haystack_utils.clean_cache()

In [2]:
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


In [3]:
kde_french = load_txt_data("kde4_french.txt")
kde_english = load_txt_data("kde4_english.txt")

kde4_french.txt: Loaded 1007 examples with 505 to 5345 characters each.
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.


In [4]:
french_activations = get_mlp_activations(kde_french, 3, model, num_prompts=100, mean=True)
english_activations = get_mlp_activations(kde_english, 3, model, num_prompts=100, mean=True)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [6]:
french_activations[609], english_activations[609]

(tensor(2.4772, device='cuda:0'), tensor(-0.0573, device='cuda:0'))

In [7]:
def evaluate_model(prompts: list[str], model: HookedTransformer, batch_size=16):

    losses = []
    for batch_index in tqdm(range(0, len(prompts), batch_size)):
        if batch_index+batch_size >= len(prompts):
            batch = prompts[batch_index:]
        else:
            batch = prompts[batch_index:batch_index+batch_size]
        tokens = model.to_tokens(batch)
        loss = model(tokens, return_type="loss")
        losses.append(loss.item())
    print(f"Mean loss: {np.mean(losses):.4f}")
    return losses

In [8]:
french_losses = evaluate_model(kde_french, model, batch_size=1)
px.histogram(french_losses)

  0%|          | 0/1007 [00:00<?, ?it/s]

Mean loss: 3.5839


In [9]:
evaluate_model(kde_english, model, batch_size=1)

  0%|          | 0/1007 [00:00<?, ?it/s]

Mean loss: 3.8768


[3.4576416015625,
 3.7964305877685547,
 3.4455814361572266,
 3.656831741333008,
 4.156656265258789,
 4.164270401000977,
 4.104758262634277,
 3.6803717613220215,
 3.7099764347076416,
 3.766713857650757,
 3.441838026046753,
 4.447780609130859,
 4.224307060241699,
 4.4418487548828125,
 3.6823410987854004,
 3.6193292140960693,
 3.58012056350708,
 2.9261581897735596,
 4.331971168518066,
 3.5921971797943115,
 3.3816676139831543,
 4.635315418243408,
 4.198049545288086,
 3.3489811420440674,
 4.332030296325684,
 3.0417544841766357,
 4.1146697998046875,
 3.659585952758789,
 3.8155677318573,
 4.205376625061035,
 3.7922539710998535,
 3.398430585861206,
 3.349548816680908,
 3.7119569778442383,
 4.395923614501953,
 3.5259158611297607,
 3.7977607250213623,
 3.7684326171875,
 3.8227269649505615,
 4.265439987182617,
 4.286070823669434,
 3.8989200592041016,
 4.826107501983643,
 4.1443400382995605,
 4.887277126312256,
 4.597630977630615,
 3.6181068420410156,
 4.283036708831787,
 3.617647409439087,
 3.735

In [10]:
def get_ablated_mlp_difference(prompts: list[str], model: HookedTransformer, neurons: list[int], layer_to_ablate: int, layer_to_cache: int, mean_neuron_activations: Float[Tensor, "d_mlp"]):
    original_losses = []
    ablated_losses = []
    mean_differences = []
    for prompt in tqdm(prompts):
        neurons = torch.LongTensor(neurons)

        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
        
        tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        ablated_losses.append(ablated_loss.item())

        block_name = f'blocks.{layer_to_cache}.mlp.hook_post'
        original_activations = original_cache[block_name][:, 1:]
        ablated_activations = ablated_cache[block_name][:, 1:]
        mean_difference = original_activations.mean((0, 1)) - ablated_activations.mean((0, 1))
        mean_differences.append(mean_difference)
        
        
    print(f"Original loss: {np.mean(original_losses):.2f}, ablated loss: {np.mean(ablated_losses):.2f} (+{((np.mean(ablated_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return torch.stack(mean_differences).mean(0)

In [11]:
difference = get_ablated_mlp_difference(kde_french[:10], model, neurons=[609], layer_to_ablate=3, layer_to_cache=5, mean_neuron_activations=english_activations)


  0%|          | 0/10 [00:00<?, ?it/s]

Original loss: 3.49, ablated loss: 3.70 (+5.97%)


In [12]:
px.histogram(difference.cpu().numpy(), title="Difference in neuron activations between original and ablated model", width=800)

In [13]:
# difference = get_mlp_activation_difference(original_cache, ablated_cache, layer=4)
# px.histogram(difference.cpu().numpy(), title="Difference in activations between original and ablated model", width=800)

In [14]:
def imshow(tensor, renderer=None, label_neurons=False, **kwargs):
    preset_kwargs = {
        "color_continuous_midpoint": 0.0,
        "color_continuous_scale": "RdBu"
    }

    fig = px.imshow(utils.to_numpy(tensor), **{**preset_kwargs, **kwargs})
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    if label_neurons:
        fig.update(data=[{'customdata': np.arange(len(tensor.flatten())).reshape(tensor.shape[0], -1), 'hovertemplate': "Difference: %{z:.4f}<br>Neuron: %{customdata}"}])
    fig.show(renderer=renderer)

imshow(difference.view(32, -1), label_neurons=True, title="""Difference in activations between original and ablated model at MLP layer 5 <br> rearranged from a 1D vector into a grid""", width=800)

In [15]:
sorted_differences, sorted_neurons = torch.topk(difference.abs(), len(difference), largest=True)

In [16]:
def line(x, xlabel="", ylabel="", title="", xticks=None, width=800):
    fig = px.line(x, title=title)
    fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, width=width)
    if xticks != None:
        fig.update_layout(
            xaxis = dict(
            tickmode = 'array',
            tickvals = [i for i in range(len(xticks))],
            ticktext = xticks
            )
        )
    fig.show()

# Plotting all differences seems to break Jupyter
line(sorted_differences.cpu().numpy()[:100], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[:100], title="Top absolute neuron differences", width=1400)

# DLA for the French neuron

In [28]:
def DLA(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, patched_component=8):
    # TODO think about layer normalization
    original_losses = []
    patched_losses = []
    for prompt in tqdm(prompts):

        neurons = torch.LongTensor(neurons)
        def ablate_neuron_hook(value, hook):
            value[:, :, neurons] = mean_neuron_activations[neurons]
            return value
        
        tokens = model.to_tokens(prompt)
        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")

        with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # component, batch, pos, residual
        # TODO figure out if we need layer norm here
        original_per_layer_residual, original_labels = original_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)
        ablated_per_layer_residual, ablated_labels = ablated_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)

        # embed l0A l0M l1A l1M l2A l2M l3A l3M
        def swap_cache_hook(value, hook):
            # Batch, pos, residual
            value -= original_per_layer_residual[patched_component]
            value += ablated_per_layer_residual[patched_component]
        
        with model.hooks(fwd_hooks=[(f'blocks.5.hook_resid_post', swap_cache_hook)]):
            patched_loss = model(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        patched_losses.append(patched_loss.item())


    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

DLA(kde_french, model, english_activations)

  0%|          | 0/1007 [00:00<?, ?it/s]

['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
Original loss: 3.30, patched loss: 3.39 (+2.62%)


(3.303405523300171, 3.3899996280670166)

In [None]:
labels = ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']

In [26]:
for later_component in range(9, 13):#
    print(f"Later component: {later_component}")
    _, _ = DLA(kde_french, model, english_activations, patched_component=later_component)

Later component: 9


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.60 (+0.41%)
Later component: 10


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.63 (+1.30%)
Later component: 11


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.61 (+0.76%)
Later component: 12


  0%|          | 0/1007 [00:00<?, ?it/s]

Original loss: 3.58, patched loss: 3.73 (+4.00%)


In [29]:
total_effect = 5.97
with_ln = 1.5 + 0.29 + 1.29 + 0.73 + 3.87
without_ln = 1.72 + 0.41 + 1.3 + 0.76 + 4
print(with_ln, without_ln, total_effect)

7.68 8.19 5.97


- Total effect of ablating L3N609: 5.97% increase in loss
- Direct effect of ablating L3N609: 1.5% increase in loss
- Added direct effects of all later components and ~8%
- In total we have a loss increase of 6% when ablating the French neuron
- The direct effect of ablating the French neuron is 1.5%
- The French neuron must directly boost relevant words

- Does it make sense that the direct loss attribution of individual components sums to a higher total loss than the total loss of ablating the neuron
- Theory: components make similar mistakes so that ablating all of them leads to fewer loss

## Check individual contributions of L5 neurons

In [5]:
def get_neuron_logit_contribution(cache, model, answer_tokens, layer):
    neuron_directions = cache.get_neuron_results(layer, neuron_slice=utils.Slice(input_slice=None), pos_slice=utils.Slice(input_slice=None))
    neuron_directions = einops.rearrange(neuron_directions, 'batch pos neuron residual -> neuron batch pos residual')
    # Apply ln, get rid of last position
    scaled_neuron_directions = cache.apply_ln_to_stack(neuron_directions)[:, 0, :-1, :] # neuron pos embed
    # Unembed of correct answer tokens
    correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # embed pos
    # Neuron attribution to correct answer token by position
    unembedded = einops.einsum(scaled_neuron_directions, correct_token_directions, 'neuron pos residual, residual pos -> neuron pos') # neuron pos
    return unembedded

In [6]:
def MLP_attribution(prompts: list[str], model: HookedTransformer, mean_neuron_activations, neurons = [609], layer_to_ablate=3, layer_to_compare=5):
    # TODO think about layer normalization
    with torch.no_grad():
        differences = torch.zeros(model.cfg.d_mlp)
        for prompt in tqdm(prompts):
            tokens = model.to_tokens(prompt)
            original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
            answer_tokens = tokens[:, 1:]
            
            original_unembedded = get_neuron_logit_contribution(original_cache, model, answer_tokens, layer=layer_to_compare)

            def ablate_neuron_hook(value, hook):
                value[:, :, neurons] = mean_neuron_activations[neurons]
                return value
        
            with model.hooks(fwd_hooks=[(f'blocks.{layer_to_ablate}.mlp.hook_post', ablate_neuron_hook)]):
                ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

            ablated_unembedded = get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=layer_to_compare)

            # Positive diff -> ablated has lower activation on correct token
            difference = (original_unembedded - ablated_unembedded).mean(1).cpu()
            differences += difference
    
    mean_difference = differences / len(prompts)#torch.stack(differences).mean(0)
    print("Total mean activation difference on correct token:", mean_difference.mean().item())
    sorted_differences, sorted_neurons = torch.topk(mean_difference, len(mean_difference), largest=True)
    line(sorted_differences.cpu().numpy()[:30], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[:30], title=f"Top neuron correct logit differences on correct tokens on layer {layer_to_compare}", width=1400)
    line(sorted_differences.cpu().numpy()[-30:], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[-30:], title=f"Top negative neuron correct logit differences on correct tokens on layer {layer_to_compare}", width=1400)
        
        

haystack_utils.clean_cache()
MLP_attribution(kde_french, model, english_activations)

  0%|          | 0/1007 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 6.18 GiB (GPU 0; 23.69 GiB total capacity; 15.16 GiB already allocated; 544.69 MiB free; 20.50 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
def get_log_probs(logits, tokens):
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_predicted_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_predicted_tokens

In [61]:
model.W_U.shape

torch.Size([512, 50304])

In [50]:
2048*5

10240

In [27]:
# neurons @ mlp_out - ln - @ w_u
einsum.einsum('batch pos d_model, d_model d_vocab -> \
        batch pos d_vocab', normalized_resid_final, model.W_U) + model.b_U

7.68 8.19 5.97


7.68

## DLA for affected later neurons

In [12]:
# def get_attention_pattern_activations(original_cache: ActivationCache, ablated_cache: ActivationCache, layer: int):
#     block_name = f'blocks.{layer}.attn.hook_pattern'
#     original_activations = original_cache[block_name]
#     ablated_activations = ablated_cache[block_name]

#     difference = original_activations.mean((0)) - ablated_activations.mean((0))

#     for pattern in difference:
#         imshow(pattern.cpu(), title="Difference between attention patterns with and without French neuron set to mean \"non-French data\" value <br> French dataset <br> Blue means the activation is more prevalent when French neuron enabled")
        

#get_attention_pattern_activations(original_cache, ablated_cache, layer=4)
#get_attention_pattern_activations(original_cache, ablated_cache, layer=5)