## Setup

In [11]:
import torch
import numpy as np
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
import plotly.express as px
from tqdm.auto import tqdm
import plotly.graph_objects as go

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [12]:
haystack_utils.clean_cache()
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)


LOG_PROB_THRESHOLD = -7
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
L4_NEURONS_TO_ABLATE = [482, 326]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def deactivate_neurons_l4_hook(value, hook):
    value[:, :, L4_NEURONS_TO_ABLATE] = english_activations[4][:, L4_NEURONS_TO_ABLATE].mean(0)
    return value
    
deactivate_l3_l4_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook), (f'blocks.4.mlp.hook_post', deactivate_neurons_l4_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.
wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.


100%|██████████| 200/200 [00:02<00:00, 73.58it/s]
100%|██████████| 200/200 [00:02<00:00, 73.28it/s]
100%|██████████| 200/200 [00:02<00:00, 73.60it/s]
100%|██████████| 200/200 [00:02<00:00, 73.14it/s]
100%|██████████| 200/200 [00:02<00:00, 69.83it/s]
100%|██████████| 200/200 [00:02<00:00, 73.43it/s]


In [13]:
def get_pos_loss_diff(prompt: str, model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]], plot_hist=False, use_activate_hook=False, debug_log=True):
    tokens = model.to_tokens(prompt)
    if use_activate_hook:
        original_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=activate_neurons_hook, loss_per_token=True)
    else:
        original_loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=deactivate_neurons_hook, loss_per_token=True)
    
    # Positive difference = loss increase due to ablation
    loss_difference = (ablated_loss - original_loss).flatten()

    if debug_log:
        print(f"Unablated loss: {original_loss.flatten()}")
        print(f"Ablated loss: {ablated_loss.flatten()}")
        print(f"Loss difference: {loss_difference}")

    if plot_hist:
        fig = px.histogram(loss_difference.flatten().cpu().numpy(), title="Loss difference due to ablation per position")
        fig.show()
    return loss_difference

def get_high_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        loss_difference = get_pos_loss_diff(prompt, model, activate_neurons_hook, deactivate_neurons_hook)
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs


In [14]:
def get_top_differences_at_position(prompt: str, model: HookedTransformer, position: int, top_k=20, mode="full"):
    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(tokens)
    # Logprobs instead of logits
    original_logits = model(tokens, return_type="logits")
    if mode=="direct":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_frozen_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    elif mode=="indirect":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_ablated_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    else:
        assert mode=="full"
        ablated_logits = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=deactivate_neurons_fwd_hooks)
    original_logprob = original_logits.log_softmax(dim=-1)
    ablated_logprob = ablated_logits.log_softmax(dim=-1)

    # Positive difference = the German neuron makes the token more likely
    # Negative difference = the German neuron makes the token less likely
    logprob_differences = original_logprob - ablated_logprob
    logit_differences = original_logits - ablated_logits

    print("Prompt:", prompt)
    print(f"Differences for predicting: {str_tokens[position]} -> {str_tokens[position+1]}")

    low_log_prob = torch.argwhere(((original_logprob[0, position, :] <= LOG_PROB_THRESHOLD) & (ablated_logprob[0, position, :] <= LOG_PROB_THRESHOLD))).flatten()
    ignore_tokens = torch.cat([low_log_prob, all_ignore]).unique()
    
    top_original_logprobs, top_original_idx = haystack_utils.top_k_with_exclude(original_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_original_ablated_logprobs = ablated_logprob[0, position, top_original_idx]
    top_ablated_logprobs, top_ablated_idx = haystack_utils.top_k_with_exclude(ablated_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_ablated_original_logprobs = original_logprob[0, position, top_ablated_idx]

    top_original_values = logprob_differences[0, position, top_original_idx]
    top_ablated_values = logprob_differences[0, position, top_ablated_idx]
    top_original_logit_diff = logit_differences[0, position, top_original_idx]
    top_ablated_logit_diff = logit_differences[0, position, top_ablated_idx]
    print("Top predictions with German neuron active (unablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_original_idx), top_original_values.cpu().tolist(), max_value=5, original_log_probs=top_original_logprobs.cpu().tolist(), ablated_log_probs=top_original_ablated_logprobs.cpu().tolist(), logit_difference=top_original_logit_diff.cpu().tolist())
    print("Top predictions with German neuron disabled (ablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_ablated_idx), top_ablated_values.cpu().tolist(), max_value=5, original_log_probs=top_ablated_original_logprobs.cpu().tolist(), ablated_log_probs=top_ablated_logprobs.cpu().tolist(), logit_difference=top_ablated_logit_diff.cpu().tolist())

    top_boosts, top_boosted_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_boost_original_logprob = original_logprob[0, position, top_boosted_idx]
    top_boost_ablated_logprob = ablated_logprob[0, position, top_boosted_idx]
    top_reduced, top_reduced_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, largest=False, exclude=ignore_tokens)
    top_reduced_original_logprob = original_logprob[0, position, top_reduced_idx]
    top_reduced_ablated_logprob = ablated_logprob[0, position, top_reduced_idx]
    print("Top boosted tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_boosted_idx), top_boosts.cpu().tolist(), max_value=5, original_log_probs=top_boost_original_logprob.cpu().tolist(), ablated_log_probs=top_boost_ablated_logprob.cpu().tolist())
    print("Top reduced tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_reduced_idx), top_reduced.cpu().tolist(), max_value=5, original_log_probs=top_reduced_original_logprob.cpu().tolist(), ablated_log_probs=top_reduced_ablated_logprob.cpu().tolist())

In [15]:
def show_token_loss(prompt: str, model: HookedTransformer, max_value=None, mode="full", freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")):
    
    original_loss, total_effect_loss_change, direct_effect_loss_change, indirect_effect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False)
    if mode== "indirect":
        pos_wise_loss = indirect_effect_loss_change
        #pos_wise_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    elif mode == "direct":
        pos_wise_loss = direct_effect_loss_change
        #pos_wise_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    else:
        assert mode =="full"
        pos_wise_loss = total_effect_loss_change
        #pos_wise_loss = get_pos_loss_diff(prompt, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks, plot_hist=False)
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    haystack_utils.print_strings_as_html(str_token_prompt[1:], pos_wise_loss.flatten().cpu().tolist(), max_value=max_value)

def print_predictions(prompt, pos, k=20):
    print("\nFull model predictions")
    get_top_differences_at_position(prompt, model, pos, k, mode="full")
    print("\nIndirect predictions (leave German neuron active, patch corrupted activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="indirect")
    print("\nDirect predictions (ablate German neuron, patch clean activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="direct")

In [16]:
def get_mlp5_attribution_without_mlp4(prompt, pos = -1):
    # Freeze everything except for MLP5 to see if MLP5 depends on MLP4
    
    freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out")
    original_loss, total_effect_loss, direct_mlp3_mlp5_loss, _= haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False, return_absolute=True)
    freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")
    _, _, direct_mlp3_loss, _ = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False, return_absolute=True)
    return original_loss[0, pos].item(), total_effect_loss[0, pos].item(), direct_mlp3_mlp5_loss[0, pos].item(), direct_mlp3_loss[0, pos].item()

def compare_activations(prompt: str, model: HookedTransformer, layer=5, pos=-1):
    tokens = model.to_tokens(prompt)
    
    with model.hooks(fwd_hooks=activate_neurons_fwd_hooks):
        original_logits, original_cache = model.run_with_cache(tokens)

    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
        ablated_logits, ablated_cache = model.run_with_cache(tokens)
    
    act_label = f"blocks.{layer}.mlp.hook_pre"
    
    original_activation, ablated_activation = original_cache[act_label][0, pos-1, :], ablated_cache[act_label][0, pos-1, :]
    activation_difference = original_activation - ablated_activation
    #activation_difference = einops.rearrange(activation_difference, "batch pos d_mlp -> (batch pos) d_mlp")
    return activation_difference, original_activation, ablated_activation

def get_neuron_logit_contribution(cache: ActivationCache, model: HookedTransformer, answer_tokens: Int[Tensor, "batch pos"], layer: int, pos:int) -> Float[Tensor, "neuron pos"]:
    # Expects cache from a single example, won't work on batched examples
    # Get per neuron output of MLP layer
    neuron_directions = cache.get_neuron_results(layer, neuron_slice=utils.Slice(input_slice=None), pos_slice=utils.Slice(input_slice=None))
    neuron_directions = einops.rearrange(neuron_directions, 'batch pos neuron residual -> neuron batch pos residual')
    # We need to apply the final layer norm because the unembed operation is applied after the final layer norm, so the answer token
    # directions are in the same space as the final layer norm output
    # LN leads to finding top tokens with slightly higher loss attribution
    scaled_neuron_directions = cache.apply_ln_to_stack(neuron_directions)[:, 0, pos-1, :] # [neuron embed]
    # Unembed of correct answer tokens
    correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # [embed] 
    # Neuron attribution to correct answer token by position
    unembedded = einops.einsum(scaled_neuron_directions, correct_token_directions, 'neuron residual, residual -> neuron')
    return unembedded

def MLP_attribution(prompt: str, model: HookedTransformer, fwd_hooks, layer_to_compare=5, pos=-1):
    
    tokens = model.to_tokens(prompt)
    answer_tokens = tokens[:, pos]
    # Get difference between ablated and unablated neurons' contribution to answer logit
    _, _, original_cache, ablated_cache = haystack_utils.get_caches_single_prompt(
        prompt, model, fwd_hooks)
    original_unembedded = get_neuron_logit_contribution(original_cache, model, answer_tokens, layer=layer_to_compare, pos=pos) # [neuron]
    ablated_unembedded = get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=layer_to_compare, pos=pos)
    differences = (original_unembedded - ablated_unembedded).detach().cpu() # [neuron]
    return differences


def get_neuron_loss_attribution(prompt, model, neurons, pos=-1):
    original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
        ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

    # Remove the effects of ablating at MLP3 from the components after MLP3
    def freeze_neurons_hook(value, hook: HookPoint):
        #print("Freeze", neurons, ablated_cache[hook.name][0, pos-1, neurons], "=", value[0, pos-1, neurons], "to", original_cache[hook.name][0, pos-1, neurons])
        value[:, :, neurons] = original_cache[hook.name][:, :, neurons] # [batch pos neuron
        return value      

    freeze_original_hooks = [("blocks.5.mlp.hook_post", freeze_neurons_hook)]
    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+freeze_original_hooks):
        ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)
    #print(ablated_loss[0, :], ablated_with_original_frozen_loss[0, :])
    return original_loss[0, pos].item(), ablated_loss[0, pos].item(), ablated_with_original_frozen_loss[0, pos].item()

In [17]:
def top_mlp_effect_on_prompts(prompts: list[str], k = 20, log=False, top_neurons=None, pos=-1):
    
    top_occurences = torch.zeros(model.cfg.d_vocab)
    percent_explained_per_prompt = []
    original_losses = []
    total_effect_losses = []
    direct_mlp3_mlp5_losses = []
    direct_mlp3_losses = []
    frozen_losses = []

    if (top_neurons is not None) and (len(top_neurons) < k):
            print(f"Warning: Only {len(top_neurons)} neurons given for k={k}.")

    for prompt in prompts:
        original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss = get_mlp5_attribution_without_mlp4(prompt, pos=pos)
        if top_neurons is None:
            differences = MLP_attribution(prompt, model, fwd_hooks=deactivate_neurons_fwd_hooks, layer_to_compare=5, pos=pos)
            top_diff, top_diff_neurons = torch.topk(differences, k, largest=True)
            top_occurences[top_diff_neurons] += 1
        else:
            top_diff_neurons = torch.LongTensor(top_neurons)

        
        
        _, _, frozen_loss = get_neuron_loss_attribution(prompt, model, top_diff_neurons[:k], pos=pos)
        
        ablation_loss_increase = total_effect_loss - original_loss
        frozen_loss_decrease = total_effect_loss - frozen_loss
        percent_explained_by_mlp5 = frozen_loss_decrease / ablation_loss_increase
        percent_explained_per_prompt.append(percent_explained_by_mlp5)
        original_losses.append(original_loss)
        total_effect_losses.append(total_effect_loss)
        direct_mlp3_mlp5_losses.append(direct_mlp3_mlp5_loss)
        direct_mlp3_losses.append(direct_mlp3_loss)
        frozen_losses.append(frozen_loss)

        if log:
            print(f"\n{prompt}")
            print("MLP 5 attribution")
            print(f"Original loss: {original_loss:.4f}")
            print(f"Total effect loss: {total_effect_loss:.4f}")#
            print(f"Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): {direct_mlp3_mlp5_loss:.4f}")
            print(f"Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): {direct_mlp3_loss:.4f}")
            print(f"Total effect loss when freezing top MLP5 neurons: {frozen_loss:.4f}")
    
    return percent_explained_per_prompt, original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, top_occurences/len(prompts)

## Look for interesting examples

In [18]:
def get_interesting_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        original_loss, total_effect_loss_change, direct_loss_change, indirect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, debug_log=False)
        #indirect_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks)
        #direct_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks)
        loss_difference = abs(direct_loss_change - indirect_loss_change)
        loss_difference[indirect_loss_change < 1] = 0
        loss_difference[direct_loss_change > 1] = 0
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs

n_examples = 500
max_diffs, average_diffs = get_interesting_loss_prompts(german_data, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks)

loss_data_tuple = [(diff, example) for diff, example in zip(max_diffs, german_data[:n_examples])]
loss_data_tuple.sort(key=lambda x: x[0], reverse=True)
loss_data_tuple[:2]

  0%|          | 4/2459 [00:00<02:06, 19.42it/s]

100%|██████████| 2459/2459 [02:10<00:00, 18.78it/s]


[(8.79670238494873,
  'Änderungsantrag 16, Änderungsanträge 14, 15, 26, 29, 30 und 75 können wir nicht akzeptieren - aus den bereits genannten Gründen, d. h. sie sind aufgrund positiver Entwicklungen im Rat zu diesen Fragen überflüssig; Änderungsanträge 53, 55 und 76 zu den Vertragsbediensteten - weil die Kommission, wie ich bereits ausführte, Änderungsanträge 52 und 54 des Ausschusses für Recht und Binnenmarkt unterstützt; Änderungsantrag 32 zur Anwendung des Artikels 50 auf die Besoldungsgruppe AD 12 - wir halten dies für unpassend; Änderungsantrag 63 zur sexuellen Belästigung, da die Kommission den von den Berichterstattern in Änderungsantrag 17 vorgelegten Text unterstützt, und schließlich Änderungsantrag 64 zur Auslandszulage, weil die Kommission an der Auffassung festhält, dass die Zulage für im Ausland tätige Bedienstete objektiv gerechtfertigt und sogar objektiv notwendig ist.'),
 (8.11987018585205,
  'Meines Erachtens kann die Europäische Union, unter Einhaltung des Subsidiari

In [19]:
def show_all_loss_types(prompt):
    print("Full model effect on loss")
    show_token_loss(prompt, model, max_value=5, mode="full")
    print("Indirect effect on loss")
    show_token_loss(prompt, model, max_value=5, mode="indirect")
    print("Direct effect on loss")
    show_token_loss(prompt, model, max_value=5, mode="direct")

In [20]:
for _ , prompt in loss_data_tuple[0:2]:
    print("")
    show_token_loss(prompt, model, max_value=5, mode="full")
    show_token_loss(prompt, model, max_value=5, mode="indirect")
    show_token_loss(prompt, model, max_value=5, mode="direct")







## Ansicht

In [21]:
prompt = "Ich möchte nochmals meine Ansicht"
# Check loss MLP5 loss increase when patching clean activations to MLP4
show_all_loss_types(prompt)
get_mlp5_attribution_without_mlp4(prompt)

Full model effect on loss


Indirect effect on loss


Direct effect on loss


(4.715382099151611, 12.902793884277344, 7.96336555480957, 3.319126605987549)

In [22]:
#prompts =  ["Obwohl er gute Argumente eingebracht hat bin ich der Ansicht"]
prompts = ["Neben anderen teile ich diese Ansicht"]
prompts = ["Erstens hat er im Gegensatz zum Rat die Richtlinie auf zivilrechtliche Klagen aufgrund von strafbaren Handlungen erweitert, denn unserer Ansicht", 
           "Modalitäten durchzuführen: einen Europäischen Rat, um gut vorbereitete Beschlüsse zu fassen, mit einer klaren Agenda, in dem die Beschlüsse meiner Ansicht"]

In [23]:
prompts = ["Die eingebrachten Gesetzentwürfe sind wichtig, denn seiner Ansicht", 
           "Die eingebrachten Gesetzentwürfe sind wichtig, denn unserer Ansicht", 
           "Die eingebrachten Gesetzentwürfe sind wichtig, denn ihrer Ansicht", 
           "Die eingebrachten Gesetzentwürfe sind wichtig, denn deiner Ansicht", 
           "Die eingebrachten Gesetzentwürfe sind wichtig, denn meiner Ansicht", 
           "Die eingebrachten Gesetzentwürfe sind wichtig, denn eurer Ansicht"]

In [49]:
percent_explained_per_prompt, original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, top_occurences = top_mlp_effect_on_prompts(prompts, log=False)
data = [original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top 20 MLP5 neurons)"]

haystack_utils.plot_barplot(data, names, title="Loss increase from disabling German context neuron and patching in clean downstream components", width=1400)

## Vorschläge

In [25]:
prompt = "zu den Vorschlägen"
top_mlp5_tokens = torch.LongTensor([838, 1026, 1709,  822, 1716,  905,  925,   84, 1414, 1506, 1227,  216, 852, 1765, 1456,  785,  959, 1043, 1514, 1751]) # Generated on "zu den Vorschlägen"
top_common_neurons = [1026, 1414,  905, 1709,  822, 1506,  838, 84, 1765,  216, 925] # Generated across all prompts
top_common_neurons = [ 822,  905,  838, 1414, 1709, 1026, 1506,  925, 1765,  216,   84,  959, 1716,  545,  509] # Top 15 across all prompts

In [26]:
prompts1 = ["Ich habe noch einige Fragen zu Ihren Vorschlägen",
    "Ich stimme den Vorschlägen",
    "Kannst du bitte mit deinen Vorschlägen",
    "Laut den Vorschlägen",
    "Die Diskussion wurde mit vielen interessanten Vorschlägen",
    "Sie schrieb einen Brief mit ihren Vorschlägen",
    "Wir werden nach deinen Vorschlägen",
    "Sind Sie mit diesen Vorschlägen",
    "Gemäß den Vorschlägen",
    "Der Kunde war nicht einverstanden mit unseren Vorschlägen",
    "Sie zeigte uns ein Dokument mit ihren Vorschlägen",
    "Das Team arbeitet an den Vorschlägen",
    "Meinen Vorschlägen"]

prompts2 = [
    "Der Ausschuss wird nach den Vorschlägen",
    "Sie war unzufrieden mit den Vorschlägen",
    "Unsere Agentur kam mit neuen Vorschlägen",
    "Haben Sie Änderungen zu den Vorschlägen",
    "Gemäß Ihren Vorschlägen",
    "Ich schrieb einen Bericht mit meinen Vorschlägen",
    "Nach den Vorschlägen",
    "Ich werde gemäß Ihren Vorschlägen",
    "Der Manager war unzufrieden mit den Vorschlägen",
    "Mit diesen Vorschlägen",
    "Der Ausschuss stimmte den Vorschlägen",
    "Der Leiter war sehr zufrieden mit den Vorschlägen",
    "Nach den aktuellen Vorschlägen",
    "Mit ihren innovativen Vorschlägen",
    "Mit einigen Verbesserungen zu Ihren Vorschlägen",
    "Das Team kam mit neuen Vorschlägen",
    "Sie kam mit großartigen Vorschlägen",
    "Ich werde mit meinen Vorschlägen",
    "Sie zeigte sich zufrieden mit den Vorschlägen",
    "Der Direktor war beeindruckt von den Vorschlägen",
    "Die Organisation hat nach Ihren Vorschlägen",
    "Wir haben ein Dokument mit den Vorschlägen",
    "Mit einigen Vorschlägen",
    "Der Lehrer war sehr zufrieden mit den Vorschlägen",
    "Nach Ihren Vorschlägen",
    "In Übereinstimmung mit den Vorschlägen",
    "Sie stimmte den Vorschlägen",
    "Mit diesen neuen Vorschlägen",
    "Ich bin sehr beeindruckt von Ihren Vorschlägen",
    "Das Unternehmen hat nach unseren Vorschlägen",
    "Die Jury war beeindruckt von den Vorschlägen",
    "Die Verwaltung hat nach Ihren Vorschlägen",
    "Das Publikum war sehr zufrieden mit den Vorschlägen",
]

prompts = prompts1 + prompts2

In [47]:
# Check loss MLP5 loss increase when patching clean activations to MLP4
print("Breakdown for prompt:", prompts1[0][-20:])
show_all_loss_types(prompts1[0][-20:])
get_mlp5_attribution_without_mlp4(prompts1[0][-20:])

Breakdown for prompt: zu Ihren Vorschlägen
Full model effect on loss


Indirect effect on loss


Direct effect on loss


(5.763085842132568, 8.876919746398926, 8.932577133178711, 5.141571044921875)

In [28]:
percent_explained_per_prompt, original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, top_occurences = top_mlp_effect_on_prompts(prompts, log=False)
data = [original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]

haystack_utils.plot_barplot(data, names, width=1400)

In [29]:
percent_explained_per_prompt, original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, _ = top_mlp_effect_on_prompts(prompts, log=False, top_neurons=top_common_neurons)
data = [original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]

haystack_utils.plot_barplot(data, names, width=1400)




In [30]:
torch.topk(top_occurences, 15, largest=True)

torch.return_types.topk(
values=tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9783, 0.9565, 0.9348, 0.9130,
        0.9130, 0.9130, 0.8043, 0.7826, 0.6087, 0.5217]),
indices=tensor([ 822,  905,  838, 1414, 1709, 1026, 1506,  925, 1765,  216,   84,  959,
        1716,  545,  509]))

In [31]:
top_common_neurons[:5]

[822, 905, 838, 1414, 1709]

## Neuron 822

In [32]:
original_activations = []
ablated_activations = []
for prompt in prompts:
    original_loss, ablated_loss, original_cache, ablated_cache = haystack_utils.get_caches_single_prompt(prompt, model, fwd_hooks=deactivate_neurons_fwd_hooks)
    neurons = torch.LongTensor(top_common_neurons)
    original_activations.append(original_cache["blocks.5.mlp.hook_post"][0, -2, neurons])
    ablated_activations.append(ablated_cache["blocks.5.mlp.hook_post"][0, -2, neurons])

original_activations = torch.stack(original_activations)
ablated_activations = torch.stack(ablated_activations) 

- Neurons that get activated and have a positive effect on loss: 84, 1709, 1414, 905
- Neuron that get deactivated to not do the bad thing anymore: 822
- Neurons that get reduced to do a bad thing less: 838, 1026

In [33]:
neuron_index = 1
neuron = top_common_neurons[neuron_index]
#assert neuron == 822
mean_neuron_activation_original = original_activations[:, neuron_index].mean()
mean_neuron_activation_ablated = ablated_activations[:, neuron_index].mean()
print(neuron, mean_neuron_activation_original, mean_neuron_activation_ablated)
haystack_utils.two_histogram(original_activations[:, neuron_index], ablated_activations[:, neuron_index], "Original activation", "Ablated activation", n_bins=1000)

905 tensor(1.1434, device='cuda:0') tensor(0.6177, device='cuda:0')


In [34]:
answer_token = model.to_single_token(["gen"])
neuron_direction = model.W_out[5, neuron]

neuron_effect_original = neuron_direction * mean_neuron_activation_original
neuron_effect_ablated = neuron_direction * mean_neuron_activation_ablated
unembed_dir = model.W_U[:, answer_token]

print("Correct answer token on original activation")
print(neuron_effect_original @ unembed_dir)
print("Correct answer token on ablated activation")
print(neuron_effect_ablated @ unembed_dir)

Correct answer token on original activation
tensor(1.0301, device='cuda:0')
Correct answer token on ablated activation
tensor(0.5565, device='cuda:0')


In [35]:
neuron_input = model.W_in[5, :, neuron]
context_output = model.W_out[LAYER_TO_ABLATE, NEURONS_TO_ABLATE, :].flatten()
print("Cosine similarity between neuron input and context output")
cos = torch.nn.CosineSimilarity(dim=0)
print(cos(neuron_input, context_output))

Cosine similarity between neuron input and context output
tensor(0.0859, device='cuda:0')


In [36]:
neuron_wise_sim = []
for other_neuron in range(model.cfg.d_mlp):
    neuron_wise_sim.append(cos(model.W_in[5, :, other_neuron], model.W_out[LAYER_TO_ABLATE, NEURONS_TO_ABLATE, :].flatten()).item())

px.histogram(neuron_wise_sim, title="Cosine similarity between neuron input and context output", width=800)

In [37]:
percent_explained_per_prompt, original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, _ = top_mlp_effect_on_prompts(prompts, log=False, top_neurons=[neuron])
data = [original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]

haystack_utils.plot_barplot(data, names, width=1400)



- 822 has a highly negative cosine sim with the context neuron (outlier in the general histogram)
- 822 has it's activations shifted into the negative range by the context neuron
- Average negative 822 activation boosts the correct token directly

## Other

In [38]:
def get_loss_contribution_per_neuron(prompt, model, neurons, pos=-1):
    loss_per_neuron = []
    for neuron in neurons:
        original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
        with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
            ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

        # # Add the effects of ablating at MLP3 to the components after MLP3
        # def freeze_ablated_hook(value, hook: HookPoint):
        #     value = ablated_cache[hook.name]
        #     return value             
        # freeze_ablated_hooks = [(freeze_act_name, freeze_ablated_hook) for freeze_act_name in freeze_act_names]
        # with model.hooks(fwd_hooks=freeze_ablated_hooks):
        #     original_with_frozen_ablated = model(prompt, return_type="loss", loss_per_token=True)

        # Remove the effects of ablating at MLP3 from the components after MLP3
        def freeze_neurons_hook(value, hook: HookPoint):
            value[:, :, neuron] = original_cache[hook.name][:, :, neuron] # [batch pos neuron
            return value      


        freeze_original_hooks = [("blocks.5.mlp.hook_post", freeze_neurons_hook)]
        with model.hooks(fwd_hooks=freeze_original_hooks+deactivate_neurons_fwd_hooks):
            ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)
        
        loss_per_neuron.append(ablated_with_original_frozen_loss[0, -1].item())

    haystack_utils.line(loss_per_neuron, xlabel="Neuron", ylabel="Total loss of restoring the neuron", width=1000, xticks=top_common_neurons)

get_loss_contribution_per_neuron(prompts[0], model, top_common_neurons)

### MLP3+4

In [39]:
# all losses are positionwise. All losses are increases in loss from ablating German context neuron for different paths.
def mlp_34_metric(ablated_loss, MLP3_4_restored_loss, MLP3_5_restored_loss, MLP3_4_5_restored_loss):
    # loss with 3 and 4 restored should be less than loss with 3 restored or loss with 4 restored. Moreover, the difference between it and the fully ablated loss
    # should be larger than the sum difference of the differences with just 3 or 4 restored

    # looking for prompts where (MLP3_restored_loss - MLP4_5_restored_loss) > (MLP3_restored_loss - MLP4_restored_loss) + (MLP3_restored_loss - MLP5_restored_loss)
    # == effect of MLP4+5 German circuit > effect of MLP4 German circuits + effect of MLP5 German circuits
    return (ablated_loss - MLP3_4_5_restored_loss) - ((ablated_loss - MLP3_4_restored_loss) + (ablated_loss - MLP3_5_restored_loss))

In [40]:
haystack_utils.clean_cache()

In [41]:
def get_interesting_loss_prompts(prompts: list[str], model: HookedTransformer):
    max_mlp_34 = []
    for prompt in tqdm(prompts):
        original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss, direct_mlp3_mlp4_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=100, log=False, answer_pos=None)
        mlp_34_metric_item = mlp_34_metric(total_effect_loss, direct_mlp3_mlp4_loss, direct_mlp3_mlp5_loss, original_loss)
        max_mlp_34.append(mlp_34_metric_item.max().item())
    return max_mlp_34

max_mlp_34 = get_interesting_loss_prompts(german_data, model)

  0%|          | 0/2459 [00:00<?, ?it/s]

 55%|█████▌    | 1364/2459 [05:38<04:31,  4.03it/s]


KeyboardInterrupt: 

In [None]:
# Filter for examples with high difference in top MLP5 neurons and MLP4+MLP5 - these are the examples where patching some MLP5 is not enough (so maybe they rely on MLP4)
loss_data_tuple = [(diff, example) for diff, example in zip(max_mlp_34, german_data[:len(max_mlp_34)])]
loss_data_tuple.sort(key=lambda x: x[0], reverse=True)
loss_data_tuple[:2]

[(38.321533203125,
  'Netizens beschwerte sich gestern auf dem Gebäude gegenüber von Taiwan Cement, North Zhongshan Road, die Polizei habe nicht davon ausgehen, nicht nur der die Schlange lange Klinge, sondern auch, um mit Gewalt nehmen die Hände der Menschen, die nationale Flagge, Snow Lion Fahne und Banner der Vereinten Nationen hat eine Fraktur Netizen Finger, die Lake Branch und zur Festnahme von drei Freunden, die sich mit Gewalt, um die Sicherheit Battalion, wenn der Anwalt kam bei der Szene, die Polizei sprach persönlich Netizens nicht als Beleidigung, nur wissen wollen, die Namen von Freunden, und dann die Großen Rechtsanwalt kam bei der Szene, Personal Polizei nach wie vor darauf, sie würden eher das Gesetz nicht erlaubt, um die medizinische Behandlung, und dann hatte eine Entschuldigung zu "schützen" Benutzer, die alle den Weg zum Parlament durch den Tracking-Klinik und Krankenhaus, die darauf bestehen, sie können nicht frei gehen Netizen Netizen Beleidigende, die angebliche 

In [None]:
# The effect here is not consistent between prompts

prompt = 'Netizens beschwerte sich gestern auf dem Gebäude gegenüber von Taiwan Cement, North Zhongshan Road, die Polizei habe nicht davon ausgehen, nicht nur der die Schlange lange Klinge, sondern auch, um mit Gewalt nehmen die Hände der Menschen, die nationale Flagge, Snow Lion Fahne und Banner der Vereinten Nationen hat eine Fraktur Netizen Finger, die Lake Branch und zur Festnahme von drei Freunden, die sich mit Gewalt, um die Sicherheit Battalion, wenn der Anwalt kam bei der Szene, die Polizei sprach persönlich Netizens nicht als Beleidigung, nur wissen wollen, die Namen von Freunden, und dann die Großen Rechtsanwalt kam bei der Szene, Personal Polizei nach wie vor darauf, sie würden eher das Gesetz nicht erlaubt, um die medizinische Behandlung, und dann hatte eine Entschuldigung zu "schützen" Benutzer, die alle den Weg zum Parlament durch den Tracking-Klinik und Krankenhaus, die darauf bestehen, sie können nicht frei gehen Netizen Netizen Beleidigende, die angebliche Vorurteile User Tatsu Freiheit 3 Stunden, bis洪建益,张茂楠, Jane余晏usw.'

# ia zu
print(model.tokenizer.decode(571))
print(model.tokenizer.decode(10736))
# prompts = [
#     "Ich bin dabei, eine Studie über das Bakterium zu",
#     "Sie plant, eine Expedition in die Antarktis zu",
#     "Er hat eine Dokumentation über die Galaxia zu",
#     "Sie wollen Informationen über die Anämie zu",
#     "Wir müssen eine Präsentation über die Phobie zu",
#     "Es ist schwer, eine Diagnose zur Dyslexie zu",
#     "Ich bin gerade dabei, eine Abhandlung zur Hypothermie zu",
#     "Er plant eine Reise zur Kolonie in Britannia zu",
#     "Wir denken darüber nach, eine Forschung zur Euphorie zu",
#     "Sie ist dabei, ein Projekt zur Dysgraphie zu",
#     "Er überlegt, eine Therapie gegen die Amnesie zu",
#     "Ich habe vor, eine Konferenz zum Thema Dysphoria zu"
# ]
# german_sentences = [
#     "Ich plane, eine Reise nach Santa Giulia zu",
#     "Sie wollen Informationen über die Geschichte von Santa Giulia zu",
#     "Er ist dabei, eine Dokumentation über Santa Giulia zu",
#     "Wir überlegen, eine Exkursion nach Santa Giulia zu",
#     "Sie denkt darüber nach, ein Forschungsprojekt in Santa Giulia zu",
#     "Ich habe vor, eine Veranstaltung in Santa Giulia zu",
#     "Er plant, seinen nächsten Urlaub in Santa Giulia zu",
#     "Sie träumt davon, eines Tages nach Santa Giulia zu",
#     "Wir bereiten gerade eine Präsentation über Santa Giulia zu",
#     "Ich schreibe gerade einen Bericht über Santa Giulia zu"
# ]
# giulia = [
#     "Ich plane, eine Reise nach Giulia zu",
#     "Sie wollen Informationen über die Geschichte von Giulia zu",
#     "Er ist dabei, eine Dokumentation über Giulia zu",
#     "Wir überlegen, eine Exkursion nach Giulia zu",
#     "Sie denkt darüber nach, ein Forschungsprojekt in Giulia zu",
#     "Ich habe vor, eine Veranstaltung in Giulia zu",
#     "Er plant, seinen nächsten Urlaub in Giulia zu",
#     "Sie träumt davon, eines Tages nach Giulia zu",
#     "Wir bereiten gerade eine Präsentation über Giulia zu",
#     "Ich schreibe gerade einen Bericht über Giulia zu"
# ]


pos=-1

print("")
show_token_loss(prompt, model, max_value=5, mode="full")
show_token_loss(prompt, model, max_value=5, mode="indirect")
show_token_loss(prompt, model, max_value=5, mode="direct")

k = 100
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]

# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss, _ = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=pos)
list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]

haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")
# for prompt in giulia:
#     original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=pos)
#     list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
#     names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
#     haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

ia
 zu



NameError: name 'show_token_loss' is not defined