When the German neurons are ablated, the model makes predictions on certain tokens that result in extremely high losses. Investigating these tokens, and the correct "next token" prediction, shows that many form German bigrams. The most prevalent are:

20 ('id', 'ig')
7 ('rt', 'ige')
5 (' Vert', 'rag')
5 ('he', 'ben')
4 ('ät', 'ig')
4 ('nd', 'liche')

We should run these bigrams with and without the German neurons ablated, then investigate which neurons have the most different activations at the position of the first bigram token. We'll do this using the raw logit difference.

Q: should we use single token prompts with the first token in each bigrams, or should we create prompts that end in the bigram?
A: we will create prompts the end in the bigram because models behave inconsistently for the first few tokens in each prompt, and we don't want this behaviour affecting our results.

Q: how will we create the prompts?
A: look at the word the token was used in, use GPT to generate prompts that end in this word.

20 ('id', 'ig') - Verteidigung, Verteidiger, Rechtsverteidigung
7 ('rt', 'ige') - auswärtige, sofortige, neuartigen
5 (' Vert', 'rag') - Vertragsbediensteten, vertraglichen
5 ('he', 'ben') - hervorheben, entheben
4 ('ät', 'ig') - tätig, Tätigkeit, Berufstätigen, bestätigt, gewalttätigen
4 ('nd', 'liche') - gründliche, selbstverständlichen, unmißverständlicher, ländlichen, unmissverständlichere

## Generate data with GPT4

In [68]:
prompts = {}
words = ['Verteidigung', 'auswärtige', 'Vertragsbediensteten', 'hervorheben', 'tätig', 'gründliche']

# 'Verteidigung'
prompts[('id', 'ig')] = [
    'Ich verbringe viel Zeit mit dem Studium der Theorie und Praxis der Verteidigung',
    'In seiner Rede betonte der Minister die Notwendigkeit einer starken nationalen Verteidigung'
    'Es ist wichtig, dass wir ein geeignetes Budget für die Verteidigung',
    'Sein Fokus liegt auf der Verbesserung seiner Techniken in der Verteidigung',
    'Das Angriffsspiel ist wichtig, aber wir dürfen die Bedeutung der Verteidigung',
    'Als Anwalt hat sie viele Jahre Erfahrung in der Verteidigung',
    'Die Verteidigung',
    'In der Militärstrategie ist die beste Angriffstaktik oft eine gute Verteidigung',
    'Der Anwalt führte eine starke und überzeugende Verteidigung',
    'Die Regierung hat die Stärkung der Verteidigung',]

# 'auswärtige'
prompts[('rt', 'ige')] = [
    'Meine Tätigkeit erfordert viele Reisen, daher bin ich oft auswärtige',
    'Er ist als diplomatischer Berater für alle auswärtige',
    'Der Minister für auswärtige',
    'Sie ist Expertin für auswärtige',
    'Die Behörde für auswärtige',
    'Es ist wichtig, sich über auswärtige',
    'Wir sollten uns auf die auswärtige',
    'In seiner Rolle überwacht er auswärtige',
    'Der Diplomat hat eine lange Karriere in auswärtige',
    'Die Universität bietet einen Studiengang in auswärtige']

# 'Vertragsbediensteten'
prompts[(' Vert', 'rag')] = [
    'Nach seiner Ausbildung begann er seine Karriere als einer der Vertragsbediensteten',
    'Das Unternehmen hat eine Reihe von Vertragsbediensteten',
    'Die Rechte und Pflichten der Vertragsbediensteten',
    'Die Bezahlung der Vertragsbediensteten',
    'Wegen des hohen Arbeitsaufkommens werden zusätzliche Vertragsbediensteten',
    'Der Status der Vertragsbediensteten',
    'Die Gesundheits- und Sicherheitsvorschriften gelten auch für die Vertragsbediensteten',
    'Alle Vertragsbediensteten müssen eine Verschwiegenheitserklärung',
    'Die Firma plant, das Team der Vertragsbediensteten',
    'Die Schulung neuer Vertragsbediensteten']

# 'hervorheben'
prompts[('he', 'ben')] = [
    'In Ihrem Lebenslauf sollten Sie Ihre besonderen Fähigkeiten und Erfahrungen hervorheben',
    'Die hellen Farben im Bild sollen die Dynamik und Energie der Szene hervorheben',
    'Bei der Präsentation sollten Sie die Hauptpunkte hervorheben',
    'Die Wissenschaftler wollen die Bedeutung ihrer Forschungsergebnisse hervorheben',
    'Die Autorin nutzte Metaphern, um die Emotionen ihrer Charaktere hervorheben',
    'Mit diesem Marketingstrategieplan wollen wir die Einzigartigkeit unseres Produkts hervorheben',
    'Es ist wichtig, in der Debatte die Fakten zu hervorheben',
    'In seinem Vortrag versuchte der Redner, die Relevanz des Themas für das Publikum hervorheben',
    'Beim Design des Hauses wurde besonderer Wert darauf gelegt, die natürlichen Materialien hervorheben',
    'Im Interview konnte sie ihre umfangreichen Kenntnisse und Erfahrungen hervorheben']

# 'tätig'
prompts[('ät', 'ig')] = [
    'Nach seinem Studium war er viele Jahre in der Marketingbranche tätig',
    'Sie ist als Freiwillige in einer gemeinnützigen Organisation tätig',
    'Ich bin seit über zehn Jahren als Lehrer tätig',
    'Er ist hauptsächlich in der Beratung von Start-up-Unternehmen tätig',
    'Als Journalistin war sie vor allem im politischen Bereich tätig',
    'Mein Bruder ist als Softwareentwickler tätig',
    'In ihrer Freizeit ist sie in verschiedenen sozialen Projekten tätig',
    'Nach seinem Ruhestand ist er ehrenamtlich in der Gemeinde tätig',
    'Sie ist als Autorin tätig und hat bereits mehrere Bücher veröffentlicht',
    'Als Anwalt ist er vor allem in den Bereichen Strafrecht und Zivilrecht tätig']

# 'gründliche'
prompts[('nd', 'liche')] = [
    'Bevor wir mit dem Projekt fortfahren, benötigen wir eine gründliche',
    'Der Erfolg der Operation hängt von einer gründlichen',
    'Das Gesetz erfordert eine gründliche',
    'Vor dem Kauf eines Gebrauchtwagens sollte man eine gründliche',
    'Die Ermittlungen in dem Fall erfordern eine gründliche',
    'Die Studie liefert eine gründliche',
    'Das Projektteam hat eine gründliche',
    'Vor dem Abschluss des Geschäfts wird eine gründliche',
    'Die Wartung des Systems erfordert eine gründliche',
    'Die Durchführung einer gründliche']

Next we need to run through each set of prompts. For each prompt we tokenize, then traverse backwards through the model until we find the last token of our bigrams.
We concatenate any tokens after this one. This gives us our final dataset.

Next, we do a forward pass with the German neurons ablated and unablated, and save the cache. There's an existing method to do this.
We select the MLP activations at the second to last position. 
We average the MLP activations.
We compare the average difference in neuron activation at that position, and select the neurons with the largest average difference.

## Setup

In [48]:
import torch
import numpy as np
from transformer_lens import HookedTransformer
import plotly.express as px

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"

from haystack_utils import load_txt_data, get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


## Process dataset and calculate MLP mean activations

In [80]:
# Remove any tokens after the second bigram token

from collections import defaultdict

processed_prompts = defaultdict(list)

for key, value in prompts.items():
    bigram_tokens = model.to_tokens(key, prepend_bos=False)
    new_prompts = []
    for prompt in value:
        prompt_tokens = model.to_tokens(prompt, prepend_bos=False)[0] # pos
        i = prompt_tokens.shape[0] - 1
        while (prompt_tokens[i] != bigram_tokens[1] or prompt_tokens[i - 1] != bigram_tokens[0]) and (i > 1):
            i -= 1
        prompt_tokens = prompt_tokens[:i + 1]
        prompt_string = model.to_string(prompt_tokens)
        new_prompts.append(prompt_string)

    processed_prompts[key] = new_prompts

prompts = processed_prompts

In [81]:
haystack_utils.clean_cache()

from collections import defaultdict

german_data = load_txt_data("wmt_german_large.txt")[:500]
english_data = load_txt_data("kde4_english.txt")[:500]

german_mean_low_activations = defaultdict(torch.Tensor, {
    3: get_mlp_activations(english_data, 3, model, mean=True),  # [2048]
    4: get_mlp_activations(english_data, 4, model, mean=True),
    5: get_mlp_activations(english_data, 5, model, mean=True)
})
german_mean_high_activations = defaultdict(torch.Tensor, {
    3: get_mlp_activations(german_data, 3, model, mean=True),  # [2048]
    4: get_mlp_activations(german_data, 4, model, mean=True),
    5: get_mlp_activations(german_data, 5, model, mean=True)
})

our_german_neurons = [(3, 669), (5, 1336), (4, 482), (5, 1039), (4, 326)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]

our_german_neurons_by_layer = defaultdict(list)
for item in our_german_neurons:
    our_german_neurons_by_layer[item[0]].append(item[1])
german_neurons_by_layer = defaultdict(list)
for item in german_neurons:
    german_neurons_by_layer[item[0]].append(item[1])

wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

## Experiment

In [32]:
def ablate_german_hook(value, hook):
    layer = hook.layer()
    german_neurons_for_layer = our_german_neurons_by_layer[layer]
    value[:, :, german_neurons_for_layer] = german_mean_low_activations[layer][german_neurons_for_layer].cuda() * 2.2
    return value

def ablate_german_hook_more_neurons(value, hook):
    layer = hook.layer()
    german_neurons_for_layer = german_neurons_by_layer[layer]
    value[:, :, german_neurons_for_layer] = german_mean_low_activations[layer][german_neurons_for_layer].cuda() * 2.2
    return value

mlp_pattern = lambda name: name.endswith("mlp.hook_post")

for key, value in prompts.items():
    print(key, ":", value[0])

    loss = haystack_utils.get_average_loss(value, model, batch_size=1, crop_context=-1, positionwise=True)
    ablated_loss = haystack_utils.get_average_loss(value, model, batch_size=1, crop_context=-1, fwd_hooks=[(mlp_pattern, ablate_german_hook)], positionwise=True)
    print(max([i - j for i, j in zip(ablated_loss, loss)]))

Verteidigung : Ich verbringe viel Zeit mit dem Studium der Theorie und Praxis der Verteidigung


0.1740102767944336
auswärtige : Meine Tätigkeit erfordert viele Reisen, daher bin ich oft auswärtige
0.16482985019683838
Vertragsbediensteten : Nach seiner Ausbildung begann er seine Karriere als einer der Vertragsbediensteten
0.17531204223632812
hervorheben : In Ihrem Lebenslauf sollten Sie Ihre besonderen Fähigkeiten und Erfahrungen hervorheben
0.1657717227935791
tätig : Nach seinem Studium war er viele Jahre in der Marketingbranche tätig
0.07474422454833984
gründliche : Bevor wir mit dem Projekt fortfahren, benötigen wir eine gründliche
0.09512472152709961


### Direct Effect, Indirect Effect

- Run model with and without ablating the German neurons, save both clean and ablated activations
- Run model again without ablation
- Simulate the effect of individual ablated components
- To simulate ablating a component:
    - Before the final layernorm, subtract the cached activation the component from the unablated run
    - Then add the activation of the ablated run
- Then we can compute the effect of running a component with corrupted activations without letting its output affect later components
- However, the cached ablated activations of later components will still be influenced by earlier components
    - [ ] Check if we can improve this by patching in ablated MLP and unablated earlier component residual stream contributions, then running the later component
    - [ ] Check patching library

In [87]:
def multi_neuron_position_DLA(
    prompts: list[str],
    model: HookedTransformer,
    mean_neuron_activations_by_layer,
    neurons = [(3, 609)],
    patched_component=8,
    position=-1,
    crop_context: None | tuple[int, int]=None,
):
    """
    Get the indirect effect of the German neurons being ablated via a single component. 
    Should be 0 before the component containing the earliest German neuron.

    Takes a list of neuron tuples like [layer, neuron_index]. Uses loss at specified token position

    How: for each prompt, get the original and ablated caches and decompose the residual stream into 
    its components. Then do a forward pass, hook into the residual stream at a single component, and 
    swap its original output activations for the ablated output activations.
    """
    neurons_by_layer = defaultdict(list)
    for item in neurons:
        neurons_by_layer[item[0]].append(item[1])

    mlp_pattern = lambda name: name.endswith("mlp.hook_post")
    neurons = torch.LongTensor(neurons)
    def ablate_neuron_hook(value, hook):
        layer = hook.layer()
        neurons_for_layer = neurons_by_layer[layer]
        value[:, :, neurons_for_layer] =  mean_neuron_activations_by_layer[layer][neurons_for_layer].cuda()
        return value

    original_losses = []
    patched_losses = []
    for prompt in prompts:
        if crop_context is not None:
            tokens = model.to_tokens(prompt)[:, crop_context[0]:crop_context[1]].cuda()
        else:
            tokens = model.to_tokens(prompt).cuda()

        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)
        with model.hooks(fwd_hooks=[(mlp_pattern, ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # component, batch, pos, residual
        original_per_layer_residual, original_labels = original_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)
        ablated_per_layer_residual, ablated_labels = ablated_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)

        # ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
        def swap_cache_hook(value, hook):
            # Batch, pos, residual
            value -= original_per_layer_residual[patched_component]
            value += ablated_per_layer_residual[patched_component]
        
        with model.hooks(fwd_hooks=[(f'blocks.5.hook_resid_post', swap_cache_hook)]):
            patched_loss = model(tokens, return_type="loss", loss_per_token=True)

        original_losses.append(original_loss[0, position].item())
        patched_losses.append(patched_loss[0, position].item())


    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

In [85]:
def multi_neuron_DLA(
    prompts: list[str],
    model: HookedTransformer,
    mean_neuron_activations_by_layer,
    neurons = [(3, 609)],
    patched_component=8,
    crop_context: None | tuple[int, int]=None,
):
    """
    Get the indirect effect of the German neurons being ablated via a single component. 
    Should be 0 before the component containing the earliest German neuron.

    Takes a list of neuron tuples like [layer, neuron_index]. Uses average loss.

    How: for each prompt, get the original and ablated caches and decompose the residual stream into 
    its components. Then do a forward pass, hook into the residual stream at a single component, and 
    swap its original output activations for the ablated output activations.
    """
    neurons_by_layer = defaultdict(list)
    for item in neurons:
        neurons_by_layer[item[0]].append(item[1])

    mlp_pattern = lambda name: name.endswith("mlp.hook_post")
    neurons = torch.LongTensor(neurons)
    def ablate_neuron_hook(value, hook):
        layer = hook.layer()
        neurons_for_layer = neurons_by_layer[layer]
        value[:, :, neurons_for_layer] =  mean_neuron_activations_by_layer[layer][neurons_for_layer].cuda()
        return value

    original_losses = []
    patched_losses = []
    for prompt in prompts:
        if crop_context is not None:
            tokens = model.to_tokens(prompt)[:, crop_context[0]:crop_context[1]].cuda()
        else:
            tokens = model.to_tokens(prompt).cuda()

        original_loss, original_cache = model.run_with_cache(tokens, return_type="loss")
        with model.hooks(fwd_hooks=[(mlp_pattern, ablate_neuron_hook)]):
            ablated_loss, ablated_cache = model.run_with_cache(tokens, return_type="loss")

        # component, batch, pos, residual
        original_per_layer_residual, original_labels = original_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)
        ablated_per_layer_residual, ablated_labels = ablated_cache.decompose_resid(layer=-1, return_labels=True, apply_ln=False)

        # ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
        def swap_cache_hook(value, hook):
            # Batch, pos, residual
            value -= original_per_layer_residual[patched_component]
            value += ablated_per_layer_residual[patched_component]
        
        with model.hooks(fwd_hooks=[(f'blocks.5.hook_resid_post', swap_cache_hook)]):
            patched_loss = model(tokens, return_type="loss")

        original_losses.append(original_loss.item())
        patched_losses.append(patched_loss.item())


    print(f"Original loss: {np.mean(original_losses):.2f}, patched loss: {np.mean(patched_losses):.2f} (+{((np.mean(patched_losses) - np.mean(original_losses)) / np.mean(original_losses))*100:.2f}%)")
    return np.mean(original_losses), np.mean(patched_losses)

def line(x, xlabel="", ylabel="", title="", xticks=None, width=800, hover_data=None):
    fig = px.line(x, title=title)
    fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, width=width)
    if xticks != None:
        fig.update_layout(
            xaxis = dict(
            tickmode = 'array',
            tickvals = [i for i in range(len(xticks))],
            ticktext = xticks
            )
        )
    if hover_data != None:
        fig.update(data=[{'customdata': hover_data, 'hovertemplate': "Loss: %{y:.4f} (+%{customdata:.2f}%)"}])
    fig.show()


In [88]:
# Indirect effect of ablating the context neurons
# Logit attribution of later components when ablating the context neuron
# Not sure how clean this is - e.g. layer 5 MLP will get the accumulated effects of all previous layers from ablating the context neuron
component_names = ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']

for key, value in prompts.items():
    # Go through components after MLP3, which contains the first German contextual neuron
    components = []
    losses = []
    for later_component in range(8, 13):
        print(f"Component: {component_names[later_component]}")
        layer = 3
        original_loss, patched_loss = multi_neuron_position_DLA(value, model, german_mean_low_activations, neurons=german_neurons, patched_component=later_component, crop_context=(0, 500))
        if len(losses) == 0:
            components.append("Original loss")
            losses.append(original_loss)
        components.append(component_names[later_component])
        losses.append(patched_loss)

    percent_increase = ((np.array(losses) - losses[0]) / losses[0]) * 100

    line(losses, xlabel="Component", ylabel="Loss", title=f"Loss of individual patching individual components when ablating German neurons <br> on prompts with the bigram {key}", xticks=components, width=800, hover_data=percent_increase.tolist())

Component: 3_mlp_out


Original loss: 0.23, patched loss: 0.23 (+-0.03%)
Component: 4_attn_out
Original loss: 0.23, patched loss: 0.22 (+-2.05%)
Component: 4_mlp_out
Original loss: 0.23, patched loss: 0.22 (+-1.74%)
Component: 5_attn_out
Original loss: 0.23, patched loss: 0.23 (+0.96%)
Component: 5_mlp_out
Original loss: 0.23, patched loss: 0.26 (+15.12%)


Component: 3_mlp_out
Original loss: 1.08, patched loss: 1.09 (+1.04%)
Component: 4_attn_out
Original loss: 1.08, patched loss: 1.08 (+0.41%)
Component: 4_mlp_out
Original loss: 1.08, patched loss: 1.06 (+-1.47%)
Component: 5_attn_out
Original loss: 1.08, patched loss: 1.09 (+1.03%)
Component: 5_mlp_out
Original loss: 1.08, patched loss: 1.09 (+1.30%)


Component: 3_mlp_out
Original loss: 6.87, patched loss: 6.87 (+0.00%)
Component: 4_attn_out
Original loss: 6.87, patched loss: 6.88 (+0.04%)
Component: 4_mlp_out
Original loss: 6.87, patched loss: 6.97 (+1.46%)
Component: 5_attn_out
Original loss: 6.87, patched loss: 6.91 (+0.50%)
Component: 5_mlp_out
Original loss: 6.87, patched loss: 6.83 (+-0.69%)


Component: 3_mlp_out
Original loss: 0.30, patched loss: 0.30 (+0.07%)
Component: 4_attn_out
Original loss: 0.30, patched loss: 0.30 (+-0.15%)
Component: 4_mlp_out
Original loss: 0.30, patched loss: 0.29 (+-2.18%)
Component: 5_attn_out
Original loss: 0.30, patched loss: 0.29 (+-2.07%)
Component: 5_mlp_out
Original loss: 0.30, patched loss: 0.32 (+6.68%)


Component: 3_mlp_out
Original loss: 0.59, patched loss: 0.59 (+-0.23%)
Component: 4_attn_out
Original loss: 0.59, patched loss: 0.59 (+-1.25%)
Component: 4_mlp_out
Original loss: 0.59, patched loss: 0.60 (+0.53%)
Component: 5_attn_out
Original loss: 0.59, patched loss: 0.59 (+-1.06%)
Component: 5_mlp_out
Original loss: 0.59, patched loss: 0.61 (+2.61%)


Component: 3_mlp_out
Original loss: 5.85, patched loss: 5.85 (+-0.05%)
Component: 4_attn_out
Original loss: 5.85, patched loss: 5.86 (+0.06%)
Component: 4_mlp_out
Original loss: 5.85, patched loss: 5.95 (+1.59%)
Component: 5_attn_out
Original loss: 5.85, patched loss: 5.88 (+0.47%)
Component: 5_mlp_out
Original loss: 5.85, patched loss: 5.81 (+-0.68%)
