### Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
haystack_utils.clean_cache()
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [3]:
# print(model.to_tokens("swim"), model.to_tokens("swam"))

In [4]:
german_neurons_with_f1 = [
    [5, 2649, 1.0],
    [8,	2994, 1.0],
    [11, 2911, 0.99],
    [10, 1129, 0.97],
    [6, 1838, 0.65],
    [7, 1594, 0.65],
    [11, 1819, 0.61],
    [11, 2014, 0.56],
    [10, 753, 0.54],
    [11, 205, 0.48],
]

important_german_neurons = defaultdict(list)
for layer, neuron, f1 in german_neurons_with_f1:
    # if f1 > 0.9:
    important_german_neurons[layer].append(neuron)

english_activations = {}
german_activations = {}
for layer in set([layer for layer, _, _ in german_neurons_with_f1]):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [5]:
mean_context_neuron_acts_active = defaultdict(list)
mean_context_neuron_acts_inactive = defaultdict(list)
for layer, neuron, _ in german_neurons_with_f1:
    mean_context_neuron_acts_active[layer].append((neuron, german_activations[layer][:, neuron].mean(0)))
    mean_context_neuron_acts_inactive[layer].append((neuron, english_activations[layer][:, neuron].mean(0)))

def get_deactivate_neurons_hook(layer):
    def deactivate_neurons_hook(value, hook):
        neurons, acts = zip(*mean_context_neuron_acts_inactive[layer])
        value[:, :, neurons] = torch.tensor(acts).cuda()
        return value
    return deactivate_neurons_hook
deactivate_neurons_fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_deactivate_neurons_hook(layer)) for layer in important_german_neurons.keys()]

def get_activate_neurons_hook(layer):
    def activate_neurons_hook(value, hook):
        neurons, acts = zip(*mean_context_neuron_acts_inactive[layer])
        value[:, :, neurons] = torch.tensor(acts).cuda()
        return value
    return activate_neurons_hook
activate_neurons_fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_activate_neurons_hook(layer)) for layer in important_german_neurons.keys()]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

### Check classification accuracy of German neurons

In [6]:
def run_single_neuron_lr(layer, neuron, num_samples=5000, german_activations=german_activations, english_activations=english_activations):
    # Check accuracy of logistic regression
    A = torch.concat([german_activations[layer][:num_samples, neuron], english_activations[layer][:num_samples, neuron]]).view(-1, 1).cpu().numpy()
    y = torch.concat([torch.ones(num_samples), torch.zeros(num_samples)]).cpu().numpy()
    A_train, A_test, y_train, y_test = train_test_split(A, y, test_size=0.2)
    lr_model = LogisticRegression()
    lr_model.fit(A_train, y_train)
    test_acc = lr_model.score(A_test, y_test)
    train_acc = lr_model.score(A_train, y_train)
    f1 = sklearn.metrics.f1_score(y_test, lr_model.predict(A_test))
    return train_acc, test_acc, f1
    
def get_neuron_accuracy(layer, neuron, german_activations=german_activations, english_activations=english_activations, plot=False, print_f1s=True):
    mean_english_activation = english_activations[layer][:,neuron].mean()
    mean_german_activation = german_activations[layer][:,neuron].mean()
    
    if plot:
        haystack_utils.two_histogram(english_activations[layer][:,neuron], german_activations[layer][:,neuron], "English", "German", "Activation", "Frequency", f"L{layer}N{neuron} activations on English vs German text")
    train_acc, test_acc, f1 = run_single_neuron_lr(layer, neuron, german_activations=german_activations, english_activations=english_activations)
    if print_f1s:
        print(f"\nL{layer}N{neuron}: F1={f1:.2f}, Train acc={train_acc:.2f}, and test acc={test_acc:.2f}")
        print(f"Mean activation English={mean_english_activation:.2f}, German={mean_german_activation:.2f}")
    return f1

In [7]:
# f1s = []
# for layer, neuron, reported_f1 in german_neurons_with_f1:
#     f1s.append(get_neuron_accuracy(layer, neuron, print_f1s=False))

# german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _ in german_neurons_with_f1]
# haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", xticks=german_neuron_names, show_legend=False)

In [8]:
# f1s = []
# for layer, neuron, _ in german_neurons_with_f1:
#     deactivate_other_neurons_fwd_hooks=[(f'blocks.{l}.mlp.hook_post', get_deactivate_neurons_hook(l)) for l in important_german_neurons.keys() if l != layer]
#     with model.hooks(deactivate_other_neurons_fwd_hooks):
#         modified_german_acts = {layer: haystack_utils.get_mlp_activations(german_data, layer, model, mean=False)}

#     f1s.append(get_neuron_accuracy(layer, neuron, german_activations=modified_german_acts, 
#                                     english_activations=english_activations, print_f1s=False, plot=False))

# german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _ in german_neurons_with_f1]
# haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", 
#                     xticks=german_neuron_names, show_legend=False)

In [9]:
# Full ablation accuracy
def ablation_effect(fwd_hooks):
    original_losses = []
    ablated_losses = []
    batch_size = 50
    for i in range(4):
        original_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss').cpu())
        with model.hooks(fwd_hooks):
            ablated_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss').cpu())

    original_loss = sum(original_losses) / len(original_losses)
    ablated_loss = sum(ablated_losses) / len(ablated_losses)

    print(original_loss, ablated_loss)
    print(f'{(ablated_loss - original_loss) / original_loss * 100:2f}% loss increase')

In [10]:
# print("Full ablation:")
# ablation_effect(deactivate_neurons_fwd_hooks)

def get_neuron_hook(layer, neuron, inactive_value):
    def deactivate_neuron_hook(value, hook):
        value[:, :, neuron] = inactive_value
        return value
    return [(f'blocks.{layer}.mlp.hook_post', deactivate_neuron_hook)]

# for layer, neuron, f1 in german_neurons_with_f1:
#     print(f"Ablate L{layer}N{neuron} context neuron with f1 of {f1}:")
#     ablation_effect(get_neuron_hook(layer, neuron, english_activations[layer][:, neuron].mean()))

All context neurons are in the output path of the L5 context neuron (L8 and L11 less and the rest more).

Most circuits are in the output path of the L8 context neuron.

In [11]:
deactivate_context_hooks = get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()) + get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean())
activate_context_hooks = get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean()) + get_neuron_hook(8, 2994,  german_activations[8][:, 2994].mean())

# print('L5 and L8 ablations respectively')
# ablation_effect(get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()))
# ablation_effect(get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean()))

# print('L5 and L8 ablated together')
# ablation_effect(deactivate_context_hooks)

# print('L5 and L8 activated together - maybe loss increases because one of the neurons is really trimodal so a single activation value removes lots of information')
# ablation_effect(activate_context_hooks)

# print('Sanity check - ablating the first context neuron + one other random context neuron doesn\'t non-linearly increases loss like the L5 + L8 neuron combo')
# ablation_effect(get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()) + get_neuron_hook(6, 1838, english_activations[6][:, 1838].mean()))

Ablating them together causes the majority of the loss, even through their individual ablation loss increases sum to less than this.

Perhaps there are circuits which rely on both - somewhat AND gates, although the context neurons are dependent.

### Trimodal context neuron in L8

In [12]:
# german_acts_5 = []
# german_acts_8 = []
# for prompt in german_data:
#     _, cache = model.run_with_cache(prompt)
#     german_acts_5 += cache['post', 5][:, :, 2649].flatten().tolist()
#     german_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

# english_acts_5 = []
# english_acts_8 = []
# for prompt in english_data:
#     _, cache = model.run_with_cache(prompt)
#     english_acts_5 += cache['post', 5][:, :, 2649].flatten().tolist()
#     english_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

# 1 - 3
# 6 - 8

In [13]:
# px.histogram(german_acts_8)

In [14]:
# px.histogram(english_acts_8)

In [15]:
french_data = haystack_utils.load_json_data('data/french_data.json')

# french_acts_8 = []
# for prompt in french_data:
#     _, cache = model.run_with_cache(prompt)
#     french_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

# px.histogram(french_acts_8)

data/french_data.json: Loaded 1000 examples with 200 to 2988 characters each.


#### Coloring tokens by which mode they're closest to

In [16]:
def custom_print_strings_as_html(strings: list[str], color_values: list[float], max_value: float=None, additional_measures: list[list[float]] | None = None, additional_measure_names: list[str] | None = None):
    """ Magic GPT function that prints a string as HTML and colors it according to a list of color values. Color values are normalized to the max value preserving the sign.
    """

    def normalize(values, max_value=None, min_value=None):
        if max_value is None:
            max_value = max(values)
        if min_value is None:
            min_value = min(values)
        min_value = abs(min_value)
        normalized = [(value / max_value if value > 0 else value / min_value) for value in values]
        return normalized
    
    html = "<div>"
    
    # Normalize color values
    normalized_values = normalize(color_values, max_value, max_value)


    cmap = cmap=plt.cm.PiYG

    for i in range(len(strings)):
        normalized_color = normalized_values[i].cpu()
        
        # Use colormap to get RGB values
        r, g, b, _ = cmap(normalized_color)

        # Scale RGB values to 0-255
        red, green, blue = [int(255*v) for v in (r, g, b)]
        
        # Calculate luminance to determine if text should be black
        luminance = (0.299 * red + 0.587 * green + 0.114 * blue) / 255
        
        # Determine text color based on background luminance
        text_color = "black" if luminance > 0.5 else "white"

        visible_string = re.sub(r'\s+', '_', strings[i])
        
        html += f'<span style="background-color: rgb({red}, {green}, {blue}); color: {text_color}; padding: 2px;" '
        html += f'title="Difference: {color_values[i]:.4f}' 
        if additional_measure_names is not None:
            for j in range(len(additional_measure_names)):
                html += f', {additional_measure_names[j]}: {additional_measures[j][i]:.4f}'
        html += f'">{visible_string}</span>'
    html += '</div>'

    # Print the HTML in Jupyter Notebook
    display(HTML(html))


In [17]:
def trimodal_interest_measure(l8_n2994_acts):
    first_mode, second_mode, third_mode = 0, 3.5, 5.5
    
    diffs = torch.stack([l8_n2994_acts - first_mode, l8_n2994_acts - second_mode, l8_n2994_acts - third_mode]).cuda()
    diffs = torch.abs(diffs)
    
    min_values, min_indices = torch.min(diffs, dim=0)
    min_indices[min_values > 0.3] = 3
    return min_indices

def trimodal_print_prompt(prompt: str):
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    _, cache = model.run_with_cache(prompt)
    pos_wise_diff = trimodal_interest_measure(cache['post', 8][0, :, 2994])
    pos_wise_diff = pos_wise_diff.float()
    # not German
    pos_wise_diff[pos_wise_diff == 0] = 0
    # first German mode
    pos_wise_diff[pos_wise_diff == 1] = 0.4
    # second German mode
    pos_wise_diff[pos_wise_diff == 2] = -0.4
    # Not any particular mode
    pos_wise_diff[pos_wise_diff == 3] = 1.1

    custom_print_strings_as_html(str_token_prompt[1:], pos_wise_diff[1:], max_value=2)

for prompt in german_data[:10]:
    trimodal_print_prompt(prompt)

# new word detection


### Utils

In [94]:
def component_effects_german(prompt, index):
        original, ablated, direct_effect, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.6.hook_attn_out", "blocks.7.hook_attn_out", "blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.6.hook_mlp_out", "blocks.7.hook_mlp_out", "blocks.9.hook_mlp_out", "blocks.11.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",))
        
        _, _, _, only_activated_loss_mlp_9 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.11.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.9.hook_mlp_out",))

        _, _, _, only_activated_loss_attn_9 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.9.hook_attn_out",))

        _, _, _, only_activated_loss_mlp_10 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.10.hook_mlp_out",))

        _, _, _, only_activated_loss_attn_10 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.10.hook_attn_out",))

        _, _, _, only_activated_loss_mlp_11 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.11.hook_mlp_out",))

        _, _, _, only_activated_loss_attn_11 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out","blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out", ),
                activated_components = ("blocks.11.hook_attn_out",))

        data = [original, ablated, direct_effect, only_activated_loss_mlp_9, only_activated_loss_attn_9, only_activated_loss_mlp_10, only_activated_loss_attn_10, only_activated_loss_mlp_11, only_activated_loss_attn_11]
        return data


In [None]:
def left_pad(prompts, model):
    tokens = model.to_tokens(prompts)
    target_length = tokens.shape[1]

    results = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)[0]
        padded_tokens = torch.cat([torch.zeros((target_length - tokens.shape[0],), dtype=int).cuda(), tokens])
        results.append(padded_tokens)

    return torch.stack(results)

### ord -> n

In [18]:
english_data_long = haystack_utils.load_json_data("data/english_europarl.json")
stack_exchange_data = load_dataset('habedi/stack-exchange-dataset', split='train')

data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [19]:
def print_counter(token: str, data: list[str], prev_tokens_count = 1):
    counter = Counter()
    ord_token = model.to_single_token(token)
    for prompt in data:
        tokens = model.to_tokens(prompt)[0]
        try: 
            index = tokens.tolist().index(ord_token)
        except:
            continue
        if index + prev_tokens_count < len(tokens):
            next_tokens = tokens[index : index + prev_tokens_count + 1]
            next_tokens_str = "".join(model.to_str_tokens(next_tokens))
            counter.update([next_tokens_str])
    print(counter)

print_counter('ord', english_data_long)
print_counter('ord', german_data)
print_counter('ord', french_data)

Counter({'ordic': 7, 'ord\n': 4, 'ordance': 1, 'ordana': 1, 'ord,': 1, 'ord.': 1, 'ordination': 1})
Counter({'ordn': 59, 'ordnet': 20, 'ordung': 2, 'ordischen': 1})
Counter({'ordre': 97, 'ord,': 10, 'ord de': 9, 'ordina': 6, 'ord vous': 3, 'ordique': 3, 'ord CE': 3, 'ord commercial': 3, 'ordé': 3, 'ordée': 3, 'ord rem': 2, 'ord entre': 2, 'ord sur': 2, 'ordance': 1, 'ord vot': 1, 'ordés': 1, 'ord.': 1, 'ord d': 1, 'ordana': 1, 'ord-': 1, 'ord inter': 1, 'ord UE': 1, 'ord le': 1, 'ord faire': 1, 'ord don': 1, 'ord voter': 1, 'ord à': 1, 'ord les': 1, 'ord Br': 1})


We can see that the most common 'ord' completions are 'ordre' in French, 'ordic' in English, and 'ordn' in German. Does a post-context neuron MLP hold these bigram statistics? 

First we'll build a dataset of German completions and see how things work. 
Then can build a dataset of the French completions and manually activate the German context neurons, then see how the logprob of 'ord' -> 'n' improves.

In [57]:
german_ord = []
target_completion = 'ordn'
for prompt in german_data:
    if target_completion in prompt:
        german_ord.append(prompt[:prompt.index(target_completion)] + target_completion)

print(len(german_ord))

81


In [87]:
print(german_ord[3])

Tagesordn


In [89]:
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=100)

  0%|          | 0/200 [00:00<?, ?it/s]

In [90]:
# Try trigram because ordn on its own is hard for the model to predict
random_ord = haystack_utils.generate_random_prompts('Tagesordn', model, common_tokens, 400, length=20)

datas = []

for i in range(0, len(random_ord) - 10, 10):
    datas.append(component_effects_german(random_ord, -1, plot=False))

data = [torch.stack([item[i] for item in datas]) for i in range(len(datas[0]))]
haystack_utils.plot_barplot([item.cpu().flatten().tolist() for item in torch.stack(data)],
                                names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP", "Attn"]])

# Clearly MLP11

In [28]:
# Get example prompts with 'ordre' completion
french_ord = []
target_completion = 'ordre'
for prompt in french_data:
    if target_completion in prompt:
        french_ord.append(prompt[:prompt.index(target_completion)] + target_completion)

print(len(french_ord))

113


In [55]:
activate_context_hooks = get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean() * 2) + get_neuron_hook(8, 2994,  german_activations[8][:, 2994].mean() * 2)

french_completion = model.to_single_token('re')
german_completion = model.to_single_token('n')

original_logprobs, ablated_logprobs, _, _ = haystack_utils.get_direct_effect(
    french_ord[0], model, pos=None, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[], return_type='logprobs')

print(original_logprobs[-2, german_completion])
print(original_logprobs[-2, german_completion] - original_logprobs[-2, french_completion])
print(ablated_logprobs[-2, german_completion] - ablated_logprobs[-2, french_completion])

tensor(-19.6804, device='cuda:0')
tensor(-19.6758, device='cuda:0')
tensor(-4.3472, device='cuda:0')


Figure out where the ord circuit occurs

In [109]:
# Reverses the German context activation hooks for use on French text - the ablation hook activates the German context
def language_logprob_diffs(prompt, index, german_token, french_token):
        logprobs_original, logprobs_ablated, direct_effect, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.6.hook_attn_out", "blocks.7.hook_attn_out", "blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.6.hook_mlp_out", "blocks.7.hook_mlp_out", "blocks.9.hook_mlp_out", "blocks.11.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",), return_type='logprobs')
        
        _, _, _, logprobs_mlp_9 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.11.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.9.hook_mlp_out",), return_type='logprobs')

        _, _, _, logprobs_attn_9 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.9.hook_attn_out",), return_type='logprobs')

        _, _, _, logprobs_mlp_10 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.10.hook_mlp_out",), return_type='logprobs')

        _, _, _, logprobs_attn_10 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.9.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.10.hook_attn_out",), return_type='logprobs')

        _, _, _, logprobs_mlp_11 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.11.hook_mlp_out",), return_type='logprobs')

        _, _, _, logprobs_attn_11 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out","blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out", ),
                activated_components = ("blocks.11.hook_attn_out",), return_type='logprobs')

        logprobs_original_diff = logprobs_original[french_token] - logprobs_original[german_token]
        ablated_diff = logprobs_ablated[french_token] - logprobs_ablated[german_token]
        direct_effect_diff = direct_effect[french_token] - direct_effect[german_token]
        logprobs_mlp_9_diff = logprobs_mlp_9[french_token] - logprobs_mlp_9[german_token]
        logprobs_attn_9_diff = logprobs_attn_9[french_token] - logprobs_attn_9[german_token]
        logprobs_mlp_10_diff = logprobs_mlp_10[french_token] - logprobs_mlp_10[german_token]
        logprobs_attn_10_diff = logprobs_attn_10[french_token] - logprobs_attn_10[german_token]
        logprobs_mlp_11_diff = logprobs_mlp_11[french_token] - logprobs_mlp_11[german_token]
        logprobs_attn_11_diff = logprobs_attn_11[french_token] - logprobs_attn_11[german_token]

        data = [logprobs_original_diff, ablated_diff, direct_effect_diff, logprobs_mlp_9_diff, logprobs_attn_9_diff, 
                logprobs_mlp_10_diff, logprobs_attn_10_diff, logprobs_mlp_11_diff, logprobs_attn_11_diff]
        return data

data = language_logprob_diffs(french_ord[0], -2, model.to_single_token('n'), model.to_single_token('re'))
haystack_utils.plot_barplot([[item.cpu()] for item in data],
                                names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP", "Attn"]],
                                title="Log prob difference between French and German tokens when activating German context")



### DLA & component-level path patching

In [20]:
# logit_attr_original, labels = haystack_utils.DLA(german_data, model)

# # Patch in disabled context neurons and plot the direct logit attribution difference for each component
# with model.hooks(fwd_hooks=deactivate_context_hooks):
#     logit_attr_ablated, _ = haystack_utils.DLA(german_data, model)

# logit_diffs = (logit_attr_original - logit_attr_ablated).mean(0)
# # The small differences accumulated before the ablation are due to the final layer norm scale being affected by the L3 hook.
# haystack_utils.line(logit_diffs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="(Original DLA - Ablated DLA) per component", xticks=labels)

The direct loss increases are relatively small until layer 9, implying that the direct effects of the ablations are more minor and that most 
differences are indirect effects of the context neurons in layers 5 and 8 starting in layer 9.

### High loss prompts - MLP11

In [21]:
def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    """Per-token measure, mixture of overall loss increase and loss increase from ablating MLP11"""
    loss_diff = (ablated_loss - original_loss) # Loss increase from context neuron
    mlp_11_power = (only_activated_loss - original_loss) # Loss increase from MLP5
    mlp_11_power[mlp_11_power < 0] = 0
    combined = 0.5 * loss_diff - mlp_11_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

def print_prompt(prompt: str):
    """Red/blue scale showing the interest measure for each token"""
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))

    pos_wise_diff = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, additional_measures=loss_list, additional_measure_names=loss_names)

def get_mlp11_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    """Token with max interest measure"""
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

average_loss_plot = haystack_utils.get_average_loss_plot_method(activate_context_hooks, deactivate_context_hooks, "MLP11",
                                                                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                                activated_components = ("blocks.11.hook_mlp_out"))


In [22]:
# Get general and MLP11 specific losses
german_losses = []
for prompt in tqdm(german_data):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks, 
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))
    german_losses.append((original_loss, ablated_loss, context_and_activated_loss, only_activated_loss))

  0%|          | 0/200 [00:00<?, ?it/s]

In [23]:
measure = get_mlp11_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)
for i, measure in sorted_measure[:2]:
    print(measure)
    print_prompt(german_data[i])

2.767449140548706


2.019975185394287


### Patch patch components

In [24]:
def get_prompt_and_token():
    "Prompt with token which generates highest loss difference"
    for prompt in german_data:
        original_loss, _ = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
        with model.hooks(deactivate_context_hooks):
            ablated_loss, _ = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
        value, index = torch.max(ablated_loss - original_loss, dim=1)
        if value > 3 and original_loss[0, index] < 3:
            return prompt, index
    return '', -1
        
prompt, index = get_prompt_and_token()

In [96]:

data = component_effects_german(prompt, index)
haystack_utils.plot_barplot([[item] for item in data],
                            names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP", "Attn"]])