## Setup

In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LOG_PROB_THRESHOLD = -7
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

    
def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)


libgomp: Thread creation failed: Resource temporarily unavailable
/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598:  arrow::fs::FinalizeS3 was not called even though S3 was initialized.  This could lead to a segmentation fault at exit


: 

In [3]:
def get_pos_loss_diff(prompt: str, model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]], plot_hist=False, use_activate_hook=False, debug_log=True):
    tokens = model.to_tokens(prompt)
    if use_activate_hook:
        original_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=activate_neurons_hook, loss_per_token=True)
    else:
        original_loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=deactivate_neurons_hook, loss_per_token=True)
    
    # Positive difference = loss increase due to ablation
    loss_difference = (ablated_loss - original_loss).flatten()

    if debug_log:
        print(f"Unablated loss: {original_loss.flatten()}")
        print(f"Ablated loss: {ablated_loss.flatten()}")
        print(f"Loss difference: {loss_difference}")

    if plot_hist:
        fig = px.histogram(loss_difference.flatten().cpu().numpy(), title="Loss difference due to ablation per position")
        fig.show()
    return loss_difference

def get_high_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        loss_difference = get_pos_loss_diff(prompt, model, activate_neurons_hook, deactivate_neurons_hook)
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs


In [4]:
def get_top_differences_at_position(prompt: str, model: HookedTransformer, position: int, top_k=20, mode="full"):
    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(tokens)
    # Logprobs instead of logits
    original_logits = model(tokens, return_type="logits")
    if mode=="direct":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_frozen_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    elif mode=="indirect":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_ablated_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    else:
        assert mode=="full"
        ablated_logits = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=deactivate_neurons_fwd_hooks)
    original_logprob = original_logits.log_softmax(dim=-1)
    ablated_logprob = ablated_logits.log_softmax(dim=-1)

    # Positive difference = the German neuron makes the token more likely
    # Negative difference = the German neuron makes the token less likely
    logprob_differences = original_logprob - ablated_logprob
    logit_differences = original_logits - ablated_logits

    print("Prompt:", prompt)
    print(f"Differences for predicting: {str_tokens[position]} -> {str_tokens[position+1]}")

    low_log_prob = torch.argwhere(((original_logprob[0, position, :] <= LOG_PROB_THRESHOLD) & (ablated_logprob[0, position, :] <= LOG_PROB_THRESHOLD))).flatten()
    ignore_tokens = torch.cat([low_log_prob, all_ignore]).unique()
    
    top_original_logprobs, top_original_idx = haystack_utils.top_k_with_exclude(original_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_original_ablated_logprobs = ablated_logprob[0, position, top_original_idx]
    top_ablated_logprobs, top_ablated_idx = haystack_utils.top_k_with_exclude(ablated_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_ablated_original_logprobs = original_logprob[0, position, top_ablated_idx]

    top_original_values = logprob_differences[0, position, top_original_idx]
    top_ablated_values = logprob_differences[0, position, top_ablated_idx]
    top_original_logit_diff = logit_differences[0, position, top_original_idx]
    top_ablated_logit_diff = logit_differences[0, position, top_ablated_idx]
    print("Top predictions with German neuron active (unablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_original_idx), top_original_values.cpu().tolist(), max_value=5, original_log_probs=top_original_logprobs.cpu().tolist(), ablated_log_probs=top_original_ablated_logprobs.cpu().tolist(), logit_difference=top_original_logit_diff.cpu().tolist())
    print("Top predictions with German neuron disabled (ablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_ablated_idx), top_ablated_values.cpu().tolist(), max_value=5, original_log_probs=top_ablated_original_logprobs.cpu().tolist(), ablated_log_probs=top_ablated_logprobs.cpu().tolist(), logit_difference=top_ablated_logit_diff.cpu().tolist())

    top_boosts, top_boosted_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_boost_original_logprob = original_logprob[0, position, top_boosted_idx]
    top_boost_ablated_logprob = ablated_logprob[0, position, top_boosted_idx]
    top_reduced, top_reduced_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, largest=False, exclude=ignore_tokens)
    top_reduced_original_logprob = original_logprob[0, position, top_reduced_idx]
    top_reduced_ablated_logprob = ablated_logprob[0, position, top_reduced_idx]
    print("Top boosted tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_boosted_idx), top_boosts.cpu().tolist(), max_value=5, original_log_probs=top_boost_original_logprob.cpu().tolist(), ablated_log_probs=top_boost_ablated_logprob.cpu().tolist())
    print("Top reduced tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_reduced_idx), top_reduced.cpu().tolist(), max_value=5, original_log_probs=top_reduced_original_logprob.cpu().tolist(), ablated_log_probs=top_reduced_ablated_logprob.cpu().tolist())

In [5]:
def show_token_loss(prompt: str, model: HookedTransformer, max_value=None, mode="full", freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")):
    
    original_loss, total_effect_loss_change, direct_effect_loss_change, indirect_effect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False)
    if mode== "indirect":
        pos_wise_loss = indirect_effect_loss_change
        #pos_wise_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    elif mode == "direct":
        pos_wise_loss = direct_effect_loss_change
        #pos_wise_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    else:
        assert mode =="full"
        pos_wise_loss = total_effect_loss_change
        #pos_wise_loss = get_pos_loss_diff(prompt, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks, plot_hist=False)
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    haystack_utils.print_strings_as_html(str_token_prompt[1:], pos_wise_loss.flatten().cpu().tolist(), max_value=max_value)

def print_predictions(prompt, pos, k=20):
    print("\nFull model predictions")
    get_top_differences_at_position(prompt, model, pos, k, mode="full")
    print("\nIndirect predictions (leave German neuron active, patch corrupted activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="indirect")
    print("\nDirect predictions (ablate German neuron, patch clean activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="direct")

In [114]:
def pos_wise_mlp_effect_on_single_prompt(prompt: str, k = 20, log=False, top_neurons=None, answer_pos=None):
    
    if answer_pos is not None:
        pos = answer_pos
        assert pos != 0, "First answer position = 1"
        if pos < 0:
            pos = model.to_tokens(prompt).shape[1]-1+pos
        elif pos > 0:
            pos -= 1
    else:
        pos = None

    if (top_neurons is not None) and (len(top_neurons) < k):
            print(f"Warning: Only {len(top_neurons)} neurons given for k={k}.")

    original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss = haystack_utils.get_mlp5_attribution_without_mlp4(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, pos=pos)
    if top_neurons is None:
        differences = haystack_utils.MLP_attribution(prompt, model, fwd_hooks=deactivate_neurons_fwd_hooks, layer_to_compare=5, pos=pos)
        # Shape (pos, k)
        top_diff, top_diff_neurons = torch.topk(differences, k, largest=True)
        #print(top_diff_neurons)
    else:
        top_diff_neurons = torch.LongTensor(top_neurons)
    
    # TODO Frozen loss does not match for single position and pos=None
    if pos is None:
        _, _, frozen_loss = haystack_utils.get_neuron_loss_attribution(prompt, model, top_diff_neurons[:, :k], ablation_hooks=deactivate_neurons_fwd_hooks, pos=pos)
    else:
        _, _, frozen_loss = haystack_utils.get_neuron_loss_attribution(prompt, model, top_diff_neurons[:k], ablation_hooks=deactivate_neurons_fwd_hooks, pos=pos)
    
    ablation_loss_increase = total_effect_loss - original_loss
    frozen_loss_decrease = total_effect_loss - frozen_loss

    if log and (pos is not None):
        print(f"\n{prompt}")
        print(f"Original loss: {original_loss:.4f}")
        print(f"Total effect loss: {total_effect_loss:.4f}")#
        print(f"Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): {direct_mlp3_mlp5_loss:.4f}")
        print(f"Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): {direct_mlp3_loss:.4f}")
        print(f"Total effect loss when freezing top MLP5 neurons: {frozen_loss:.4f}")
    elif log:
        print(f"\n{prompt}")
        print(f"Original loss: {original_loss}")
        print(f"Total effect loss: {total_effect_loss}")#
        print(f"Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): {direct_mlp3_mlp5_loss}")
        print(f"Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): {direct_mlp3_loss}")
        print(f"Total effect loss when freezing top MLP5 neurons: {frozen_loss}")
    
    return original_loss, total_effect_loss, direct_mlp3_mlp5_loss, direct_mlp3_loss, frozen_loss

In [115]:
pos_wise_mlp_effect_on_single_prompt("Ich möchte nochmals meine Ansicht", k=20, log=True, answer_pos=None)


Ich möchte nochmals meine Ansicht
Original loss: tensor([10.5044,  5.3246,  0.0203,  4.4463,  3.9146,  0.4182,  6.2384,  2.9002,
         9.2651,  4.7154], device='cuda:0')
Total effect loss: tensor([10.4859,  6.1975,  0.0941,  5.4040,  4.3693,  1.0097,  6.9663,  4.0273,
        10.5419, 12.9027], device='cuda:0')
Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): tensor([10.4888,  5.9200,  0.0217,  4.7280,  3.5013,  0.6890,  6.1393,  2.8607,
         9.1133,  7.9633], device='cuda:0')
Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): tensor([10.4970,  5.8691,  0.0374,  5.2219,  2.9650,  0.2496,  6.1509,  2.6764,
         9.3340,  3.3191], device='cuda:0')
Total effect loss when freezing top MLP5 neurons: tensor([10.4843,  5.7237,  0.0467,  4.6787,  2.7740,  0.0540,  5.3758,  1.9933,
         9.5463,  7.1096], device='cuda:0')


(tensor([10.5044,  5.3246,  0.0203,  4.4463,  3.9146,  0.4182,  6.2384,  2.9002,
          9.2651,  4.7154], device='cuda:0'),
 tensor([10.4859,  6.1975,  0.0941,  5.4040,  4.3693,  1.0097,  6.9663,  4.0273,
         10.5419, 12.9027], device='cuda:0'),
 tensor([10.4888,  5.9200,  0.0217,  4.7280,  3.5013,  0.6890,  6.1393,  2.8607,
          9.1133,  7.9633], device='cuda:0'),
 tensor([10.4970,  5.8691,  0.0374,  5.2219,  2.9650,  0.2496,  6.1509,  2.6764,
          9.3340,  3.3191], device='cuda:0'),
 tensor([10.4843,  5.7237,  0.0467,  4.6787,  2.7740,  0.0540,  5.3758,  1.9933,
          9.5463,  7.1096], device='cuda:0'))

In [117]:
pos_wise_mlp_effect_on_single_prompt("Ich möchte nochmals meine Ansicht", k=20, log=True, answer_pos=-1)


Ich möchte nochmals meine Ansicht
Original loss: 4.7154
Total effect loss: 12.9027
Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): 7.9633
Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): 3.3191
Total effect loss when freezing top MLP5 neurons: 7.1096


(4.715377330780029,
 12.902650833129883,
 7.963266372680664,
 3.3190646171569824,
 7.10962438583374)

## Look for interesting examples

In [10]:
def get_interesting_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        original_loss, total_effect_loss_change, direct_loss, indirect_loss = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, debug_log=False)
        #indirect_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks)
        #direct_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks)
        loss_difference = abs(direct_loss - indirect_loss)
        loss_difference[indirect_loss < 1] = 0
        loss_difference[direct_loss > 1] = 0
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs

n_examples = 1000
max_diffs, average_diffs = get_interesting_loss_prompts(german_data, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks)

loss_data_tuple = [(diff, example) for diff, example in zip(max_diffs, german_data[:n_examples])]
loss_data_tuple.sort(key=lambda x: x[0], reverse=True)
loss_data_tuple[:2]

  0%|          | 0/2459 [00:00<?, ?it/s]

[(11.32781982421875,
  'Ich möchte nochmals meine Ansicht wiederholen, dass die Menschenrechte als ein selbstverständlicher Teil fest in die Außenpolitik der EU verankert werden müssen, und ich glaube, dass die neue institutionelle Struktur der EU, besonders im Hinblick auf den EAD und die zu ihm gehörige Abteilung, eine Gelegenheit bietet, den Zusammenhalt und die Effizienz der EU auf diesem Gebiet zu stärken; ich fordere die Vizepräsidentin/Hohe Vertreterin dringend dazu auf, die Initiative zu ergreifen und durch bilaterale Beziehungen zu Drittländern und eine aktive Teilnahme an internationalen Foren dafür zu sorgen, dass sich die Drittländer für den Schutz der Menschenrechte engagieren und sich gegen Verletzungen der Menschenrechte wehren, anstatt vor den nötigen Maßnahmen zurückzuschrecken, wenn sie missachtet werden; angesichts des wachsenden Ausmaßes der schweren Verletzungen der Glaubensfreiheit, rufe ich die Kommission dazu auf, eine gründliche Evaluierung vorzunehmen und die 

In [11]:
def show_all_loss_types(prompt):
    print("Full model loss")
    show_token_loss(prompt, model, max_value=5, mode="full")
    print("Indirect loss")
    show_token_loss(prompt, model, max_value=5, mode="indirect")
    print("Direct loss")
    show_token_loss(prompt, model, max_value=5, mode="direct")

In [12]:
for _ , prompt in loss_data_tuple[0:2]:
    print("")
    show_token_loss(prompt, model, max_value=5, mode="full")
    show_token_loss(prompt, model, max_value=5, mode="indirect")
    show_token_loss(prompt, model, max_value=5, mode="direct")







In [13]:
prompt = "Ich möchte nochmals meine Ansicht"
# Check loss MLP5 loss increase when patching clean activations to MLP4
show_all_loss_types(prompt)
get_mlp5_attribution_without_mlp4(prompt)

Full model loss


Indirect loss


Direct loss


NameError: name 'get_mlp5_attribution_without_mlp4' is not defined

In [14]:
pos_wise_mlp_effect_on_single_prompt("Ich möchte nochmals meine Ansicht", k=20, log=True, pos=-1)

torch.Size([512])
torch.Size([512])

Ich möchte nochmals meine Ansicht
MLP 5 attribution
Original loss: 4.7154
Total effect loss: 12.9027
Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): 7.9633
Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): 3.3191


TypeError: unsupported format string passed to Tensor.__format__