### Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

def get_neuron_hook(layer, neuron, act_value):
    def neuron_hook(value, hook):
        value[:, :, neuron] = act_value
        return value
    return (f'blocks.{layer}.mlp.hook_post', neuron_hook)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [3]:
german_neurons_with_f1 = [
    [5, 2649, 1.0],
    [8,	2994, 1.0],
    [11, 2911, 0.99],
    [10, 1129, 0.97],
    [6, 1838, 0.65],
    [7, 1594, 0.65],
    [11, 1819, 0.61],
    [11, 2014, 0.56],
    [10, 753, 0.54],
    [11, 205, 0.48],
]

english_activations = {}
german_activations = {}
for layer in set([layer for layer, *_ in german_neurons_with_f1]):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

for item in german_neurons_with_f1:
    layer, neuron, f1 = item
    item.append(german_activations[layer][:, neuron].mean(0))
    item.append(english_activations[layer][:, neuron].mean(0))

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=100)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

### Check classification accuracy of German neurons

In [4]:
def run_single_neuron_lr(layer, neuron, num_samples=5000, german_activations=german_activations, english_activations=english_activations):
    # Check accuracy of logistic regression
    A = torch.concat([german_activations[layer][:num_samples, neuron], english_activations[layer][:num_samples, neuron]]).view(-1, 1).cpu().numpy()
    y = torch.concat([torch.ones(num_samples), torch.zeros(num_samples)]).cpu().numpy()
    A_train, A_test, y_train, y_test = train_test_split(A, y, test_size=0.2)
    lr_model = LogisticRegression()
    lr_model.fit(A_train, y_train)
    test_acc = lr_model.score(A_test, y_test)
    train_acc = lr_model.score(A_train, y_train)
    f1 = sklearn.metrics.f1_score(y_test, lr_model.predict(A_test))
    return train_acc, test_acc, f1
    
def get_neuron_accuracy(layer, neuron, german_activations=german_activations, english_activations=english_activations, plot=False, print_f1s=True):
    mean_english_activation = english_activations[layer][:,neuron].mean()
    mean_german_activation = german_activations[layer][:,neuron].mean()
    
    if plot:
        haystack_utils.two_histogram(english_activations[layer][:,neuron], german_activations[layer][:,neuron], "English", "German", "Activation", "Frequency", f"L{layer}N{neuron} activations on English vs German text")
    train_acc, test_acc, f1 = run_single_neuron_lr(layer, neuron, german_activations=german_activations, english_activations=english_activations)
    if print_f1s:
        print(f"\nL{layer}N{neuron}: F1={f1:.2f}, Train acc={train_acc:.2f}, and test acc={test_acc:.2f}")
        print(f"Mean activation English={mean_english_activation:.2f}, German={mean_german_activation:.2f}")
    return f1

In [5]:
f1s = []
for layer, neuron, reported_f1, _, _ in german_neurons_with_f1:
    f1s.append(get_neuron_accuracy(layer, neuron, print_f1s=False))

german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _, _, _ in german_neurons_with_f1]
haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", xticks=german_neuron_names, show_legend=False)

In [6]:
f1s = []
for layer, neuron, _, _, _ in german_neurons_with_f1:
    deactivate_other_neurons_fwd_hooks=[get_neuron_hook(l, neuron, deact_val) for l, neuron, _, _, deact_val in german_neurons_with_f1 if l != layer]
    with model.hooks(deactivate_other_neurons_fwd_hooks):
        modified_german_acts = {layer: haystack_utils.get_mlp_activations(german_data, layer, model, mean=False)}

    f1s.append(get_neuron_accuracy(layer, neuron, german_activations=modified_german_acts, 
                                    english_activations=english_activations, print_f1s=False, plot=False))

german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, *_ in german_neurons_with_f1]
haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", 
                    xticks=german_neuron_names, show_legend=False)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
# Full ablation accuracy
def ablation_effect(fwd_hooks):
    original_losses = []
    ablated_losses = []
    batch_size = 50
    for i in range(4):
        original_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss').cpu())
        with model.hooks(fwd_hooks):
            ablated_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss').cpu())

    original_loss = sum(original_losses) / len(original_losses)
    ablated_loss = sum(ablated_losses) / len(ablated_losses)

    print(original_loss, ablated_loss)
    print(f'{(ablated_loss - original_loss) / original_loss * 100:2f}% loss increase')

In [None]:
first_mode, second_mode, third_mode = 0, 3.5, 5.5
def trimodal_hook(value, hook):
    neuron_act = value[:, :, 2994]

    diffs = torch.stack([neuron_act - first_mode, neuron_act - second_mode, neuron_act - third_mode]).cuda()
    diffs = torch.abs(diffs)
    min_values, min_indices = torch.min(diffs, dim=0)

    value[:, :, 2994] = torch.where(min_indices == 0, first_mode, torch.where(min_indices == 1, second_mode, third_mode))

    return value

ablation_effect([("blocks.8.mlp.hook_post", trimodal_hook)])

first_mode, second_mode, third_mode = 0, 3.5, 5.5
def trimodal_hook(value, hook):
    neuron_act = value[:, :, 2994]

    diffs = torch.stack([neuron_act - first_mode, neuron_act - second_mode, neuron_act - third_mode]).cuda()
    diffs = torch.abs(diffs)
    min_values, min_indices = torch.min(diffs, dim=0)

    value[:, :, 2994] = torch.where(min_indices == 0, first_mode, torch.where(min_indices == 1, 4.5, 4.5))

    return value

ablation_effect([("blocks.8.mlp.hook_post", trimodal_hook)])

tensor(3.6835) tensor(3.6888)
0.141645% loss increase
tensor(3.6835) tensor(3.6959)
0.336715% loss increase


RuntimeError: The size of tensor a (679) must match the size of tensor b (3072) at non-singleton dimension 2

In [None]:
first_mode, second_mode, third_mode = 0, 3.5, 5.5
def trimodal_hook(value, hook):
    neuron_act = value[:, :, 2994]
    diffs = torch.stack([neuron_act - first_mode, neuron_act - second_mode, neuron_act - third_mode]).cuda()
    diffs = torch.abs(diffs)
    _, min_indices = torch.min(diffs, dim=0)

    value[:, :, 2994] = torch.where(min_indices == 0, neuron_act, torch.where(min_indices == 1, 4.5, 4.5))

    return value

ablation_effect([("blocks.8.mlp.hook_post", trimodal_hook)])

tensor(3.6835) tensor(3.6933)
0.264106% loss increase


In [None]:
first_mode, second_mode, third_mode = 0, 3.5, 5.5
def trimodal_hook(value, hook):
    neuron_act = value[:, :, 2994]
    diffs = torch.stack([neuron_act - first_mode, neuron_act - second_mode, neuron_act - third_mode]).cuda()
    diffs = torch.abs(diffs)
    _, min_indices = torch.min(diffs, dim=0)

    value[:, :, 2994] = torch.where(min_indices == 0, neuron_act, torch.where(min_indices == 1, 5.5, 3.5))

    return value

ablation_effect([("blocks.8.mlp.hook_post", trimodal_hook)])

tensor(3.6835) tensor(3.7111)
0.748032% loss increase


In [None]:
print("Full ablation:")
deactivate_all_neurons_fwd_hooks=[get_neuron_hook(layer, neuron, deact_val) for layer, neuron, _, _, deact_val in german_neurons_with_f1]
ablation_effect(deactivate_all_neurons_fwd_hooks)

for layer, neuron, f1, *_ in german_neurons_with_f1:
    print(f"Ablate L{layer}N{neuron} context neuron with f1 of {f1}:")
    ablation_effect([get_neuron_hook(layer, neuron, english_activations[layer][:, neuron].mean())])

All context neurons are in the output path of the L5 context neuron (L8 and L11 less and the rest more).

Most circuits are in the output path of the L8 context neuron.

In [None]:
deactivate_context_hooks = [get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()), get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean())]
activate_context_hooks = [get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean()), get_neuron_hook(8, 2994,  german_activations[8][:, 2994].mean())]

# print('L5 and L8 ablations respectively')
# ablation_effect([get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean())])
# ablation_effect([get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean())])

# print('L5 and L8 ablated together')
# ablation_effect(deactivate_context_hooks)

# print('L5 and L8 activated together - maybe loss increases because one of the neurons is really trimodal so a single activation value removes lots of information')
# ablation_effect(activate_context_hooks)

# print('Sanity check - ablating the first context neuron + one other random context neuron doesn\'t non-linearly increases loss like the L5 + L8 neuron combo')
# ablation_effect([get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()), get_neuron_hook(6, 1838, english_activations[6][:, 1838].mean())])

Ablating them together causes the majority of the loss, even through their individual ablation loss increases sum to less than this.

Perhaps there are circuits which rely on both - somewhat AND gates, although the context neurons are dependent.

### Trimodal context neuron in L8

In [None]:
german_acts_5 = []
german_acts_8 = []
for prompt in german_data:
    _, cache = model.run_with_cache(prompt)
    german_acts_5 += cache['post', 5][:, :, 2649].flatten().tolist()
    german_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

english_acts_5 = []
english_acts_8 = []
for prompt in english_data:
    _, cache = model.run_with_cache(prompt)
    english_acts_5 += cache['post', 5][:, :, 2649].flatten().tolist()
    english_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

1 - 3
6 - 8

-2

In [None]:
for layer, neuron, *_ in german_neurons_with_f1:
    print(f'L{layer}N{neuron}')


5 2649
8 2994
11 2911
10 1129
6 1838
7 1594
11 1819
11 2014
10 753
11 205


In [None]:
acts_dict = {}
for layer, neuron, *_ in german_neurons_with_f1:
    acts = []
    for prompt in german_data[:20]:
        _, cache = model.run_with_cache(prompt)
        acts += cache['post', layer][:, :, neuron].flatten().tolist()
    acts_dict[f"L{layer}N{neuron}"] = acts

px.histogram(acts_dict["L5N2649"], title=f"L5N2649")
px.histogram(acts_dict["L8N2994"], title=f"L8N2994")
px.histogram(acts_dict["L11N2911"], title=f"L11N2911")
px.histogram(acts_dict["L10N2994"], title=f"L10N2994")
px.histogram(acts_dict["L6N2994"], title=f"L6N2994")
px.histogram(acts_dict["L7N2994"], title=f"L7N2994")
px.histogram(acts_dict["L11N2994"], title=f"L11N2994")
px.histogram(acts_dict["L11N2994"], title=f"L11N2994")
px.histogram(acts_dict["L10N2994"], title=f"L10N2994")
px.histogram(acts_dict["L11N2994"], title=f"L11N2994")


In [None]:
px.histogram(german_acts_8)

In [None]:
px.histogram(english_acts_8)

In [None]:
french_data = haystack_utils.load_json_data('data/french_data.json')

french_acts_8 = []
for prompt in french_data:
    _, cache = model.run_with_cache(prompt)
    french_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

px.histogram(french_acts_8)

data/french_data.json: Loaded 1000 examples with 200 to 2988 characters each.


#### Coloring tokens by which mode they're closest to

In [None]:
def custom_print_strings_as_html(strings: list[str], color_values: list[float], max_value: float=None, additional_measures: list[list[float]] | None = None, additional_measure_names: list[str] | None = None):
    """ Magic GPT function that prints a string as HTML and colors it according to a list of color values. Color values are normalized to the max value preserving the sign.
    """

    def normalize(values, max_value=None, min_value=None):
        if max_value is None:
            max_value = max(values)
        if min_value is None:
            min_value = min(values)
        min_value = abs(min_value)
        normalized = [(value / max_value if value > 0 else value / min_value) for value in values]
        return normalized
    
    html = "<div>"
    
    # Normalize color values
    normalized_values = normalize(color_values, max_value, max_value)


    cmap = cmap=plt.cm.Pastel1

    for i in range(len(strings)):
        normalized_color = normalized_values[i].cpu()
        
        # Use colormap to get RGB values
        r, g, b, _ = cmap(normalized_color)

        # Scale RGB values to 0-255
        red, green, blue = [int(255*v) for v in (r, g, b)]
        
        # Calculate luminance to determine if text should be black
        luminance = (0.299 * red + 0.587 * green + 0.114 * blue) / 255
        
        # Determine text color based on background luminance
        text_color = "black" if luminance > 0.5 else "white"

        visible_string = re.sub(r'\s+', '_', strings[i])
        
        html += f'<span style="background-color: rgb({red}, {green}, {blue}); color: {text_color}; padding: 2px;" '
        html += f'title="Difference: {color_values[i]:.4f}' 
        if additional_measure_names is not None:
            for j in range(len(additional_measure_names)):
                html += f', {additional_measure_names[j]}: {additional_measures[j][i]:.4f}'
        html += f'">{visible_string}</span>'
    html += '</div>'

    # Print the HTML in Jupyter Notebook
    display(HTML(html))


In [None]:
def trimodal_interest_measure(l8_n2994_acts):
    first_mode, second_mode, third_mode = 0, 3.5, 5.5
    
    diffs = torch.stack([l8_n2994_acts - first_mode, l8_n2994_acts - second_mode, l8_n2994_acts - third_mode]).cuda()
    diffs = torch.abs(diffs)
    
    min_values, min_indices = torch.min(diffs, dim=0)
    min_indices[min_values > 0.3] = 3
    return min_indices

def get_token_modes(prompt: str):
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    _, cache = model.run_with_cache(prompt)
    pos_wise_diff = trimodal_interest_measure(cache['post', 8][0, :, 2994])
    pos_wise_diff = pos_wise_diff.float()
    # not German
    pos_wise_diff[pos_wise_diff == 0] = 0
    # first German mode
    pos_wise_diff[pos_wise_diff == 1] = 0.4
    # second German mode
    pos_wise_diff[pos_wise_diff == 2] = 0.8
    # Not any particular mode
    pos_wise_diff[pos_wise_diff == 3] = 1
    
    return str_token_prompt, pos_wise_diff

for prompt in german_data[:10]:
    str_token_prompt, pos_wise_diff = get_token_modes(prompt)
    custom_print_strings_as_html(str_token_prompt, pos_wise_diff, max_value=2)

IndexError: index 5 is out of bounds for dimension 0 with size 1

In [None]:
# new word detection
# collect stats on whether blue occurs before a space or punctuation more than purple

counter = Counter()
for prompt in german_data[:100]:
    tokens = model.to_tokens(prompt)[0]
    str_token_prompt, pos_wise_diff = get_token_modes(prompt)
    for i in range(tokens.shape[0] - 1):
        if pos_wise_diff[i] == 0.4:
            counter["0.4"] += 1
            next_token_str = model.to_single_str_token(tokens[i + 1].item())
            if next_token_str.startswith(' ') or next_token_str.startswith('.') or next_token_str.startswith(','):
                counter["0.4_positive"] += 1
        if pos_wise_diff[i] == 0.8:
            counter["0.8"] += 1
            next_token_str = model.to_single_str_token(tokens[i + 1].item())
            if next_token_str.startswith(' ') or next_token_str.startswith('.') or next_token_str.startswith(','):
                counter["0.8_positive"] += 1

print(counter)
print(counter["0.4_positive"] / max(0.01, counter["0.4"]))
print(counter["0.8_positive"] / max(0.01, counter["0.8"]))



Counter({'0.4': 7224, '0.8': 6417, '0.8_positive': 5456, '0.4_positive': 1039})
0.14382613510520487
0.8502415458937198


### Utils

In [None]:
def mlp_effects_german(prompt, index):
        """Customised to L5 and L8 context neurons"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]
     
        original, ablated, direct_effect, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components=tuple(downstream_components), activated_components=("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",))
        
        data = [original, ablated, direct_effect]
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_mlp_out",))
                data.append(activated_component_loss)
        return data

def attn_effects_german(prompt, index):
        """Customised to L5 and L8 context neurons"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]

        data = []
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_attn_out",))
                data.append(activated_component_loss)
        return data

def component_analysis(end_strings: list[str] | str):
    if isinstance(end_strings, str):
        end_strings = [end_strings]
    for end_string in end_strings:
        print(model.to_str_tokens(end_string))
        random_prompts = haystack_utils.generate_random_prompts(end_string, model, common_tokens, 400, length=20)
        data = mlp_effects_german(random_prompts, -1)

        haystack_utils.plot_barplot([[item.cpu().flatten().mean().item()] for item in data],
                                        names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP"]], # + ["MLP9 + MLP11"]
                                        title=f'Loss increases from ablating various MLP components for end string \"{end_string}\"')
        
def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    """Per-token measure, mixture of overall loss increase and loss increase from ablating MLP11"""
    loss_diff = (ablated_loss - original_loss) # Loss increase from context neuron
    mlp_11_power = (only_activated_loss - original_loss) # Loss increase from MLP11
    mlp_11_power[mlp_11_power < 0] = 0
    combined = 0.5 * loss_diff - mlp_11_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

def print_prompt(prompt: str):
    """Red/blue scale showing the interest measure for each token"""
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))

    pos_wise_diff = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, additional_measures=loss_list, additional_measure_names=loss_names)

def get_mlp11_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    """Token with max interest measure"""
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

def print_counter(token: str, data: list[str], next_tokens_count = 1):
    counter = Counter()
    token_index = model.to_single_token(token)
    for prompt in data:
        tokens = model.to_tokens(prompt)[0]
        try: 
            index = tokens.tolist().index(token_index)
        except:
            continue
        if index + next_tokens_count < len(tokens):
            next_tokens = tokens[index : index + next_tokens_count + 1]
            next_tokens_str = "".join(model.to_str_tokens(next_tokens))
            counter.update([next_tokens_str])
    print(counter)

def get_prompts_with_token(target_completion: str, data: list[str]):
    found_prompts = []
    for prompt in data:
        if target_completion in prompt:
            found_prompts.append(prompt[:prompt.index(target_completion)] + target_completion)
    return found_prompts

def left_pad(prompts, model):
    tokens = model.to_tokens(prompts)
    target_length = tokens.shape[1]

    results = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)[0]
        padded_tokens = torch.cat([torch.zeros((target_length - tokens.shape[0],), dtype=int).cuda(), tokens])
        results.append(padded_tokens)

    return torch.stack(results)

In [None]:
def mlp_language_logprob_diffs(prompt, index, german_token, french_token):
        """Customised to L5 and L8 context neurons. 
        Reverses the German context activation hooks for use on French text - the ablation hook activates the German context"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]
     
        original_logprobs, ablated_logprobs, direct_effect_logprobs, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components=tuple(downstream_components), activated_components=("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",),
                return_type='logprobs')
        
        data = [original_logprobs, ablated_logprobs, direct_effect_logprobs]
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_mlp_out",), return_type='logprobs')
                data.append(activated_component_loss)

        data = [item[:, french_token] - item[:, german_token] for item in data]
        return data

### ord -> n

In [None]:
english_data_long = haystack_utils.load_json_data("data/english_europarl.json")
stack_exchange_data = load_dataset('habedi/stack-exchange-dataset', split='train')

data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


Downloading readme:   0%|          | 0.00/20.0 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/48.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/39.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
print_counter('ord', english_data_long)
print_counter('ord', german_data)
print_counter('ord', french_data)

Counter({'ordic': 7, 'ord\n': 4, 'ordance': 1, 'ordana': 1, 'ord,': 1, 'ord.': 1, 'ordination': 1})
Counter({'ordn': 59, 'ordnet': 20, 'ordung': 2, 'ordischen': 1})
Counter({'ordre': 97, 'ord,': 10, 'ord de': 9, 'ordina': 6, 'ord vous': 3, 'ordique': 3, 'ord CE': 3, 'ord commercial': 3, 'ordé': 3, 'ordée': 3, 'ord rem': 2, 'ord entre': 2, 'ord sur': 2, 'ordance': 1, 'ord vot': 1, 'ordés': 1, 'ord.': 1, 'ord d': 1, 'ordana': 1, 'ord-': 1, 'ord inter': 1, 'ord UE': 1, 'ord le': 1, 'ord faire': 1, 'ord don': 1, 'ord voter': 1, 'ord à': 1, 'ord les': 1, 'ord Br': 1})


We can see that the most common 'ord' completions are 'ordre' in French, 'ordic' in English, and 'ordn' in German. Does a post-context neuron MLP hold these bigram statistics? 

First we'll build a dataset of German completions and see how things work. 
Then can build a dataset of the French completions and manually activate the German context neurons, then see how the logprob of 'ord' -> 'n' improves.

In [None]:
german_ord = get_prompts_with_token('ord', german_data)
print(len(german_ord))
print(german_ord[6])

120
Fischereitätigkeiten von Fischereifahrzeugen der Gemeinschaft außerhalb der Gemeinschaftsgewässer und Zugang von Drittlandschiffen zu Gemeinschaftsgewässern (Aussprache) 
Die Präsidentin
Als nächster Punkt folgt der Bericht von Philippe Morillon im Namen des Fischereiausschusses über den Vorschlag für eine Verord


In [None]:
# Try trigram because ordn on its own is hard for the model to predict
random_ord = haystack_utils.generate_random_prompts(' Tagesordn', model, common_tokens, 400, length=20)

data = mlp_effects_german(random_ord, -1)
_, _, _, only_activated_loss_mlp_9_11 = haystack_utils.get_direct_effect(
                random_ord[:10], model, pos=-1, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.11.hook_mlp_out", "blocks.9.hook_mlp_out",))
data.append(only_activated_loss_mlp_9_11)

In [None]:
haystack_utils.plot_barplot([[item.cpu().flatten().mean().item()] for item in data],
                                names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP"]] + ["MLP9 + MLP11"],
                                title='Loss increases from ablating various MLP components')

In [None]:
def replace_token_loss(prompts, replace_index, num_replacements=10):
    """Replace the token at the given index with a random token many times and return the losses"""
    new_prompts = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)
        for i in range(num_replacements):
            new_tokens = tokens.clone()
            new_tokens[0, replace_index] = common_tokens[i]
            new_prompts.append(new_tokens)
    losses = []
    for prompt in new_prompts:    
        with model.hooks(fwd_hooks=activate_context_hooks):
            original_loss = model(prompt, return_type="loss", loss_per_token=True).flatten()[-1].item()
        losses.append(original_loss)
    return losses


In [None]:
# Slow
def replace_token_loss_slow(prompts, fwd_hooks=activate_context_hooks):
    tokens = model.to_tokens(prompts)

    losses = []
    names = []
    with model.hooks(fwd_hooks):
        original_loss = model(prompts, return_type="loss", loss_per_token=True).flatten()[-1].item()
    losses.append([original_loss]*20)
    names.append("Original")
    for pos in tqdm(range(1, tokens.shape[1]-1)):
        loss = replace_token_loss([prompts], pos, num_replacements=20)
        losses.append(loss)
        names.append(model.to_str_tokens(tokens[0, pos])[0])
    return losses, names

# losses, names = replace_token_loss_slow(german_ord[:10][-15:])
# haystack_utils.plot_barplot(losses, names, ylabel="Loss", xlabel="Replaced token", title=f"Average loss when replacing single token with top 20 German unigrams")


In [None]:
prompts = random_ord
neuron_diffs_filename = f'data/pythia_160m/neuron_diffs.pkl'

with model.hooks(deactivate_context_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

if os.path.exists(neuron_diffs_filename):
    with open(neuron_diffs_filename, 'rb') as f:
            diffs = pickle.load(f)
else:
    diffs = torch.zeros(2048, prompts.shape[0])
    # Loss with path patched MLP11 neurons
    _, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                                                            deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                            activated_components = ("blocks.11.hook_mlp_out",))
    for neuron in tqdm(range(2048)):
        ablate_single_neuron_hook = get_neuron_hook(11, neuron, ablated_cache[f'blocks.11.mlp.hook_post'][:, :, neuron])
        # Loss with path patched MLP11 neurons but a single neuron changed back to original ablated value
        _, _, _, only_deactivated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks+[ablate_single_neuron_hook],
                                                                        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                                        activated_components = ("blocks.11.hook_mlp_out",))
        diffs[neuron] = only_deactivated_loss - baseline_loss

    with open(neuron_diffs_filename, 'wb') as f:
        pickle.dump(diffs, f)

print(diffs.mean())

tensor(0.0005)


In [None]:
sorted_means, sorted_indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
haystack_utils.line(sorted_means, xlabel="Sorted neurons", ylabel="Loss change", title="Loss change from ablating MLP11 neuron")

def get_individual_neuron_ablation_losses(top_neurons_count=10, sorted_indices=sorted_indices, deactivate_context_hooks=deactivate_context_hooks):
    """Check loss change when ablating top / bottom neurons"""    
    top_neurons = sorted_indices[-top_neurons_count:]
    bottom_neurons = sorted_indices[:top_neurons_count]

    with model.hooks(deactivate_context_hooks):
        _, ablated_cache = model.run_with_cache(prompts, return_type="loss")

    ablate_top_neurons_hook = [get_neuron_hook(11, top_neurons, ablated_cache[f'blocks.11.mlp.hook_post'][:, :, top_neurons].mean(0))]
    ablate_bottom_neurons_hook = [get_neuron_hook(11, bottom_neurons, ablated_cache[f'blocks.11.mlp.hook_post'][:, :, bottom_neurons].mean(0))]

    original_loss, ablated_loss, _, all_MLP11_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                                                                                    deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                                                    activated_components = ("blocks.11.hook_mlp_out",))
    _, _, _, top_MLP11_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks+ablate_top_neurons_hook,
                                                                                    deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                                                    activated_components = ("blocks.11.hook_mlp_out",))
    _, _, _, bottom_MLP11_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks+ablate_bottom_neurons_hook,
                                                                                    deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                                                    activated_components = ("blocks.11.hook_mlp_out",))

    names = ["Original", "Ablated", "MLP11 path patched", f"MLP11 path patched + Top {top_neurons_count} MLP11 neurons ablated", f"MLP11 path patched + Bottom {top_neurons_count} MLP11 neurons ablated"]
    short_names = ["Original", "Ablated", "MLP11 path patched", f"Top MLP11 removed", f"Bottom MLP11 removed"]
    values = [original_loss.tolist(), ablated_loss.tolist(), all_MLP11_loss.tolist(), top_MLP11_ablated_loss.tolist(), bottom_MLP11_ablated_loss.tolist()]
    return values, names, short_names


values, names, short_names = get_individual_neuron_ablation_losses()
haystack_utils.plot_barplot(values, names, short_names=short_names, ylabel="Loss", title=f"Average last token loss when removing top / bottom neurons from path patching")

#### Try in French language prompt

In [None]:
# Get example prompts with 'ordre' completion
french_ord = get_prompts_with_token('ordre', french_data)
print(len(french_ord))

113


In [None]:
activate_context_hooks = [get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean() * 2), get_neuron_hook(8, 2994,  german_activations[8][:, 2994].mean() * 2)]

french_completion = model.to_single_token('re')
german_completion = model.to_single_token('n')

original_logprobs, ablated_logprobs, _, _ = haystack_utils.get_direct_effect(
    french_ord[0], model, pos=None, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[], return_type='logprobs')

print(original_logprobs[-2, german_completion], original_logprobs[-2, french_completion])
print(original_logprobs[-2, german_completion] - original_logprobs[-2, french_completion])
print(ablated_logprobs[-2, german_completion] - ablated_logprobs[-2, french_completion])

tensor(21.9077, device='cuda:0') tensor(41.5833, device='cuda:0')
tensor(-19.6756, device='cuda:0')
tensor(-4.3473, device='cuda:0')


In [None]:
data = mlp_language_logprob_diffs(french_ord[0], -2, model.to_single_token('n'), model.to_single_token('re'))
haystack_utils.plot_barplot([[item.cpu().item()] for item in data],
                                names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP"]],
                                title="Log prob difference between French and German tokens when activating German context")

### DLA & component-level path patching

In [None]:
# logit_attr_original, labels = haystack_utils.DLA(german_data, model)

# # Patch in disabled context neurons and plot the direct logit attribution difference for each component
# with model.hooks(fwd_hooks=deactivate_context_hooks):
#     logit_attr_ablated, _ = haystack_utils.DLA(german_data, model)

# logit_diffs = (logit_attr_original - logit_attr_ablated).mean(0)
# # The small differences accumulated before the ablation are due to the final layer norm scale being affected by the L3 hook.
# haystack_utils.line(logit_diffs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="(Original DLA - Ablated DLA) per component", xticks=labels)

The direct loss increases are relatively small until layer 9, implying that the direct effects of the ablations are more minor and that most 
differences are indirect effects of the context neurons in layers 5 and 8 starting in layer 9.

### High loss prompts - MLP11

In [None]:
# Get general and MLP11 specific losses
german_losses = []
for prompt in tqdm(german_data):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks, 
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))
    german_losses.append((original_loss, ablated_loss, context_and_activated_loss, only_activated_loss))

  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
measure = get_mlp11_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)
for i, measure in sorted_measure[:10]:
    print(measure)
    print_prompt(german_data[i])

3.3175816535949707


3.0792105197906494


2.634366989135742


2.284700393676758


2.220391035079956


2.18497896194458


2.1799113750457764


2.137683391571045


2.089397668838501


2.0570228099823


### Other tokens

In [None]:
# print_counter('f', german_data)
# component_analysis([' Verfasser'])

In [None]:
print_counter('ing', german_data)
component_analysis([' beding'])

Counter({'inges': 5, 'ingere': 4, 'ingele': 3, 'ingangs': 3, 'ingungen': 3, 'ingear': 2, 'ingeb': 2, 'ingest': 2, 'ingesch': 2, 'inglied': 2, 'ingare': 1, 'ing\n': 1, 'ingten': 1, 'ing in': 1, 'inget': 1, 'ingre': 1, 'ing,': 1, 'ingt': 1, 'ing.': 1, 'ingeg': 1, 'ingel': 1, 'ingef': 1})
['<|endoftext|>', ' bed', 'ing']


In [None]:
# Other english completions for tok that tokenize correctly?
# tox -> oplasma etc.

print_counter(' tox', english_data_long)
print_counter(' tox', german_data)
print_counter(' tox', french_data)


Counter()
Counter({' toxik': 1})
Counter({' toxique': 1})


In [None]:
print_counter('le', english_data_long, next_tokens_count=3)
print_counter('le', german_data, next_tokens_count=3)
print_counter('le', french_data, next_tokens_count=3)

Counter({'le\nMember of': 9, 'le, on behalf': 6, 'leagues, I': 5, 'le Stauner': 3, 'leagues, we': 3, 'le, Adam B': 2, 'le Durant\n': 2, 'lemming\n': 2, 'lean, on': 2, 'le Schmidt, on': 2, 'le Schmidt on behalf': 2, 'le\nMr President': 2, 'le - and also': 1, 'le Rivasi\n': 1, 'le Zimmer and': 1, 'le and Mr Rap': 1, 'le, who was': 1, 'leagues, allow': 1, 'leagues, ladies': 1, "le'.\nThe": 1, 'leagues, you': 1, 'leagues, please': 1, 'le report, in': 1, 'lean sense.': 1, 'le\n(DE': 1, 'lean, Hann': 1, 'le the ships because': 1, 'leagues, under': 1, 'leva, on': 1, 'le will take the': 1, 'leagues, let': 1, 'le\n(ES': 1, 'le Albertini\n': 1, 'le Schmidt, Maria': 1, 'le clause under Article': 1, 'le, who is': 1, 'leagues, there': 1, 'le, Cresc': 1, 'le to tell Mrs': 1, 'leagues, today': 1, 'le, a lawyer': 1, 'leagues, for': 1, 'leknas U': 1, 'leicher\nMr': 1, 'leagues - for': 1, 'le Schmidt\nMr': 1, 'le will present the': 1, 'le Albertini in': 1, 'le and Kyri': 1, 'le over. I': 1, 'lehem and t

In [None]:
component_analysis([' legen'])

['<|endoftext|>', ' le', 'gen']


In [None]:
component_analysis([' legen'])

['<|endoftext|>', ' le', 'gen']


In [None]:
# Skip trigram because changing P to B doesn't break the circuit
component_analysis([' die Bräsidentin'])

['<|endoftext|>', ' die', ' Br', 'äsident', 'in']


In [None]:
component_analysis([' Präsidentin', ' Anerkennung', ' legen'])

['<|endoftext|>', ' Pr', 'äsident', 'in']


['<|endoftext|>', ' An', 'erk', 'enn', 'ung']


['<|endoftext|>', ' le', 'gen']


### Patch patch components

In [None]:
def get_prompt_and_token():
    "Prompt with token which generates highest loss difference"
    for prompt in german_data:
        original_loss, _ = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
        with model.hooks(deactivate_context_hooks):
            ablated_loss, _ = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
        value, index = torch.max(ablated_loss - original_loss, dim=1)
        if value > 3 and original_loss[0, index] < 3:
            return prompt, index
    return '', -1
        
# prompt, index = get_prompt_and_token()
# data = mlp_effects_german(prompt, index)
# haystack_utils.plot_barplot([[item] for item in data],
#                             names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP"]])