## Setup

In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LOG_PROB_THRESHOLD = -7
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

    
def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.
wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [3]:
def get_pos_loss_diff(prompt: str, model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]], plot_hist=False, use_activate_hook=False, debug_log=True):
    tokens = model.to_tokens(prompt)
    if use_activate_hook:
        original_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=activate_neurons_hook, loss_per_token=True)
    else:
        original_loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=deactivate_neurons_hook, loss_per_token=True)
    
    # Positive difference = loss increase due to ablation
    loss_difference = (ablated_loss - original_loss).flatten()

    if debug_log:
        print(f"Unablated loss: {original_loss.flatten()}")
        print(f"Ablated loss: {ablated_loss.flatten()}")
        print(f"Loss difference: {loss_difference}")

    if plot_hist:
        fig = px.histogram(loss_difference.flatten().cpu().numpy(), title="Loss difference due to ablation per position")
        fig.show()
    return loss_difference

def get_high_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        loss_difference = get_pos_loss_diff(prompt, model, activate_neurons_hook, deactivate_neurons_hook)
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs


In [4]:
def get_top_differences_at_position(prompt: str, model: HookedTransformer, position: int, top_k=20, mode="full"):
    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(tokens)
    # Logprobs instead of logits
    original_logits = model(tokens, return_type="logits")
    if mode=="direct":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_frozen_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    elif mode=="indirect":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_ablated_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    else:
        assert mode=="full"
        ablated_logits = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=deactivate_neurons_fwd_hooks)
    original_logprob = original_logits.log_softmax(dim=-1)
    ablated_logprob = ablated_logits.log_softmax(dim=-1)

    # Positive difference = the German neuron makes the token more likely
    # Negative difference = the German neuron makes the token less likely
    logprob_differences = original_logprob - ablated_logprob
    logit_differences = original_logits - ablated_logits

    print("Prompt:", prompt)
    print(f"Differences for predicting: {str_tokens[position]} -> {str_tokens[position+1]}")

    low_log_prob = torch.argwhere(((original_logprob[0, position, :] <= LOG_PROB_THRESHOLD) & (ablated_logprob[0, position, :] <= LOG_PROB_THRESHOLD))).flatten()
    ignore_tokens = torch.cat([low_log_prob, all_ignore]).unique()
    
    top_original_logprobs, top_original_idx = haystack_utils.top_k_with_exclude(original_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_original_ablated_logprobs = ablated_logprob[0, position, top_original_idx]
    top_ablated_logprobs, top_ablated_idx = haystack_utils.top_k_with_exclude(ablated_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_ablated_original_logprobs = original_logprob[0, position, top_ablated_idx]

    top_original_values = logprob_differences[0, position, top_original_idx]
    top_ablated_values = logprob_differences[0, position, top_ablated_idx]
    top_original_logit_diff = logit_differences[0, position, top_original_idx]
    top_ablated_logit_diff = logit_differences[0, position, top_ablated_idx]
    print("Top predictions with German neuron active (unablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_original_idx), top_original_values.cpu().tolist(), max_value=5, original_log_probs=top_original_logprobs.cpu().tolist(), ablated_log_probs=top_original_ablated_logprobs.cpu().tolist(), logit_difference=top_original_logit_diff.cpu().tolist())
    print("Top predictions with German neuron disabled (ablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_ablated_idx), top_ablated_values.cpu().tolist(), max_value=5, original_log_probs=top_ablated_original_logprobs.cpu().tolist(), ablated_log_probs=top_ablated_logprobs.cpu().tolist(), logit_difference=top_ablated_logit_diff.cpu().tolist())

    top_boosts, top_boosted_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_boost_original_logprob = original_logprob[0, position, top_boosted_idx]
    top_boost_ablated_logprob = ablated_logprob[0, position, top_boosted_idx]
    top_reduced, top_reduced_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, largest=False, exclude=ignore_tokens)
    top_reduced_original_logprob = original_logprob[0, position, top_reduced_idx]
    top_reduced_ablated_logprob = ablated_logprob[0, position, top_reduced_idx]
    print("Top boosted tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_boosted_idx), top_boosts.cpu().tolist(), max_value=5, original_log_probs=top_boost_original_logprob.cpu().tolist(), ablated_log_probs=top_boost_ablated_logprob.cpu().tolist())
    print("Top reduced tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_reduced_idx), top_reduced.cpu().tolist(), max_value=5, original_log_probs=top_reduced_original_logprob.cpu().tolist(), ablated_log_probs=top_reduced_ablated_logprob.cpu().tolist())

In [5]:
def show_token_loss(prompt: str, model: HookedTransformer, max_value=None, mode="full", freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"), custom_loss_change=None):
    
    original_loss, total_effect_loss_change, direct_effect_loss_change, indirect_effect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False)
    if mode== "indirect":
        pos_wise_loss = indirect_effect_loss_change
        #pos_wise_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    elif mode == "direct":
        pos_wise_loss = direct_effect_loss_change
        #pos_wise_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    elif mode == "custom":
        pos_wise_loss = custom_loss_change
    else:
        assert mode =="full"
        pos_wise_loss = total_effect_loss_change
        #pos_wise_loss = get_pos_loss_diff(prompt, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks, plot_hist=False)
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    haystack_utils.print_strings_as_html(str_token_prompt[1:], pos_wise_loss.flatten().cpu().tolist(), max_value=max_value)

def print_predictions(prompt, pos, k=20):
    print("\nFull model predictions")
    get_top_differences_at_position(prompt, model, pos, k, mode="full")
    print("\nIndirect predictions (leave German neuron active, patch corrupted activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="indirect")
    print("\nDirect predictions (ablate German neuron, patch clean activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="direct")

## Look for interesting examples with new loss difference sorting

In [7]:
# all losses are positionwise. All losses are increases in loss from ablating German context neuron for different paths.
# Total effect loss == everything loss, direct MLP3+5 loss == MLP3+5 loss, direct3 == MLP3 loss, frozen loss == top MLP5 neurons' loss.
def custom_diff_weight(total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss):
    # We want high loss difference between total effect and direct effect of the context neuron (so restoring MLP4 and MLP5 fixes the problem)
    indirect_diff = total_effect_loss - direct_mlp3_loss
    # We want high loss difference between frozen MLP4 and the direct effect of the context neuron (so only restoring MLP4 is not enough)
    mlp_4_diff = direct_mlp3_mlp5_loss - direct_mlp3_loss
    # We want high loss difference between frozen MLP5 and the direct effect of the context neuron (so only restoring MLP5 is not enough)
    mlp_5_diff = frozen_loss - direct_mlp3_loss
    return indirect_diff, mlp_4_diff, mlp_5_diff

In [None]:
# all losses are positionwise. All losses are increases in loss from ablating German context neuron for different paths.
# Total effect loss == everything loss, direct MLP3+5 loss == MLP3+5 loss, direct3 == MLP3 loss, frozen loss == top MLP5 neurons' loss.
def lq_custom_diff_weight(total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss): # same as above just updated comments for my own understanding
    # MLP4+5 enabled together are more important than MLP3 if this number is more positive
    indirect_diff = total_effect_loss - direct_mlp3_loss
    # MLP5 is more necessary if this number is more positive. Enable only MLP4 vs enable MLP4+5, if the number is small then MLP4 is important in its own right
    mlp_4_diff = direct_mlp3_mlp5_loss - direct_mlp3_loss
    # MLP5 top neurons are more necessary than MLP3 direct effect if this number is more positive
    mlp_5_diff = frozen_loss - direct_mlp3_loss
    return indirect_diff, mlp_4_diff, mlp_5_diff

In [8]:
def get_interesting_loss_prompts(prompts: list[str], model: HookedTransformer):
    max_indirect_diff, max_mlp_4_diff, max_mlp_5_diff = [], [], []
    for prompt in tqdm(prompts):
        original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=100, log=False, answer_pos=None)
        indirect_diff, mlp_4_diff, mlp_5_diff = custom_diff_weight(total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss)
        max_indirect_diff.append(indirect_diff.max().item())
        max_mlp_4_diff.append(mlp_4_diff.max().item())
        max_mlp_5_diff.append(mlp_5_diff.max().item())
    return max_indirect_diff, max_mlp_4_diff, max_mlp_5_diff

max_indirect_diff, max_mlp_4_diff, max_mlp_5_diff = get_interesting_loss_prompts(german_data, model)

  0%|          | 0/2459 [00:00<?, ?it/s]

In [9]:
# Filter for examples with high difference in top MLP5 neurons and MLP4+MLP5 - these are the examples where patching some MLP5 is not enough (so maybe they rely on MLP4)
loss_data_tuple = [(diff, example) for diff, example in zip(max_mlp_5_diff, german_data[:len(max_mlp_5_diff)])]
loss_data_tuple.sort(key=lambda x: x[0], reverse=True)
loss_data_tuple[:2]

[(2.4673328399658203,
  'Zur Zielsetzung, die das Vereinigte Königreich verfolgt - die völlige Aufhebung des Embargos - möchte ich unterstreichen, daß der Rat in seinen Schlußfolgerungen vom 29. und 30. April klar gesagt hat - und hier beziehe ich mich auf Absatz 6 dieser Schlußfolgerungen -, daß die Gesamtheit der vom Vereinigten Königreich ergriffenen Maßnahmen, deren Umsetzung sowie Kontrolle durch die Kommission, nämlich das von den britischen Behörden angekündigte selektive Schlachtprogramm, die sich als notwendig erwiesenen Zusatzmaßnahmen sowie schließlich die ebenso notwendige Bekräftigung, sich bei zukünftigen Beschlußfassungen immer mehr auf solide wissenschaftliche Aussagen zu stützen, Anhaltspunkte darstellen, die zu dem Prozeß gehören, der eine allmähliche Aufhebung des Ausfuhrverbots ermöglichen dürfte.'),
 (2.238445281982422,
  'Der Europäische Rat stellte fest, dass weitere Schritte in Richtung einer EU-Mitgliedschaft erstens auf der Grundlage der Debatte über die Erwei

In [None]:
for _ , prompt in loss_data_tuple[100:110]:
    print(f"\n{prompt}")
    original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=50, log=False, answer_pos=None)    
    mlp_5_diff = frozen_loss - direct_mlp3_loss
    mlp_4_diff = direct_mlp3_mlp5_loss - direct_mlp3_loss
    show_token_loss(prompt, model, max_value=5, mode="custom", custom_loss_change=mlp_4_diff)
    show_token_loss(prompt, model, max_value=5, mode="custom", custom_loss_change=mlp_5_diff)
    # Scan for examples with reasonable overall loss
    show_token_loss(prompt, model, max_value=10, mode="custom", custom_loss_change=original_loss*-1)
    #show_token_loss(prompt, model, max_value=5, mode="full")
    #show_token_loss(prompt, model, max_value=5, mode="indirect")
    #show_token_loss(prompt, model, max_value=5, mode="direct")

In [11]:
prompt = 'Abschließend möchte ich also sagen, wir sollten vorbehaltlich folgender Forderungen - 1. Vorlage eines kompletten Informationsdossiers vor jeder Marktzulassung'
k = 200
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

## The only actual example where both matter I found actually k=500 fixes it :(

In [68]:
prompt = 'Bericht (A4-0220/97) von Frau Pery im Namen des Ausschusses für Fischerei über den Vorschlag für einen Beschluß des Rates über den Abschluß des Abkommens in Form eines Briefwechsels zur vorübergehenden Verlängerung des Protokolls zum Abkommen zwischen der Europäischen Gemeinschaft und der Regierung der Republik Senegal über die Fischerei vor der senegalesischen Küste vom 2. Oktober 1996 bis 1. November 1996 (KOM(96)0611 - C4-0032/97-96/0287(CNS)); -Bericht (A4-0224/97) von Frau Péry im Namen des Ausschusses für Fischerei über den Vorschlag für eine Verordnung des Rates über den Abschluß des Protokolls zur Festlegung der Fischereirechte und des finanziellen Ausgleichs nach dem Abkommen zwischen der Regierung der Republik Senegal und der Europäischen Gemeinschaft über die Fischerei vor der senegalesischen Küste für die Zeit vom 1. Mai 1997 bis zum 30. April 2001 (KOM(97)0324 - C4-0322/97-97/0179(CNS)); -Bericht (A4-0229/97) von Herrn Gallagher über den Vorschlag für eine Verordnung (EG) des Rates über den Abschluß des Pro'
k = 500
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

## Something cool? MLP4 doing things when "." is continued with lowercase word
- On "und" between dates ("12. und 13. Dezember")
- After "usw" (=etc) ("... usw. zu fragen")
- Unfortunately the model is very bad at this in general (except for first example)

In [39]:
prompt = 'Ein einzigartiges Erlebnis ist es, gemeinsam über die Strände von Japan zu'
k = 2048
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)
list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [24]:
def evaluate_prompts_on_pos(prompts, pos=-1, k=100):
    original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses = [], [], [], [], []
    for prompt in prompts:
        original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=pos)
        original_losses.append(original_loss)
        total_effect_losses.append(total_effect_loss)
        direct_mlp3_mlp5_losses.append(direct_mlp3_mlp5_loss)
        direct_mlp3_losses.append(direct_mlp3_loss)
        frozen_losses.append(frozen_loss)
    names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
    haystack_utils.plot_barplot([original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses], names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")


In [40]:
prompts = [
    'Ein einzigartiges Erlebnis ist es, gemeinsam über die Straßen von Japan zu',
    'Ein einzigartiges Erlebnis ist es, gemeinsam über die Strände von Japan zu',
    'Ein einzigartiges Erlebnis ist es, gemeinsam über die Straßen von Mallorca zu',
    'Ein einzigartiges Erlebnis ist es, gemeinsam über die Strände von Mallorca zu',
    'Ein einzigartiges Erlebnis ist es, gemeinsam über die Straßen von Santa Giulia zu',
    'Ein einzigartiges Erlebnis ist es, gemeinsam über die Strände von Santa Giulia zu',
]

evaluate_prompts_on_pos(prompts, pos=-1, k=2048)

## L4 neuron contribution

In [41]:
prompt = 'Ein einzigartiges Erlebnis ist es, gemeinsam über die Strände von Japan zu'
l4_differences = haystack_utils.MLP_attribution(prompt, model, deactivate_neurons_fwd_hooks, layer_to_compare=4, pos=-1)
top_diff, top_diff_neurons = torch.topk(l4_differences, k=2048, largest=True)
#haystack_utils.line(top_diff, xticks=top_diff_neurons)
haystack_utils.get_neuron_loss_attribution(prompt, model, top_diff_neurons, ablation_hooks=deactivate_neurons_fwd_hooks, pos=-1, layer=4)

(3.251843214035034, 4.827531814575195, 4.333996772766113)

In [45]:
# MLP4 doesn't help, check attention
freeze_act_names=("blocks.5.hook_attn_out",)#"blocks.4.hook_attn_out",)#, )#, "blocks.4.hook_mlp_out")
original_loss, total_effect_loss, frozen_loss, _= haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False, return_absolute=True)
print(f"Frozen loss: {frozen_loss[0, -1].item():-4f}")

Frozen loss: 3.563227


In [13]:
prompt = 'Ein einzigartiges Erlebnis ist es, mit staunendem Herzen Hand in Hand das Feuerwerk vom 14. Juli von der Rondinara Bucht aus zu verfolgen, gemeinsam über die Strände von Santa Giulia zu'
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [14]:
prompt = 'Zur Zielsetzung, die das Vereinigte Königreich verfolgt - die völlige Aufhebung des Embargos - möchte ich unterstreichen, daß der Rat in seinen Schlußfolgerungen vom 29. und'
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [46]:
prompt = 'Der Europäische Rat stellte fest, dass weitere Schritte in Richtung einer EU-Mitgliedschaft erstens auf der Grundlage der Debatte über die Erweiterungsstrategie gemäß den Schlussfolgerungen des Rates vom Dezember 2005 geprüft werden müssten, die in den "erneuerten Konsens" über die Erweiterung, welcher auf dem Europäischen Rat am 14. und'
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [58]:
freeze_act_names=("blocks.5.hook_attn_out", "blocks.4.hook_attn_out",)#"blocks.4.hook_attn_out",)#, )#, "blocks.4.hook_mlp_out")
original_loss, total_effect_loss, frozen_loss, _= haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False, return_absolute=True)
print(f"Frozen loss: {frozen_loss[0, -1].item():-4f}")
# On this prompt both attention heads work together? Or they just do something good individually which is then added together

Frozen loss: 8.219750


In [16]:
prompt = 'Blicken wir auf die politisch-soziale Situation, in der die eingangs wiedergegebene Frage, deren Formulierung ich einem Buch des in Buenos Aires tätigen Colectivo Situaciones entnehme, ebenso wie ihre vorläufig nur knapp umrissenen Implikationen ihre konkrete Verortung finden: Sie betrifft die insbesondere an den Tagen des 19. und'

k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [17]:
prompt = 'Deshalb ist es falsch, den Schutz der biologischen Vielfalt strikt auf den zweiten Pfeiler der gemeinsamen Agrarpolitik - und insbesondere auf die Agrarumweltprogramme zur Förderung der Extensivierung, des ökologischen Landbaus, zur Erhaltung einheimischer Rassen, zum Schutz der natürlichen Lebensräume usw. zu'
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

## Clean examples of Layer 5 doing things

In [18]:
prompt = 'Zum Abschluss erlauben Sie mir bitte, all jenen Abgeordneten zu danken, die anwesend sind und denen, die an dieser Aussprache teilnehmen werden, aber wegen des Vulkanausbruchs nicht hier sein können, die viel Arbeit geleistet haben, um den Vorschlag für eine Entscheidung vorzubere'
k = 20
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

## Superposition examples
- The model needs more MLP5 neurons to restore the loss

In [49]:
prompt = 'Abschließend möchte ich also sagen, wir sollten vorbehaltlich folgender Forderungen - 1. Vorlage eines kompletten Informationsdossiers vor jeder Marktzulassung'
k = 20
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [7]:
model.to_tokens(" Marktzulassung")

tensor([[    0,  4744, 21239,   335,   515,  1947]], device='cuda:0')

- Total ablated loss: 5.7
- Freezing MLP4 leads to loss of 2.9 (loss decrease of 2.8)
    - L4 attention can't be super interesting
- Patching frozen MLP4 activations into both L5 components leads to loss of 4.6803 (indirect effect of MLP4) (loss decrease of 1)
- Activating MLP4 and deactivating both L5 components leads to loss of 4.0831 (direct effect of MLP4) (loss decrease of 1.6)

In [77]:
answer_tokens = torch.LongTensor([1947])
original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

# Remove the effects of ablating at MLP3 from the components after MLP3
def freeze_neurons_hook(value, hook: HookPoint):
    value = original_cache[hook.name]
    return value

activate_4_hooks = [(f"blocks.4.hook_mlp_out", freeze_neurons_hook)]
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+activate_4_hooks):
    frozen_loss, frozen_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

frozen_unembedded = haystack_utils.get_neuron_logit_contribution(frozen_cache, model, answer_tokens, layer=5, pos=-2) # [neuron]
ablated_unembedded = haystack_utils.get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=5, pos=-2)
differences = (ablated_unembedded - frozen_unembedded).detach().cpu() # [neuron]

print(original_loss[0, -1], ablated_loss[0, -1], frozen_loss[0, -1])
top_diff, top_diff_neurons = torch.topk(differences, k=30, largest=True)
haystack_utils.line(top_diff, xticks=top_diff_neurons)

print("Overall L4 improvement", ablated_loss[0, -1]-frozen_loss[0, -1])

tensor(1.4814, device='cuda:0') tensor(5.7540, device='cuda:0') tensor(2.9608, device='cuda:0')


Overall L4 improvement tensor(2.7932, device='cuda:0')


In [78]:
# Use MLP4 activated cache to activate some L5 things to see if they read from it
def activate_5_hook(value, hook: HookPoint):
    #print(value.shape, frozen_cache[hook.name].shape)
    #value = ablated_cache[hook.name]
    #value[:, :, top_diff_neurons] = frozen_cache[hook.name][:, :, top_diff_neurons] # [batch pos neuron]
    value = frozen_cache[hook.name] # [batch pos neuron]
    return value

activate_5_hooks = [(f"blocks.5.mlp.hook_post", activate_5_hook)]
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+activate_5_hooks):#freeze_original_hooks
    only_mlp_5_activated_loss, new_frozen_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

print(only_mlp_5_activated_loss[0, -1])
print("Only MLP5 improvement",  ablated_loss[0, -1]-only_mlp_5_activated_loss[0, -1])

tensor(4.9066, device='cuda:0')
Only MLP5 improvement tensor(0.8474, device='cuda:0')


In [79]:
# Activate MLP4 and deactivate L5 to prevent them from reading from MLP4
def deactivate_5_hook(value, hook: HookPoint):
    value = ablated_cache[hook.name]
    return value

deactivate_5_hooks = [(f"blocks.5.hook_attn_out", deactivate_5_hook), (f"blocks.5.mlp.hook_pre", deactivate_5_hook)]
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+deactivate_5_hooks+activate_4_hooks):#
    only_mlp_4_activated_loss, new_frozen_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

print(only_mlp_4_activated_loss[0, -1])
print("Only MLP5 improvement",  ablated_loss[0, -1]-only_mlp_4_activated_loss[0, -1])

tensor(4.0831, device='cuda:0')
Only MLP5 improvement tensor(1.6709, device='cuda:0')


- Which MLP4 components give us improvement of 1.67 by writing directly into the unembed space
- Which MLP5 components give us improvement of 0.84 by reading MLP4 outputs
- Which MLP4 components write information used by MLP5

In [69]:
def get_loss_contribution_per_neuron(prompt, model, pos=-1, layer=5, plot_top_k=30):
    neurons = torch.LongTensor([i for i in range(model.cfg.d_mlp)])
    loss_per_neuron = []
    for neuron in neurons:
        original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
        with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
            ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

        # Remove the effects of ablating at MLP3 from the components after MLP3
        def freeze_neurons_hook(value, hook: HookPoint):
            value[:, :, neuron] = original_cache[hook.name][:, :, neuron] # [batch pos neuron
            return value      

        freeze_original_hooks = [(f"blocks.{layer}.mlp.hook_post", freeze_neurons_hook)]
        with model.hooks(fwd_hooks=freeze_original_hooks+deactivate_neurons_fwd_hooks):
            ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)
        
        loss_per_neuron.append(ablated_with_original_frozen_loss[0, -1].item())
    top_values, top_neurons = torch.topk(torch.Tensor(loss_per_neuron), k=plot_top_k, largest=False)
    print(top_neurons)
    haystack_utils.line(top_values, xlabel="Neuron", ylabel="Total loss of restoring the neuron", width=1000, xticks=top_neurons)

In [75]:
# Direct loss attributino neurons in MLP4
with model.hooks(fwd_hooks=deactivate_5_hooks):
    get_loss_contribution_per_neuron(prompt, model, layer=4, pos=-1)
l4_differences = haystack_utils.MLP_attribution(prompt, model, deactivate_neurons_fwd_hooks, layer_to_compare=4, pos=-1)
#top_diff, top_diff_neurons = torch.topk(l4_differences, k=50, largest=True)
top_diff_neurons = torch.LongTensor([ 689, 1612, 1992,  660,  102,  592,  307,  360,  602, 1872, 1915, 1444,
         921,  196, 1447,  683,  755,  298,  897, 2001, 1935, 1508,  595, 1213,
        1122, 2030, 1519,  814, 1090,  896])
#top_diff_neurons = torch.LongTensor([i for i in range(model.cfg.d_mlp)])
#haystack_utils.line(top_diff, xticks=top_diff_neurons)
with model.hooks(fwd_hooks=deactivate_5_hooks):
    #get_loss_contribution_per_neuron(prompt, model, top_diff_neurons, layer=4, pos=-1)
    original_loss, ablated_loss, frozen_loss = haystack_utils.get_neuron_loss_attribution(prompt, model, top_diff_neurons[:30], ablation_hooks=deactivate_neurons_fwd_hooks, pos=-1, layer=4)
print(original_loss, ablated_loss, frozen_loss)

3.5167388916015625 5.753985404968262 3.746394395828247


In [82]:
def get_loss_contribution_per_neuron_on_frozen(prompt, model, pos=-1, layer=5, plot_top_k=30):

    original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
    neurons = torch.LongTensor([i for i in range(model.cfg.d_mlp)])
    def freeze_neurons_hook(value, hook: HookPoint):
        value = original_cache[hook.name]
        return value

    activate_4_hooks = [(f"blocks.4.hook_mlp_out", freeze_neurons_hook)]
    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+activate_4_hooks):
        frozen_loss, frozen_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
    loss_per_neuron = []    
    for neuron in neurons:
        # Remove the effects of ablating at MLP3 from the components after MLP3
        def freeze_neurons_hook(value, hook: HookPoint):
            value[:, :, neuron] = frozen_cache[hook.name][:, :, neuron] # [batch pos neuron
            return value      

        freeze_original_hooks = [(f"blocks.{layer}.mlp.hook_post", freeze_neurons_hook)]
        with model.hooks(fwd_hooks=freeze_original_hooks+deactivate_neurons_fwd_hooks):
            ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)
        
        loss_per_neuron.append(ablated_with_original_frozen_loss[0, -1].item())
    top_values, top_neurons = torch.topk(torch.Tensor(loss_per_neuron), k=plot_top_k, largest=False)
    print(top_neurons)
    haystack_utils.line(top_values, xlabel="Neuron", ylabel="Total loss of restoring the neuron", width=1000, xticks=top_neurons)

In [83]:
# Direct loss attributino neurons in MLP5
get_loss_contribution_per_neuron_on_frozen(prompt, model, layer=5, pos=-1)

tensor([ 925, 1354, 1404,  100,  918, 1418,  348, 1425, 1266, 2029,  834,  177,
        1550,  141, 1215, 1330, 1927, 1737,  670,  380, 1881, 1140, 1119,  358,
        1001,  216,  960,  957, 1842, 1250])


In [95]:
pos = -1
top_mlp5_neurons = torch.LongTensor([ 925, 1354, 1404,  100,  918, 1418,  348, 1425, 1266, 2029,  834,  177,
        1550,  141, 1215, 1330, 1927, 1737,  670,  380, 1881, 1140, 1119,  358,
        1001,  216,  960,  957, 1842, 1250])
top_mlp5_neurons = top_mlp5_neurons[:30]

original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

# 1. Get frozen cachhe with MLP4 activations
def freeze_neurons_hook(value, hook: HookPoint):
        value = original_cache[hook.name]
        return value

activate_4_hooks = [(f"blocks.4.hook_mlp_out", freeze_neurons_hook)]
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+activate_4_hooks):
        frozen_loss, frozen_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
#get_loss_contribution_per_neuron(prompt, model, top_diff_neurons, layer=4, pos=-1)

with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
        ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

# Remove the effects of ablating at MLP3 from the components after MLP3
def freeze_neurons_l5_hook(value, hook: HookPoint):
        value[:, :, top_mlp5_neurons] = frozen_cache[hook.name][:, :, top_mlp5_neurons] # [batch pos neuron]
        return value      

freeze_original_hooks = [(f"blocks.5.mlp.hook_post", freeze_neurons_l5_hook)]
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks+freeze_original_hooks):
        ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)
#print(ablated_loss[0, :], ablated_with_original_frozen_loss[0, :])
original_loss[0, pos].item(), ablated_loss[0, pos].item(), ablated_with_original_frozen_loss[0, pos].item()

(1.4813501834869385, 5.753985404968262, 3.802215337753296)

## Other prompts

In [62]:
prompt = "Der Vorschlag zur Änderung dieser Verordnung über die Anwendung des dem Vertrag"
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [63]:
prompt = "Ich möchte nochmals meine Ansicht"
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [22]:
prompt = "   .– Herr Präsident, als Vorsitzender des Ausschusses für Recht und Binnenmarkt sei mir die Behauptung gestattet, dass ich niemals so stolz wie heute Abend war, Vorsitzender eines Ausschusses zu sein, der dem Ausschuss für die Freiheiten und Rechte der Bürger, Justiz und innere Angelegenheiten eine objektive und sehr umfassende Stellungnahme sowie detaillierte Vorschläge vorgelegt hat, in denen er äußerst bewusst und vollkommen sachlich, ohne Widerspruch und sehr rigoros ein Bemühen um Diversifizierung des Eigentums und der Kontrolle des Fernsehens"
k = 20
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

In [23]:
prompt = 'Deshalb ist es für uns als Konföderale Fraktion der Vereinigten Europäischen Linken/Nordische Grüne Linke von grundlegender Bedeutung, Kommission und Rat mit der vom Parlament anzunehmenden Entschließung'
k = 20
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")

## Looking for MLP3-MLP4 Circuit

In [24]:
prompt = 'Ein einzigartiges Erlebnis ist es, mit staunendem Herzen Hand in Hand das Feuerwerk vom 14. Juli von der Rondinara Bucht aus zu verfolgen, gemeinsam über die Strände von Santa Giulia zu'
k = 100
# Check loss MLP5 loss increase when patching clean activations to MLP4
original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, deactivate_neurons_fwd_hooks, k=k, log=False, answer_pos=-1)

list_answers = [[x] for x in [original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss]]
names = ["Original loss", "Ablated Loss", "Ablated loss (restoring MLP4)", "Ablated loss (restoring MLP4 + MLP5)", "Ablated loss (restoring top MLP5 neurons)"]
haystack_utils.plot_barplot(list_answers, names, title=f"Losses on pos={-1} for different frozen components (MLP5: top {k} neurons)")