When the German neurons are ablated, the model makes predictions on certain tokens that result in extremely high losses. 
Investigating these tokens, and the correct "next token" prediction, shows that many form German bigrams. The most prevalent are:

20 ('id', 'ig')
7 ('rt', 'ige')
5 (' Vert', 'rag')
5 ('he', 'ben')
4 ('ät', 'ig')
4 ('nd', 'liche')

We should run these bigrams with and without the German neurons ablated, then investigate which neurons have the most 
different activations at the position of the first bigram token. We'll do this using the raw logit difference.

Q: should we use single token prompts with the first token in each bigrams, or should we create prompts that end in the bigram?
A: we will create prompts the end in the bigram because models behave inconsistently for the first few tokens in each prompt, 
and we don't want this behaviour affecting our results.

Q: how will we create the prompts?
A: look at the word the token was used in, use GPT to generate prompts that end in this word.

20 ('id', 'ig') - Verteidigung, Verteidiger, Rechtsverteidigung
7 ('rt', 'ige') - auswärtige, sofortige, neuartigen
5 (' Vert', 'rag') - Vertragsbediensteten, vertraglichen
5 ('he', 'ben') - hervorheben, entheben
4 ('ät', 'ig') - tätig, Tätigkeit, Berufstätigen, bestätigt, gewalttätigen
4 ('nd', 'liche') - gründliche, selbstverständlichen, unmißverständlicher, ländlichen, unmissverständlichere

## Generate data with GPT4

In [31]:
prompts = {}
words = ['Verteidigung', 'auswärtige', 'Vertragsbediensteten', 'hervorheben', 'tätig', 'gründliche']

# 'Verteidigung'
prompts[('id', 'ig')] = [
    'Ich verbringe viel Zeit mit dem Studium der Theorie und Praxis der Verteidigung',
    'In seiner Rede betonte der Minister die Notwendigkeit einer starken nationalen Verteidigung'
    'Es ist wichtig, dass wir ein geeignetes Budget für die Verteidigung',
    'Sein Fokus liegt auf der Verbesserung seiner Techniken in der Verteidigung',
    'Das Angriffsspiel ist wichtig, aber wir dürfen die Bedeutung der Verteidigung',
    'Als Anwalt hat sie viele Jahre Erfahrung in der Verteidigung',
    'Die Verteidigung',
    'In der Militärstrategie ist die beste Angriffstaktik oft eine gute Verteidigung',
    'Der Anwalt führte eine starke und überzeugende Verteidigung',
    'Die Regierung hat die Stärkung der Verteidigung',]

# 'auswärtige'
prompts[('rt', 'ige')] = [
    'Meine Tätigkeit erfordert viele Reisen, daher bin ich oft auswärtige',
    'Er ist als diplomatischer Berater für alle auswärtige',
    'Der Minister für auswärtige',
    'Sie ist Expertin für auswärtige',
    'Die Behörde für auswärtige',
    'Es ist wichtig, sich über auswärtige',
    'Wir sollten uns auf die auswärtige',
    'In seiner Rolle überwacht er auswärtige',
    'Der Diplomat hat eine lange Karriere in auswärtige',
    'Die Universität bietet einen Studiengang in auswärtige']

# 'Vertragsbediensteten'
prompts[(' Vert', 'rag')] = [
    'Nach seiner Ausbildung begann er seine Karriere als einer der Vertragsbediensteten',
    'Das Unternehmen hat eine Reihe von Vertragsbediensteten',
    'Die Rechte und Pflichten der Vertragsbediensteten',
    'Die Bezahlung der Vertragsbediensteten',
    'Wegen des hohen Arbeitsaufkommens werden zusätzliche Vertragsbediensteten',
    'Der Status der Vertragsbediensteten',
    'Die Gesundheits- und Sicherheitsvorschriften gelten auch für die Vertragsbediensteten',
    'Alle Vertragsbediensteten müssen eine Verschwiegenheitserklärung',
    'Die Firma plant, das Team der Vertragsbediensteten',
    'Die Schulung neuer Vertragsbediensteten']

# 'hervorheben'
prompts[('he', 'ben')] = [
    'In Ihrem Lebenslauf sollten Sie Ihre besonderen Fähigkeiten und Erfahrungen hervorheben',
    'Die hellen Farben im Bild sollen die Dynamik und Energie der Szene hervorheben',
    'Bei der Präsentation sollten Sie die Hauptpunkte hervorheben',
    'Die Wissenschaftler wollen die Bedeutung ihrer Forschungsergebnisse hervorheben',
    'Die Autorin nutzte Metaphern, um die Emotionen ihrer Charaktere hervorheben',
    'Mit diesem Marketingstrategieplan wollen wir die Einzigartigkeit unseres Produkts hervorheben',
    'Es ist wichtig, in der Debatte die Fakten zu hervorheben',
    'In seinem Vortrag versuchte der Redner, die Relevanz des Themas für das Publikum hervorheben',
    'Beim Design des Hauses wurde besonderer Wert darauf gelegt, die natürlichen Materialien hervorheben',
    'Im Interview konnte sie ihre umfangreichen Kenntnisse und Erfahrungen hervorheben']

# 'tätig'
prompts[('ät', 'ig')] = [
    'Nach seinem Studium war er viele Jahre in der Marketingbranche tätig',
    'Sie ist als Freiwillige in einer gemeinnützigen Organisation tätig',
    'Ich bin seit über zehn Jahren als Lehrer tätig',
    'Er ist hauptsächlich in der Beratung von Start-up-Unternehmen tätig',
    'Als Journalistin war sie vor allem im politischen Bereich tätig',
    'Mein Bruder ist als Softwareentwickler tätig',
    'In ihrer Freizeit ist sie in verschiedenen sozialen Projekten tätig',
    'Nach seinem Ruhestand ist er ehrenamtlich in der Gemeinde tätig',
    'Sie ist als Autorin tätig und hat bereits mehrere Bücher veröffentlicht',
    'Als Anwalt ist er vor allem in den Bereichen Strafrecht und Zivilrecht tätig']

# 'gründliche'
prompts[('nd', 'liche')] = [
    'Bevor wir mit dem Projekt fortfahren, benötigen wir eine gründliche',
    'Der Erfolg der Operation hängt von einer gründlichen',
    'Das Gesetz erfordert eine gründliche',
    'Vor dem Kauf eines Gebrauchtwagens sollte man eine gründliche',
    'Die Ermittlungen in dem Fall erfordern eine gründliche',
    'Die Studie liefert eine gründliche',
    'Das Projektteam hat eine gründliche',
    'Vor dem Abschluss des Geschäfts wird eine gründliche',
    'Die Wartung des Systems erfordert eine gründliche',
    'Die Durchführung einer gründliche']

Next we need to run through each set of prompts. For each prompt we tokenize, then traverse backwards through the model until we find the last token of our bigrams.
We concatenate any tokens after this one. This gives us our final dataset.

Next, we do a forward pass with the German neurons ablated and unablated, and save the cache. There's an existing method to do this.
We select the MLP activations at the second to last position. 
We average the MLP activations.
We compare the average difference in neuron activation at that position, and select the neurons with the largest average difference.

## Setup

In [32]:
from collections import defaultdict

import torch
import numpy as np
from transformer_lens import HookedTransformer

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "colab+vscode"

from haystack_utils import load_txt_data, get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [33]:
haystack_utils.clean_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


## Remove any tokens after the second bigram token from the dataset and calculate MLP mean activations

In [34]:
# Remove any tokens after the second bigram token frmo the dataset

processed_prompts = defaultdict(list)

for key, value in prompts.items():
    bigram_tokens = model.to_tokens(key, prepend_bos=False)
    new_prompts = []
    for prompt in value:
        prompt_tokens = model.to_tokens(prompt, prepend_bos=False)[0] # pos
        i = prompt_tokens.shape[0] - 1
        while (prompt_tokens[i] != bigram_tokens[1] or prompt_tokens[i - 1] != bigram_tokens[0]) and (i > 1):
            i -= 1
        prompt_tokens = prompt_tokens[:i + 1]
        prompt_string = model.to_string(prompt_tokens)
        new_prompts.append(prompt_string)

    processed_prompts[key] = new_prompts

prompts = processed_prompts

In [35]:
# Calculate MLP mean activations
haystack_utils.clean_cache()

german_data = load_txt_data("wmt_german_large.txt")[:500]
english_data = load_txt_data("kde4_english.txt")[:500]

german_mean_low_activations = defaultdict(torch.Tensor, {
    3: get_mlp_activations(english_data, 3, model, mean=True),  # [2048]
    4: get_mlp_activations(english_data, 4, model, mean=True),
    5: get_mlp_activations(english_data, 5, model, mean=True)
})
german_mean_high_activations = defaultdict(torch.Tensor, {
    3: get_mlp_activations(german_data, 3, model, mean=True),  # [2048]
    4: get_mlp_activations(german_data, 4, model, mean=True),
    5: get_mlp_activations(german_data, 5, model, mean=True)
})

our_german_neurons = [(3, 669), (5, 1336), (4, 482), (5, 1039), (4, 326)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]

our_german_neurons_by_layer = defaultdict(list)
for item in our_german_neurons:
    our_german_neurons_by_layer[item[0]].append(item[1])
german_neurons_by_layer = defaultdict(list)
for item in german_neurons:
    german_neurons_by_layer[item[0]].append(item[1])

wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

In [61]:
def disable_german_hook(value, hook):
    layer = hook.layer()
    german_neurons_for_layer = our_german_neurons_by_layer[layer]
    value[:, :, german_neurons_for_layer] = german_mean_low_activations[layer][german_neurons_for_layer].cuda()
    return value

def enable_german_hook(value, hook):
    layer = hook.layer()
    german_neurons_for_layer = our_german_neurons_by_layer[layer]
    value[:, :, german_neurons_for_layer] = german_mean_high_activations[layer][german_neurons_for_layer].cuda() * 2.2
    return value

def disable_german_l3_hook(value, hook):
    layer = hook.layer()
    if layer == 3:
        value[:, :, 669] = german_mean_low_activations[3][669].cuda() * 2.2
    return value

def enable_german_l3_hook(value, hook):
    layer = hook.layer()
    if layer == 3:
        value[:, :, 669] = german_mean_high_activations[3][669].cuda() * 2.2
    return value

## Experiment

### Highest token loss difference from ablating German neurons

In [62]:
# This block isn't super relevant to the experiment, more of a sanity check that the second bigram token has a high loss difference. 
# The difference is negative which means the German-enabled loss is lower than the disabled loss, which is what we want.
# The difference is also very high, which is good.

mlp_pattern = lambda name: name.endswith("mlp.hook_post")

for key, value in prompts.items():
    enabled_loss = haystack_utils.get_average_loss([key[0] + key[1]], model, crop_context=-1, fwd_hooks=[(mlp_pattern, enable_german_l3_hook)], positionwise=True)
    disabled_loss = haystack_utils.get_average_loss([key[0] + key[1]], model, crop_context=-1, fwd_hooks=[(mlp_pattern, disable_german_l3_hook)], positionwise=True)
    print([i - j for i, j in zip(enabled_loss[:3], disabled_loss[:3])])

model.reset_hooks()
print(haystack_utils.generate_text("A4-0409/98 von Herrn Roubatis im Namen des Ausschusses für auswärt", model))

[-0.31719112396240234, -3.4623517990112305, 0]
[-0.08849430084228516, -3.9318342208862305, 0]
[-0.882472038269043, -4.8858349323272705, 0]
[-0.9507732391357422, -1.5768184661865234, 0]
[-2.221306800842285, -1.8587722778320312, 0]
[-1.0019140243530273, -5.573101043701172, 0]
A4-0409/98 von Herrn Roubatis im Namen des Ausschusses für auswärtche Familie, die in der Wahlkampf in der Wahlkampf


In [45]:
# When the problematic bigrams are in a sentence both models have high loss and similar loss.
# We could benefit from left padding the prompts here.
for key, value in prompts.items():
    print(key, ":", value[0])
    loss = haystack_utils.get_average_loss(value, model, crop_context=-1, positionwise=True)
    ablated_loss = haystack_utils.get_average_loss(value, model, crop_context=-1, fwd_hooks=[(mlp_pattern, disable_german_hook)], positionwise=True)
    print("ablated alone:", ablated_loss)
    print([i - j for i, j in zip(ablated_loss, loss)])

('id', 'ig') : Ich verbringe viel Zeit mit dem Studium der Theorie und Praxis der Verteidig
ablated alone: [7.766988754272461, 8.227001190185547, 6.402140140533447, 5.804161071777344, 6.938735485076904, 8.42430591583252, 5.7868242263793945, 6.2592597007751465, 5.115482330322266, 5.266501426696777, 4.365352153778076, 3.514495849609375, 4.273314476013184, 4.550295352935791, 3.104933977127075, 3.764209747314453, 4.727024555206299, 4.996598243713379, 3.6796176433563232, 4.27886438369751, 4.008338928222656, 3.0583372116088867, 4.943498611450195, 4.0826826095581055, 5.406358242034912, 4.53716516494751, 5.925273418426514, 2.180640935897827, 0.222773939371109, 0.7666769027709961, 1.1788671016693115, 3.48811674118042, 4.325319290161133, 4.384984016418457, 3.8103933334350586, 2.0329039096832275, 2.2362844944000244, 9.289022445678711, 2.6558756828308105, 2.277245044708252, 5.820948600769043, 0.2030738890171051, 0.003481400664895773, 0.10046202689409256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

### Direct Effect, Indirect Effect

- Run model with and without ablating the German neurons, save both clean and ablated activations
- Run model again without ablation
- Simulate the effect of individual ablated components
- To simulate ablating a component:
    - Before the final layernorm, subtract the cached activation the component from the unablated run
    - Then add the activation of the ablated run
- Then we can compute the effect of running a component with corrupted activations without letting its output affect later components
- However, the cached ablated activations of later components will still be influenced by earlier components
    - [ ] Check if we can improve this by patching in ablated MLP and unablated earlier component residual stream contributions, then running the later component
    - [ ] Check patching library

### Loss Component Breakdown - overview, not core

In [51]:
# Indirect effect of ablating the context neuron
# Logit attribution of later components when ablating the context neuron
# Not sure how clean this is - e.g. layer 5 MLP will get the accumulated effects of all previous layers from ablating the context neuron
component_names = ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']

def disable_german_l3_hook(value, hook):
    value[:, :, 669] = german_mean_low_activations[3][669].cuda()
    return value
disable_german_l3_fwd_hooks = [(f'blocks.{3}.mlp.hook_post', disable_german_l3_hook)]

# def ablate_german_neurons_hook(value, hook):
#     layer = hook.layer()
#     neurons_for_layer = german_neurons_by_layer[layer]
#     value[:, :, neurons_for_layer] = german_mean_low_activations[layer][neurons_for_layer].cuda()
#     return value
# mlp_pattern = lambda name: name.endswith("mlp.hook_post")
# fwd_hooks=[(mlp_pattern, ablate_german_neurons_hook)]

for key, value in prompts.items():
    # Go through components after MLP3, which contains the first German contextual neuron
    components = []
    losses = []
    for later_component in range(8, 13):
        # print(f"Component: {component_names[later_component]}")
        layer = 3
        original_loss, patched_loss = haystack_utils.get_direct_loss_increase_for_component(value, model, disable_german_l3_fwd_hooks, patched_component=later_component, crop_context_end=500, disable_progress_bar=True)
        if len(losses) == 0:
            components.append("Original loss")
            losses.append(original_loss)
        components.append(component_names[later_component])
        losses.append(patched_loss)

    percent_increase = ((np.array(losses) - losses[0]) / losses[0]) * 100

    haystack_utils.line(losses, xlabel="Component", ylabel="Loss", title=f"Loss of patching individual components with German context neuron disabled <br>on prompts with the bigram {key}", xticks=components, width=800, hover_data=percent_increase.tolist())

Original loss: 4.34, patched loss: 4.61 (+6.07%)
Original loss: 4.34, patched loss: 4.36 (+0.51%)
Original loss: 4.34, patched loss: 4.48 (+3.09%)
Original loss: 4.34, patched loss: 4.38 (+0.89%)
Original loss: 4.34, patched loss: 4.65 (+7.20%)


Original loss: 4.97, patched loss: 5.42 (+9.17%)
Original loss: 4.97, patched loss: 4.99 (+0.51%)
Original loss: 4.97, patched loss: 5.17 (+4.15%)
Original loss: 4.97, patched loss: 5.03 (+1.21%)
Original loss: 4.97, patched loss: 5.11 (+2.86%)


Original loss: 7.55, patched loss: 7.61 (+0.86%)
Original loss: 7.55, patched loss: 7.55 (+-0.01%)
Original loss: 7.55, patched loss: 7.62 (+0.98%)
Original loss: 7.55, patched loss: 7.56 (+0.17%)
Original loss: 7.55, patched loss: 7.63 (+1.09%)


Original loss: 4.51, patched loss: 4.76 (+5.57%)
Original loss: 4.51, patched loss: 4.53 (+0.56%)
Original loss: 4.51, patched loss: 4.62 (+2.52%)
Original loss: 4.51, patched loss: 4.58 (+1.53%)
Original loss: 4.51, patched loss: 4.92 (+9.24%)


Original loss: 5.08, patched loss: 5.32 (+4.75%)
Original loss: 5.08, patched loss: 5.09 (+0.29%)
Original loss: 5.08, patched loss: 5.23 (+2.96%)
Original loss: 5.08, patched loss: 5.12 (+0.78%)
Original loss: 5.08, patched loss: 5.31 (+4.63%)


Original loss: 7.16, patched loss: 7.19 (+0.43%)
Original loss: 7.16, patched loss: 7.16 (+-0.01%)
Original loss: 7.16, patched loss: 7.20 (+0.54%)
Original loss: 7.16, patched loss: 7.15 (+-0.10%)
Original loss: 7.16, patched loss: 7.24 (+1.10%)


## Get individual neuron activation differences for first token in bigram

In [59]:
from typing import List, Tuple

# Modified from downstream components notebook to process a single position in each prompt
def get_ablated_mlp_difference_at_pos(
        prompts: List[str], 
        model: HookedTransformer, 
        fwd_hooks: List[Tuple],
        layer_to_cache: int, 
        position=-1,
        print_mean_loss=False
):
    """Difference with ablation for one component, not doing any fancy direct/indirect effect logic"""
    block_name = f'blocks.{layer_to_cache}.mlp.hook_post'
    original_losses = []
    ablated_losses = []
    mean_differences = []
    for prompt in prompts:
        original_loss, ablated_loss, original_cache, ablated_cache = haystack_utils.get_caches_single_prompt(
            prompt, model, fwd_hooks=disable_german_l3_fwd_hooks, return_type="loss")

        original_activations = original_cache[block_name][:, position] # batch pos d_mlp
        ablated_activations = ablated_cache[block_name][:, position]
        mean_difference = original_activations.mean((0)) - ablated_activations.mean((0))
        mean_differences.append(mean_difference)
        if print_mean_loss:
            original_losses.append(original_loss)
            ablated_losses.append(ablated_loss)
        
    if print_mean_loss:
        print(f"Original loss: {np.mean(original_losses):.2f}, ablated loss: {np.mean(ablated_losses):.2f} (+{((np.mean(ablated_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return torch.stack(mean_differences).mean(0)

In [60]:
def ablate_german_neurons_hook(value, hook):
    if hook.layer() == 3:
        value[:, :, 669] = german_mean_low_activations[layer][669].cuda()
    return value
mlp_pattern = lambda name: name.endswith("mlp.hook_post")
fwd_hooks=[(mlp_pattern, ablate_german_neurons_hook)]

for key, value in prompts.items():
    for layer_to_cache in [4, 5]:
        difference = get_ablated_mlp_difference_at_pos(value, model, fwd_hooks=fwd_hooks, layer_to_cache=layer_to_cache, position=-2)
        sorted_differences, sorted_neurons = torch.topk(difference.abs(), len(difference), largest=True)
        haystack_utils.line(sorted_differences.cpu().numpy()[:100], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[:100], title=f"Top absolute neuron differences in layer {layer_to_cache} for first token of bigram {key}", width=1400)

Some of these are just the context neurons we already know about.

Some like  (5, 297) are unknown tokens with massive activation differences.

### Repeat with logit diff directions

- Find a plausible English completion for each first token and use a logit difference metric similar to IOI instead of the absolute logit for our answer token. 
Then we have a metric that is decomposable into individual component contributions like the logits, but isn't sensitive to scaling of the entire logit vector.
  - Just run the model with the first token and take the second one as the English/non-German bigram?
  - logit diff = x(token 2)W_U - x(english alternative token)W_U

The main problem with this is, there's no plausible reason why

In [41]:
# # Find plausible English completions for our bigrams


# def get_plausible_completions():
#     for key in prompts.keys():
        

## Zoom in on 5,297 with ät ig

In [68]:
import einops
from transformer_lens import ActivationCache, utils
from jaxtyping import Int, Float
from torch import Tensor

# downstream components notebook
def get_loss_patched_mlp_neurons(prompts: list[str], model: HookedTransformer, fwd_hooks: List[Tuple], patch_neurons, patch_layer=5, 
                                 crop_context_end: None | int=None):
    """Print increase in patched loss from original loss as a percentage of original loss"""
    original_losses = []
    patched_losses = []
    for prompt in prompts:  
        original_loss, _, _, ablated_cache = haystack_utils.get_caches_single_prompt(prompt, model, fwd_hooks, crop_context_end=crop_context_end)
        
        if crop_context_end is not None:
            tokens = model.to_tokens(prompt)[:, :crop_context_end]
        else:
            tokens = model.to_tokens(prompt)
            
        def patch_hook(value, hook):
            # Batch, pos, d_mlp
            value[:, :, patch_neurons] = ablated_cache[f'blocks.{patch_layer}.mlp.hook_post'][:, :, patch_neurons]
        
        with model.hooks(fwd_hooks=[(f'blocks.{patch_layer}.mlp.hook_post', patch_hook)]):
            patched_loss = model(tokens, return_type="loss")
        
        original_losses.append(original_loss)
        patched_losses.append(patched_loss.item())

    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

def get_neuron_logit_contribution(cache: ActivationCache, model: HookedTransformer, answer_tokens: Int[Tensor, "batch pos"], layer: int) -> Float[Tensor, "neuron pos"]:
    # Expects cache from a single example, won't work on batched examples
    # Get per neuron output of MLP layer
    neuron_directions = cache.get_neuron_results(layer, neuron_slice=utils.Slice(input_slice=None), pos_slice=utils.Slice(input_slice=None))
    neuron_directions = einops.rearrange(neuron_directions, 'batch pos neuron residual -> neuron batch pos residual')
    # We need to apply the final layer norm because the unembed operation is applied after the final layer norm, so the answer token
    # directions are in the same space as the final layer norm output
    # LN leads to finding top tokens with slightly higher loss attribution
    scaled_neuron_directions = cache.apply_ln_to_stack(neuron_directions)[:, 0, :-1, :] # [neuron pos embed]
    # Unembed of correct answer tokens
    correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # [embed pos] # [d_model answer_tokens]
    # Neuron attribution to correct answer token by position
    unembedded = einops.einsum(scaled_neuron_directions, correct_token_directions, 'neuron pos residual, residual pos -> neuron pos') # [neuron pos]
    return unembedded

# The start of sorted_differences contains neurons that increase the logits of the answer tokens the most,
# and the end contains neurons that decrease the logits the most.

def MLP_pos_attribution(prompts: list[str], model: HookedTransformer, fwd_hooks, layer_to_compare=5, crop_context_end: None | int=None, position=-1):
    """ 
    Gets the difference between how aligned neuron outputs are with the correct token in the original and ablated models.
    A positive difference here means the direct effect of an ablated cache neuron is lower on the correct token.
    Works for a single layer.
    
    The differences are averaged over each position in each prompt, and then over each prompt but it works as an approximation and can be used for comparisons.
    """
    differences = torch.zeros(model.cfg.d_mlp)
    for prompt in prompts:
        # Get answer tokens
        if crop_context_end is not None:
            tokens = model.to_tokens(prompt)[:, :crop_context_end]
        else:
            tokens = model.to_tokens(prompt)
        answer_tokens = tokens[:, 1:]

        # Get difference between ablated and unablated neurons' contribution to answer logit
        _, _, original_cache, ablated_cache = haystack_utils.get_caches_single_prompt(
            prompt, model, fwd_hooks, crop_context_end=crop_context_end)
        
        original_unembedded = get_neuron_logit_contribution(original_cache, model, answer_tokens, layer=layer_to_compare) # [neuron pos]
        ablated_unembedded = get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=layer_to_compare)
        differences += (original_unembedded - ablated_unembedded)[:, position].detach().cpu() # [neuron]
    
    mean_difference = differences / len(prompts)
    print(mean_difference.shape)
    print("Mean activation difference on correct token summed over all neurons:", mean_difference.sum().item())
    sorted_differences, sorted_neurons = torch.topk(mean_difference, len(mean_difference), largest=True)
    return sorted_differences, sorted_neurons

In [71]:
for key, value in prompts.items():
    for layer_to_cache in [4, 5]:
        sorted_differences, sorted_neurons = MLP_pos_attribution(value, model, fwd_hooks=fwd_hooks, layer_to_compare=layer_to_cache, crop_context_end=500)
        haystack_utils.line(sorted_differences.cpu().numpy()[:100], xlabel="Neuron", ylabel="Absolute difference", xticks=sorted_neurons.cpu().tolist()[:100], title=f"Top absolute neuron differences in contribution to answer logit in layer {layer_to_cache} for first token of bigram {key}", width=1400)

torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.2734977602958679


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.83008873462677


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.460050493478775


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.623856782913208


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.1051379069685936


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.09195093810558319


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.33168449997901917


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.7938476204872131


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.34148818254470825


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.4814188480377197


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.07425827533006668


torch.Size([2048])
Mean activation difference on correct token summed over all neurons: 0.08535687625408173


In [75]:
def disable_5_297(value, hook):
    layer = hook.layer()
    if layer == 5:
        value[:, :, 297] = german_mean_low_activations[5][297].cuda() * 2.2
    return value

mlp_pattern = lambda name: name.endswith("mlp.hook_post")

# for key, value in prompts.items():
print(model("Nach seinem Studium war er viele Jahre in der Marketingbranche tätig", return_type="loss", loss_per_token=True)[-3:])

with model.hooks(fwd_hooks=[(mlp_pattern, disable_5_297)]):
    print(model("Nach seinem Studium war er viele Jahre in der Marketingbranche tätig", return_type="loss", loss_per_token=True)[-3:])
# enabled_loss = haystack_utils.get_average_loss(["ätig"], model, positionwise=True)
# disabled_loss = haystack_utils.get_average_loss(["ätig"], model, fwd_hooks=[(mlp_pattern, disable_5_297)], positionwise=True)
# print(enabled_loss[:3], disabled_loss[:3])
# print([i - j for i, j in zip(enabled_loss[:3], disabled_loss[:3])])

model.reset_hooks()
print(haystack_utils.generate_text("A4-0409/98 von Herrn Roubatis im Namen des Ausschusses für auswärt", model))

tensor([[ 5.6681,  5.9985, 14.7690,  0.6080,  8.6174,  3.4352,  7.3726,  7.2873,
         10.9054,  2.5063,  6.1046,  0.1818,  3.1512,  2.6386, 12.2585,  9.3313,
          1.9656,  6.5056,  6.3966,  6.1895]], device='cuda:0')
tensor([[ 5.7913,  6.1337, 15.1876,  1.0885,  9.4587,  1.5320,  7.0472,  7.6913,
         11.5994,  3.2800,  6.6716,  0.1313,  3.2489,  3.3172, 12.8690,  8.8057,
          4.3142,  6.1308,  5.7763,  6.6426]], device='cuda:0')
A4-0409/98 von Herrn Roubatis im Namen des Ausschusses für auswärtche Familie, die in der Wahlkampf in der Wahlkampf
