In [10]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [11]:
news_data = haystack_utils.load_txt_data("german_news.csv")
print(len(news_data))

german_news.csv: Loaded 9246 examples with 0 to 22124 characters each.
9246


In [12]:
german_news_data = []
for example in tqdm(news_data[:-1]):
    index = example.index(";")
    example = example[index+1:]
    if len(example) > 500:
        german_news_data.append(example[:min(len(example), 2000)])

print(len(german_news_data))

  0%|          | 0/9245 [00:00<?, ?it/s]

8979


In [13]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")


english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)


LOG_PROB_THRESHOLD = -7
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [14]:

german_neurons_l3 = [669, 1204]
german_neurons_l4 = [482, 326, 1903]
deprecated_german_neurons_l4 = [482, 326]

def deprecated_deactivate_neurons_hook_l4(value, hook):
    value[:, :, deprecated_german_neurons_l4] = english_activations[4][:, deprecated_german_neurons_l4].mean()
    return value
deprecated_deactivate_neurons_fwd_hooks_l4=[(f'blocks.{4}.mlp.hook_post', deprecated_deactivate_neurons_hook_l4)]

def deactivate_neurons_hook_l4(value, hook):
    value[:, :, german_neurons_l4] = english_activations[4][:, german_neurons_l4].mean()
    return value
deactivate_neurons_fwd_hooks_l4=[(f'blocks.{4}.mlp.hook_post', deactivate_neurons_hook_l4)]

def activate_neurons_hook_l4(value, hook):
    value[:, :, german_neurons_l4] = german_activations[4][:, german_neurons_l4].mean()
    return value
activate_neurons_fwd_hooks_l4=[(f'blocks.{4}.mlp.hook_post', activate_neurons_hook_l4)]

def deprecated_activate_neurons_hook_l4(value, hook):
    value[:, :, deprecated_german_neurons_l4] = german_activations[4][:, deprecated_german_neurons_l4].mean()
    return value
deprecated_activate_neurons_fwd_hooks_l4=[(f'blocks.{4}.mlp.hook_post', deprecated_activate_neurons_hook_l4)]

def deactivate_neurons_hook_l3(value, hook):
    value[:, :, german_neurons_l3] = english_activations[3][:, german_neurons_l3].mean()
    return value
deactivate_neurons_fwd_hooks_l3=[(f'blocks.{3}.mlp.hook_post', deactivate_neurons_hook_l3)]
deactivate_neurons_fwd_hooks_l3_l4=deactivate_neurons_fwd_hooks_l3+deactivate_neurons_fwd_hooks_l4

def activate_neurons_hook_l3(value, hook):
    value[:, :, german_neurons_l3] = german_activations[3][:, german_neurons_l3].mean()
    return value
activate_neurons_fwd_hooks_l3=[(f'blocks.{3}.mlp.hook_post', activate_neurons_hook_l3)]
activate_neurons_fwd_hooks_l3_l4=activate_neurons_fwd_hooks_l3+activate_neurons_fwd_hooks_l4


In [15]:
#german_data = german_news_data
print(len(german_data))

2000


# Direct MLP5 effects

In [16]:
german_losses = []
for prompt in tqdm(german_data[:2000]):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=None, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
    german_losses.append((original_loss, ablated_loss, context_and_activated_loss, only_activated_loss))

  0%|          | 0/2000 [00:00<?, ?it/s]

In [17]:
def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    loss_diff = (ablated_loss - original_loss) # High ablation loss increase
    mlp_5_power = (only_activated_loss - original_loss) # Low loss increase from MLP5
    mlp_5_power[mlp_5_power < 0] = 0
    combined = 0.5*loss_diff - mlp_5_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

In [18]:
def get_mlp5_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

measure = get_mlp5_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)

In [19]:
def print_prompt(prompt: str):
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=None, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)

    pos_wise_diff = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, max_value=5, additional_measures=loss_list, additional_measure_names=loss_names)

In [20]:
average_loss_plot = haystack_utils.get_average_loss_plot_method(activate_neurons_fwd_hooks_l3_l4, deactivate_neurons_fwd_hooks_l3_l4, "MLP5")

## Looking for ngrams

In [21]:
for i, measure in sorted_measure[:2]:
    print_prompt(german_data[i])

In [22]:
prompts = [
    "Der Handwerker wird einen maßgeschneiderten Schrank fertigen.",
    "Sie fertigt kunstvolle Keramikobjekte in ihrer Werkstatt.",
    "Der Architekt plant das Haus und der Bauunternehmer fertigt es.",
    "Das Unternehmen fertigt hochwertige Elektronikprodukte.",
    "Er hat eine Skulptur aus Metall fertiggestellt.",
    "Der Bäcker wird morgen früh das Brot fertigbacken.",
    "Die Designerin fertigt einzigartige Kleider für ihre Kunden.",
    "Der Ingenieur fertigt einen detaillierten Bauplan für das Projekt.",
    "Die Fabrik fertigt täglich tausende von Autos.",
    "Die Schneiderei fertigt individuelle Kleidungsstücke nach Maß."
]


for prompt in prompts:
    print_prompt(prompt)

## häufig

In [23]:
prompts = [
    "In der Sommerzeit kommt es häufig",
    "Der Zug hat häufig",
    "Eine gesunde Ernährung ist häufig",
    "Bei stressigen Situationen treten häufig",
    "Im Winter erkranken Menschen häufig",
    "In vielen Unternehmen werden häufig",
    "Eine gute Kommunikation ist häufig",
    "Um Verkehrsstaus zu vermeiden, fahren viele Pendler häufig"
]

for prompt in prompts:
    print_prompt(prompt)

In [24]:
average_loss_plot(prompts, model, token="hä-u-fig")

## schließt

In [25]:
prompts = [
    "Bob, nachdem er alle Optionen gründlich erwogen hat, schließt",
    "Der Ausschuss hat die Angelegenheit gründlich erwogen und schließt",
    "Die Kommission hat eine wichtige Entscheidung getroffen und schließt",
    "Zusammen haben sie die Angelegenheit gründlich erwogen und schließt",
    "Die Polizei schließt"
]

for prompt in prompts:
    print_prompt(prompt)


In [26]:
average_loss_plot(prompts, model, token="sch-lie-ß-t")

## beweglich

In [27]:
prompts = [
    "Zum Glück hat sich Klaus gut von seinem Unfall erholt und blieb geistig äußerst beweglich",
    "Das neue Design des Roboteres macht die Gliedmaßen besonders weit beweglich",
    "Die Roboterarme sind flexibel und beweglich",
    "Der Künstler hat eine Skulptur geschaffen, die optisch sehr beweglich",
    "Die Yoga-Übungen machen den Körper geschmeidig und beweglich",
    "Die Schubladen des Schranks sind dank der Rollen besonders beweglich",
    "Das neue Mobiltelefon ist leicht und beweglich",
    "Die Puppen sind sehr beweglich",
    "Die Beine vom Tisch sind beweglich",
    "Die Schmetterlinge fliegen anmutig und beweglich"
]



for prompt in prompts:
    print_prompt(prompt)

In [28]:
average_loss_plot(prompts, model, token="be-we-glich")

## Ansicht (not a good example)

In [29]:
prompts = [
    "Ich finde ihre Ansicht",
    "Ich teile seine Ansicht",
    "Ich teile deine Ansicht",
    "Die politische Ansicht",
    "Sie teilt die Ansicht",
    "Ich stimme mit deiner Ansicht"
]

for prompt in prompts:
    print_prompt(prompt)

In [30]:
average_loss_plot(prompts, model, token="Ans-icht")

## Voschlägen

In [31]:
# nächster
prompts = ["Ich habe noch einige Fragen zu Ihren Vorschlägen",
    "Ich stimme den Vorschlägen",
    "Kannst du bitte mit deinen Vorschlägen",
    "Laut den Vorschlägen",
    "Die Diskussion wurde mit vielen interessanten Vorschlägen",
    "Sie schrieb einen Brief mit ihren Vorschlägen",
    "Wir werden nach deinen Vorschlägen",
    "Sind Sie mit diesen Vorschlägen",
    "Gemäß den Vorschlägen",
    "Der Kunde war nicht einverstanden mit unseren Vorschlägen",
    "Sie zeigte uns ein Dokument mit ihren Vorschlägen",
    "Das Team arbeitet an den Vorschlägen",
    "Meinen Vorschlägen",
    "Der Ausschuss wird nach den Vorschlägen",
    "Sie war unzufrieden mit den Vorschlägen",
    "Unsere Agentur kam mit neuen Vorschlägen",
    "Haben Sie Änderungen zu den Vorschlägen",
    "Gemäß Ihren Vorschlägen",
    "Ich schrieb einen Bericht mit meinen Vorschlägen",
    "Nach den Vorschlägen",
    "Ich werde gemäß Ihren Vorschlägen",
    "Der Manager war unzufrieden mit den Vorschlägen",
    "Mit diesen Vorschlägen",
    "Der Ausschuss stimmte den Vorschlägen",
    "Der Leiter war sehr zufrieden mit den Vorschlägen",
    "Nach den aktuellen Vorschlägen",
    "Mit ihren innovativen Vorschlägen",
    "Mit einigen Verbesserungen zu Ihren Vorschlägen",
    "Das Team kam mit neuen Vorschlägen",
    "Sie kam mit großartigen Vorschlägen",
    "Ich werde mit meinen Vorschlägen",
    "Sie zeigte sich zufrieden mit den Vorschlägen",
    "Der Direktor war beeindruckt von den Vorschlägen",
    "Die Organisation hat nach Ihren Vorschlägen",
    "Wir haben ein Dokument mit den Vorschlägen",
    "Mit einigen Vorschlägen",
    "Der Lehrer war sehr zufrieden mit den Vorschlägen",
    "Nach Ihren Vorschlägen",
    "In Übereinstimmung mit den Vorschlägen",
    "Sie stimmte den Vorschlägen",
    "Mit diesen neuen Vorschlägen",
    "Ich bin sehr beeindruckt von Ihren Vorschlägen",
    "Das Unternehmen hat nach unseren Vorschlägen",
    "Die Jury war beeindruckt von den Vorschlägen",
    "Die Verwaltung hat nach Ihren Vorschlägen",
    "Das Publikum war sehr zufrieden mit den Vorschlägen",
]


for prompt in prompts[:5]:
    print_prompt(prompt)

In [32]:
average_loss_plot(prompts, model, token="Real prompts ending in [V-orsch-lä-gen]")

## Check previous tokens actually matter

### Generate some likely tokens from german data

In [33]:
token_counts = torch.zeros(model.cfg.d_vocab)
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

  0%|          | 0/2000 [00:00<?, ?it/s]

In [34]:
token_counts[all_ignore] = 0

punctuation = [".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

[' der', '\n', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', 'ch', 'st', ' zu', 're', ' für', 'äsident', ' Pr', 'n', 'z', 'ischen', ' von', 'ü', 't', 'icht', 'in', 'ge', 'gen', 'te', ' ist', ' auf', 'ig', ' über', ' dass', ' im', 'f', ' er', 'es', ' das', 'men', 'g', 'ß', ' Europ', ' w', 'w', 'le', 'ten', ' eine', ' wir', ' ein', ' an', 'hen', 'ren', 'e', ' ich', 'ungen', ' W', ' Ver', ' B', ' dem', ' mit', ' dies', ' nicht', ' Z', 'h', ' z', 's', 'it', 'hr', ' es', ' zur', ' An', ' Herr', 'ich', 'heit', 'b', 'lich', 'l', ' ver', ' S', ' G', 'i', 'Der', ' V', 'der', ' Ab', 'u', 'ie', 'ungs', 'chte', 'chaft', 'igen', ' werden', 'uss', 'ord', 'em', ' Ber', 'ür', ' haben', 'et', ' um']


In [35]:
def replace_token(prompts, replace_index, token="", num_replacements=20):
    new_prompts = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)
        for i in range(num_replacements):
            if tokens[0, replace_index].item() != top_tokens[i].item():
                new_tokens = tokens.clone()
                new_tokens[0, replace_index] = top_tokens[i]
                new_prompts.append(new_tokens)
            else:
                print("skipping")
    average_loss_plot(new_prompts, model, token=token)

def replace_token_loss(prompts, replace_index, num_replacements=10):
    new_prompts = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)
        for i in range(num_replacements):
            #if tokens[0, replace_index].item() != top_tokens[i].item():
            new_tokens = tokens.clone()
            new_tokens[0, replace_index] = top_tokens[i]
            new_prompts.append(new_tokens)
            #else:
            #    print("skipping")
    losses = []
    for prompt in new_prompts:    
        with model.hooks(fwd_hooks=activate_neurons_fwd_hooks_l3_l4):
            original_loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1].item()
        losses.append(original_loss)
    return losses

In [36]:
prompts = ["Ich habe noch einige Fragen zu Ihren Vorschlägen",
    "Ich stimme den Vorschlägen",
    "Kannst du bitte mit deinen Vorschlägen",
    "Laut den Vorschlägen",
    "Die Diskussion wurde mit vielen interessanten Vorschlägen",
    "Sie schrieb einen Brief mit ihren Vorschlägen",
    "Wir werden nach deinen Vorschlägen",
    "Sind Sie mit diesen Vorschlägen",
    "Gemäß den Vorschlägen",
    "Der Kunde war nicht einverstanden mit unseren Vorschlägen",
    "Sie zeigte uns ein Dokument mit ihren Vorschlägen",
    "Das Team arbeitet an den Vorschlägen",
    "Meinen Vorschlägen",
    "Der Ausschuss wird nach den Vorschlägen",
    "Sie war unzufrieden mit den Vorschlägen",
    "Unsere Agentur kam mit neuen Vorschlägen",
    "Haben Sie Änderungen zu den Vorschlägen",
    "Gemäß Ihren Vorschlägen",
    "Ich schrieb einen Bericht mit meinen Vorschlägen",
    "Nach den Vorschlägen",
    "Ich werde gemäß Ihren Vorschlägen",
    "Der Manager war unzufrieden mit den Vorschlägen",
    "Mit diesen Vorschlägen",
    "Der Ausschuss stimmte den Vorschlägen",
    "Der Leiter war sehr zufrieden mit den Vorschlägen",
    "Nach den aktuellen Vorschlägen",
    "Mit ihren innovativen Vorschlägen",
    "Mit einigen Verbesserungen zu Ihren Vorschlägen",
    "Das Team kam mit neuen Vorschlägen",
    "Sie kam mit großartigen Vorschlägen",
    "Ich werde mit meinen Vorschlägen",
    "Sie zeigte sich zufrieden mit den Vorschlägen",
    "Der Direktor war beeindruckt von den Vorschlägen",
    "Die Organisation hat nach Ihren Vorschlägen",
    "Wir haben ein Dokument mit den Vorschlägen",
    "Mit einigen Vorschlägen",
    "Der Lehrer war sehr zufrieden mit den Vorschlägen",
    "Nach Ihren Vorschlägen",
    "In Übereinstimmung mit den Vorschlägen",
    "Sie stimmte den Vorschlägen",
    "Mit diesen neuen Vorschlägen",
    "Ich bin sehr beeindruckt von Ihren Vorschlägen",
    "Das Unternehmen hat nach unseren Vorschlägen",
    "Die Jury war beeindruckt von den Vorschlägen",
    "Die Verwaltung hat nach Ihren Vorschlägen",
    "Das Publikum war sehr zufrieden mit den Vorschlägen",
]

In [37]:
average_loss_plot(prompts, model, token="V-orsch-lä-gen")

In [38]:
# Replace second to last token
replace_token(prompts, -2, token="V-orsch-X-gen")

KeyboardInterrupt: 

In [None]:
# Replace third to last token
replace_token(prompts, -3, token="V-X-lä-gen")

In [None]:
# Replace V token
replace_token(prompts, -4, token="X-orsch-lä-gen")

In [None]:
# Replace first token
replace_token(prompts, 1, token="X .... V-orsch-lä-gen")

In [None]:
prompt = "Ich habe noch einige Fragen zu Ihren Vorschlägen"
tokens = model.to_tokens(prompt)

losses = []
names = []

with model.hooks(fwd_hooks=activate_neurons_fwd_hooks_l3_l4):
    original_loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1].item()
losses.append([original_loss]*20)
names.append("Original")
for pos in range(1, tokens.shape[1]-1):
    loss = replace_token_loss([prompt], pos, num_replacements=20)
    losses.append(loss)
    names.append(model.to_str_tokens(tokens[0, pos])[0])
haystack_utils.plot_barplot(losses, names, ylabel="Loss", xlabel="Replaced token", title=f"Average loss when replacing single token with top 20 German unigrams")

# Compare original loss with 'V' in per-token chart
#replace_token([prompt], -4, token="X-orsch-lä-gen on single prompt")


In [None]:
# Compare with per-token chart
replace_token([prompt], -4, token="X-orsch-lä-gen on single prompt")

In [None]:
new_prompts = ["Ich habe noch einige Fragen zu Ihren Vorschlägen"]

average_loss_plot(new_prompts, model, token="V-orsch-lä-gen")

In [None]:
new_prompts = [
    #" der die unden in enorschlägen"
    "Ich habe noch einige Fragen zu Ihren Vorsch dergen"
]

average_loss_plot(new_prompts, model, token="V-orsch-lä-gen")

## Weird behavior? Why is the loss so low...

In [None]:
prompt = "Kannst du bitte mit deinen Vorschlägen"
with model.hooks(fwd_hooks=activate_neurons_fwd_hooks+activate_neurons_fwd_hooks_l4):
    original_loss, cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

In [None]:
logit_attr_original, labels = haystack_utils.DLA([prompt], model)

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
logit_attr_ablated , _ = haystack_utils.DLA(["Kannst du bitte mit deinenenorschlägen"], model)

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
logit_diffs = (logit_attr_original - logit_attr_ablated).mean(0)
haystack_utils.line(logit_diffs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="(Original DLA - Ablated DLA) per component", xticks=labels)

In [None]:
print("Prompt: Kannst du bitte mit deinen Vorschlägen")
print("\nGerman neurons left as is")
prompt = "Kannst du bitte mit deinen Vorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print(f"Loss on original prompt: {loss:.2f}")
prompt = "Kannst du bitte mit deinenenorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print(f"Loss when replacing ' V' with common german unigram 'en': {loss:.2f}")
prompt = " der die unden in enorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print(f"Loss when replacing all early tokens with common german unigrams: {loss:.2f}")

print("\nGerman neurons manually activated")
with model.hooks(fwd_hooks=activate_neurons_fwd_hooks_l3_l4):
    prompt = "Kannst du bitte mit deinen Vorschlägen"
    loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
    print(f"Loss on original prompt: {loss:.2f}")
    prompt = "Kannst du bitte mit deinenenorschlägen"
    loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
    print(f"Loss when replacing ' V' with common german unigram 'en': {loss:.2f}")
    prompt = " der die unden in enorschlägen"
    loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
    print(f"Loss when replacing all early tokens with common german unigrams: {loss:.2f}")

Prompt: Kannst du bitte mit deinen Vorschlägen

German neurons left as is
Loss on original prompt: 3.22
Loss when replacing ' V' with common german unigram 'en': 0.48
Loss when replacing all early tokens with common german unigrams: 0.35

German neurons manually activated
Loss on original prompt: 3.94
Loss when replacing ' V' with common german unigram 'en': 0.66
Loss when replacing all early tokens with common german unigrams: 0.39


In [None]:
prompt = " this that is die Vorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print(loss)

tensor(2.4376, device='cuda:0')


In [None]:
prompt = " der die unden in enorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print(loss)

tensor(0.3482, device='cuda:0')


In [None]:
prompt = " der die unden inen enorschlägen"
prob = model(prompt, return_type="logits").log_softmax(-1)
answer_token = model.to_single_token("gen")
print(prob[0, -2, answer_token])

tensor(-0.8951, device='cuda:0')


In [None]:
tokens = model.to_tokens(prompt)
pred_log_probs = model(prompt, return_type="logits").log_softmax(-1)

log_probs_for_predicted_tokens = prob[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
log_probs_for_predicted_tokens

tensor([[-15.2951,  -6.8866, -13.0358,  -9.9525,  -7.5594,  -8.8534, -14.7382,
          -6.0190,  -0.8951]], device='cuda:0')

In [None]:
prompt = "Kannst du bitte mit deinen Vorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print_prompt(prompt)
print(loss)
prompt = " derannst du bitte mit deinen Vorschlägen"
loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1]
print_prompt(prompt)
print(loss)

tensor(3.2218, device='cuda:0')


tensor(2.2444, device='cuda:0')


In [None]:
# Per-token chart averaged over all prompts, stepping backwards through positions to line up the vorschlagen tokens for prompts of different lengths
def get_new_prompts_from_tokens(prompts: list[torch.Tensor], replace_index, num_replacements=10):
    new_prompts = []
    for prompt in prompts:
        tokens = prompt
        for i in range(num_replacements):
            new_tokens = tokens.clone()
            new_tokens[replace_index] = top_tokens[i]
            new_prompts.append(new_tokens)
    return new_prompts

def get_new_prompts(prompts, replace_index, num_replacements=10):
    new_prompts = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)
        if tokens.shape[1] <= replace_index:
            continue
        if abs(pos) > tokens.shape[1]:
            continue
        for i in range(num_replacements):
            new_tokens = tokens.clone()
            new_tokens[0, replace_index] = top_tokens[i]
            new_prompts.append(new_tokens)
    return new_prompts

def replace_token_loss_average(prompts, replace_index, num_replacements=10):
    new_prompts = get_new_prompts(prompts, replace_index, num_replacements)

    losses = []
    for prompt in new_prompts:    
        with model.hooks(fwd_hooks=activate_neurons_fwd_hooks): # +activate_neurons_fwd_hooks_l4
            original_loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1].item()
        losses.append(original_loss)
    
    return sum(losses) / len(losses)

# losses = []
# names = []
# original_losses = []
# for prompt in prompts:
#     with model.hooks(fwd_hooks=activate_neurons_fwd_hooks+activate_neurons_fwd_hooks_l4):
#         original_losses.append(model(prompt, return_type="loss", loss_per_token=True).flatten()[-1].item())
# original_loss = sum(original_losses) / len(original_losses)    
# losses.append([original_loss])
# names.append("Original")

# for pos in range(tokens.shape[1]-2, 1, -1):
#     loss = replace_token_loss(prompts, pos, num_replacements=20)
#     losses.append([loss])
#     names.append(str(pos))

# haystack_utils.plot_barplot(losses, names, ylabel="Loss", title=f"Average loss")

In [None]:
max_len = model.to_tokens(prompts).shape[1]

losses = []
names = []
original_losses = []
for prompt in prompts:
    with model.hooks(fwd_hooks=activate_neurons_fwd_hooks_l3_l4):
        original_losses.append(model(prompt, return_type="loss", loss_per_token=True).flatten()[-1].item())
original_loss = sum(original_losses) / len(original_losses)    
losses.append([original_loss])
names.append("Original")

for pos in range(-2, -max_len-1, -1):
    loss = replace_token_loss_average(prompts, pos, num_replacements=20)
    losses.append([loss])
    names.append(str(pos))

haystack_utils.plot_barplot(losses, names, ylabel="Loss", xlabel="Position of replaced token", title=f"Average loss of replacing a single token with common german tokens across multiple prompts")

## Investigate MLP5 on random common replacement

In [None]:
top_tokens

tensor([ 1784,   187,   257,  3150,  3807,  1947,  1392,   275,  1850,   711,
          348,   296, 10736,   250, 13417, 49560,  2604,    79,    91, 16050,
         8449,  3090,    85, 11014,   249,   463,  1541,   442, 10863, 12606,
          304, 20150, 18117,   516,    71,  2827,   265,  9527,  3767,    72,
        10278, 13124,   259,    88,   282,  1866, 15827, 19129,  9416,   271,
          864,   445,    70, 18119, 17079,   411,  7188,   378,  1471,  4784,
         9778, 13014,  1503,    73,  1182,    84,   262,  6285,  1578, 24499,
          743, 27779,   469, 23867,    67, 14671,    77,  2336,   322,   443,
           74, 19932,   657,   491,  3506,    86,   466, 25190, 39997, 19696,
         3855, 19748,  1316,   636,   358,  6193,  8824, 25433,   292,  5111])

In [None]:
def get_random_selection(tensor, n=12):
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_prompts(end_tokens, n=10, length=12):
    '''Returns prompts as tokens'''
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:50], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    return prompts

def loss_analysis_random_prompts(end_string, n=50, length=12, single_ablation_mode=True):
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = generate_prompts(end_tokens, n=n, length=length)

    original_losses, ablated_losses,_, only_activated_losses = [], [], [], []
    names = ["Original", "Ablated", "MLP5 path patched"]
    for prompt in prompts:
        if single_ablation_mode:
            original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = \
                haystack_utils.get_direct_effect(prompt, model, pos=-1,
                                                context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                                context_activation_hooks=activate_neurons_fwd_hooks, 
                                                )
        else:
            original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = \
            haystack_utils.get_direct_effect(prompt, model, pos=-1,
                                            context_ablation_hooks=deactivate_neurons_fwd_hooks_l3_l4, 
                                            context_activation_hooks=activate_neurons_fwd_hooks_l3_l4,
                                            )
        original_losses.append(original_loss)
        ablated_losses.append(ablated_loss)
        only_activated_losses.append(only_activated_loss)
    haystack_utils.plot_barplot([original_losses, ablated_losses, only_activated_losses], names, ylabel="Loss", title=f"Average last token loss on {length} random tokens ending in '{model.to_str_tokens(end_tokens)}'")


In [None]:
loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)

In [None]:
loss_analysis_random_prompts("orschlägen", n=100, length=20)

In [None]:
loss_analysis_random_prompts("lägen", n=100, length=20)

In [None]:
loss_analysis_random_prompts("gen", n=100, length=20)

## Other bigrams

In [None]:
loss_analysis_random_prompts(" häufig", n=100, length=20)

In [None]:
loss_analysis_random_prompts("ufig", n=100, length=20)

In [None]:
loss_analysis_random_prompts(" schließt", n=100, length=20)

In [None]:
loss_analysis_random_prompts("ließt", n=100, length=20, single_ablation_mode=False)

In [None]:
loss_analysis_random_prompts("ßt", n=100, length=20)

In [None]:
loss_analysis_random_prompts(" beweglich", n=100, length=20)

In [None]:
loss_analysis_random_prompts("weglich", n=100, length=20)

- no clean cutoff for how many important neurons we select
- should we recover total average loss?
- calculate (ablated loss - patched loss) and aim for 80% of that using important neurons?
- add neurons until improvement per neuron is "low"


- Ask Neel about choosing arbitrary cut-offs for neurons
- Go with cut-off for now: divide (ablated loss - patched loss) / 2048 to get "uniform contribution to loss value" for a neuron. If it contributes less than that, don't include it. Add an error term for noise in neuron contributions
- Average over many prompts

In [None]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)
prompts[1]

print(model.to_tokens(prompts[1]))
_, ablated_cache = model.run_with_cache(prompts[1])

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
tensor([[    0, 39980,  4308,  1405,  1850,   657, 34267, 42824,  1541]],
       device='cuda:0')


In [None]:
def get_deactivate_neuron_fwd_hooks(neuron: int | list[int], ablated_cache):
    def deactivate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache['blocks.5.mlp.hook_post'][:, :, neuron].mean(0)
        return value
    return [('blocks.5.mlp.hook_post', deactivate_neurons_hook)]

In [None]:
diffs = torch.zeros(2048, len(prompts))
for i, prompt in enumerate(prompts[:15]):
    with model.hooks(deactivate_neurons_fwd_hooks):
        _, ablated_cache = model.run_with_cache(prompt)

    _, _, _, baseline_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
    for neuron in range(2048):
        deactivate_neuron_fwd_hooks = get_deactivate_neuron_fwd_hooks(neuron, ablated_cache)
        _, _, _, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+deactivate_neuron_fwd_hooks)
        diffs[neuron, i] = only_activated_loss - baseline_loss
    

RuntimeError: The expanded size of the tensor (1) must match the existing size (10) at non-singleton dimension 0.  Target sizes: [1, 24].  Tensor sizes: [10, 24]

In [None]:
means = diffs.mean(1)

sorted_means, indices = torch.sort(means)
sorted_means = sorted_means.tolist()

haystack_utils.line(sorted_means, xlabel="Neuron", ylabel="Loss increase", title="Loss increase when deactivating single neuron") # xticks=indices

In [None]:
# Visually choose cutoff
top_neurons_count = 100
top_neuron_indices = indices[-top_neurons_count:]

In [None]:
original_losses, ablated_losses,_, only_activated_losses, important_activated_losses = [], [], [], [], []
names = ["Original", "Ablated", "MLP5 path patched", "Important MLP5 neurons path patched"]
for prompt in prompts:
    with model.hooks(deactivate_neurons_fwd_hooks):
        _, ablated_cache = model.run_with_cache(prompt)
    deactivate_neuron_fwd_hooks = get_deactivate_neuron_fwd_hooks(top_neuron_indices, ablated_cache)
    
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
    _, _, _, important_neurons_deactivated_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+deactivate_neuron_fwd_hooks)
    original_losses.append(original_loss)
    ablated_losses.append(ablated_loss)
    only_activated_losses.append(only_activated_loss)
    important_activated_losses.append(important_neurons_deactivated_loss)

haystack_utils.plot_barplot([original_losses, ablated_losses, only_activated_losses, important_activated_losses], names, ylabel="Loss", title=f"Average last token loss without important neurons'")

How does switching out previous tokens for other tokens affect the value of these neurons? Are there clusters which correspond to particular token positions?

Method:
- Switch out for average of most common German unigrams
- Not looking at average difference in activation per neuron because there's no way to tell if the difference is good or bad
- Will look at average difference in loss when disabling the neuron on regular text vs. switched text
Possible alternative we're not doing: switch out for totally random tokens

Implementation:
Patch in neuron activation from run where token is swapped out to run where token is normal. Average over many swapped tokens.
- return cache
- do three runs, once with swapped token position, once patching the neuron under test from swapped to not swapped, once to compare with not swapped

In [None]:
end_string = " Vorschlägen"
n=100
length=20
num_replacements = 10

end_tokens = model.to_tokens(end_string).flatten()[1:]
# random unigrams ending in end_string
prompts = generate_prompts(end_tokens, n=n, length=length)

for i, prompt in enumerate(prompts):
    prompts[i] = torch.cat([torch.zeros(1, dtype=int).cuda(), prompt]).cuda()

diffs = torch.zeros(top_neurons_count, 3, len(prompts))
for replace_index, replace_pos in enumerate([-2, -3, -4]):
    for prompt_index, prompt in enumerate(prompts):
        _, _, _, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
        new_prompts = get_new_prompts_from_tokens([prompt], replace_pos, num_replacements=num_replacements)
        new_tokens = torch.stack(new_prompts)

        patched_cache = haystack_utils.get_patched_cache(new_tokens, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
        for i, neuron in enumerate(top_neuron_indices):
            patched_neuron_fwd_hooks = get_deactivate_neuron_fwd_hooks(neuron, patched_cache)
            _, _, _, only_activated_with_corrupted_neuron_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+patched_neuron_fwd_hooks)
            diffs[i, replace_index, prompt_index] = only_activated_with_corrupted_neuron_loss - only_activated_loss

In [None]:
mean_diffs = diffs.mean((2))
mean_diffs.shape

torch.Size([100, 3])

In [None]:
sorted_means, indices = torch.sort(mean_diffs, dim=0)

haystack_utils.line(sorted_means[:, 0].tolist(), xlabel="Neuron", ylabel="Loss increase", title="Loss increase when switching out a single token lä", xticks=indices[:, 0].tolist(), width=1200)
haystack_utils.line(sorted_means[:, 1].tolist(), xlabel="Neuron", ylabel="Loss increase", title="Loss increase when switching out a single token orsch", xticks=indices[:, 1].tolist(), width=1200)
haystack_utils.line(sorted_means[:, 2].tolist(), xlabel="Neuron", ylabel="Loss increase", title="Loss increase when switching out a single token V", xticks=indices[:, 2].tolist(), width=1200)

Identify which neurons are destructive interference vs. a mixture of constructive and destructive
- Identify important neuron output weights that barely/don't write to the correct token

Distribution of cosine sims

In [None]:
cosine_sim = torch.nn.CosineSimilarity(dim=1)
answer_residual_direction = model.tokens_to_residual_directions("gen")
neuron_weights = model.state_dict()['blocks.5.mlp.W_out'][top_neuron_indices]

cosine_sims = cosine_sim(neuron_weights, answer_residual_direction.unsqueeze(0))


fig = go.Figure()
fig.add_trace(go.Histogram(x=cosine_sims.cpu().numpy(), histnorm='percent', name=''))
fig.update_layout(
    width=800,
    title_text="Cosine Similarities",
    xaxis_title_text="Neuron", 
    yaxis_title_text="Sim",
    bargap=0.2, # gap between bars of adjacent location coordinates
    bargroupgap=0.1 # gap between bars of the same location coordinates
)
fig.show()

torch.Size([100, 512])