In [26]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
# from torchmetrics.regression import KendallRankCorrCoef, SpearmanCorrCoef
from collections import defaultdict

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [27]:
haystack_utils.clean_cache()
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [28]:
# print(model.to_tokens("swim"), model.to_tokens("swam"))

In [111]:
german_neurons_with_f1 = [
    [5, 2649, 1.0],
    [8,	2994, 1.0],
    [11, 2911, 0.99],
    [10, 1129, 0.97],
    [6, 1838, 0.65],
    [7, 1594, 0.65],
    [11, 1819, 0.61],
    [11, 2014, 0.56],
    [10, 753, 0.54],
    [11, 205, 0.48],
]

important_german_neurons = defaultdict(list)
for layer, neuron, f1 in german_neurons_with_f1:
    # if f1 > 0.9:
    important_german_neurons[layer].append(neuron)

english_activations = {}
german_activations = {}
for layer in set([layer for layer, _, _ in german_neurons_with_f1]):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)


5


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

6


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

7


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

8


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

10


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

11


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [117]:
mean_context_neuron_acts_active = defaultdict(list)
mean_context_neuron_acts_inactive = defaultdict(list)
for layer, neuron, _ in german_neurons_with_f1:
    mean_context_neuron_acts_active[layer].append((neuron, german_activations[layer][:, neuron].mean(0)))
    mean_context_neuron_acts_inactive[layer].append((neuron, english_activations[layer][:, neuron].mean(0)))

def get_deactivate_neurons_hook(layer):
    def deactivate_neurons_hook(value, hook):
        neurons, acts = zip(*mean_context_neuron_acts_inactive[layer])
        value[:, :, neurons] = torch.tensor(acts).cuda()
        return value
    return deactivate_neurons_hook
deactivate_neurons_fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_deactivate_neurons_hook(layer)) for layer in important_german_neurons.keys()]

def get_activate_neurons_hook(layer):
    def activate_neurons_hook(value, hook):
        neurons, acts = zip(*mean_context_neuron_acts_inactive[layer])
        value[:, :, neurons] = torch.tensor(acts).cuda()
        return value
    return activate_neurons_hook
activate_neurons_fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_activate_neurons_hook(layer)) for layer in important_german_neurons.keys()]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

### Check classification accuracy of German neurons

In [118]:
def run_single_neuron_lr(layer, neuron, num_samples=5000, german_activations=german_activations, english_activations=english_activations):
    # Check accuracy of logistic regression
    A = torch.concat([german_activations[layer][:num_samples, neuron], english_activations[layer][:num_samples, neuron]]).view(-1, 1).cpu().numpy()
    y = torch.concat([torch.ones(num_samples), torch.zeros(num_samples)]).cpu().numpy()
    A_train, A_test, y_train, y_test = train_test_split(A, y, test_size=0.2)
    lr_model = LogisticRegression()
    lr_model.fit(A_train, y_train)
    test_acc = lr_model.score(A_test, y_test)
    train_acc = lr_model.score(A_train, y_train)
    f1 = sklearn.metrics.f1_score(y_test, lr_model.predict(A_test))
    return train_acc, test_acc, f1
    
def get_neuron_accuracy(layer, neuron, german_activations=german_activations, english_activations=english_activations, plot=False, print_f1s=True):
    mean_english_activation = english_activations[layer][:,neuron].mean()
    mean_german_activation = german_activations[layer][:,neuron].mean()
    
    if plot:
        haystack_utils.two_histogram(english_activations[layer][:,neuron], german_activations[layer][:,neuron], "English", "German", "Activation", "Frequency", f"L{layer}N{neuron} activations on English vs German text")
    train_acc, test_acc, f1 = run_single_neuron_lr(layer, neuron, german_activations=german_activations, english_activations=english_activations)
    if print_f1s:
        print(f"\nL{layer}N{neuron}: F1={f1:.2f}, Train acc={train_acc:.2f}, and test acc={test_acc:.2f}")
        print(f"Mean activation English={mean_english_activation:.2f}, German={mean_german_activation:.2f}")
    return f1

In [113]:
f1s = []
for layer, neuron, reported_f1 in german_neurons_with_f1:
    f1s.append(get_neuron_accuracy(layer, neuron, print_f1s=False))

german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _ in german_neurons_with_f1]
haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", xticks=german_neuron_names, show_legend=False)

In [119]:
f1s = []
for layer, neuron, _ in german_neurons_with_f1:
    deactivate_other_neurons_fwd_hooks=[(f'blocks.{l}.mlp.hook_post', get_deactivate_neurons_hook(l)) for l in important_german_neurons.keys() if l != layer]
    with model.hooks(deactivate_other_neurons_fwd_hooks):
        modified_german_acts = {layer: haystack_utils.get_mlp_activations(german_data, layer, model, mean=False)}

    f1s.append(get_neuron_accuracy(layer, neuron, german_activations=modified_german_acts, 
                                    english_activations=english_activations, print_f1s=False, plot=False))

german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _ in german_neurons_with_f1]
haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", 
                    xticks=german_neuron_names, show_legend=False)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [167]:
# Full ablation accuracy
def ablation_effect(fwd_hooks):
    original_losses = []
    ablated_losses = []
    batch_size = 50
    for i in range(4):
        original_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss').cpu())
        with model.hooks(fwd_hooks):
            ablated_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss').cpu())

    original_loss = sum(original_losses) / len(original_losses)
    ablated_loss = sum(ablated_losses) / len(ablated_losses)

    print(original_loss, ablated_loss)
    print(f'{(ablated_loss - original_loss) / original_loss * 100:2f}% loss increase')

In [35]:
print("Full ablation:")
ablation_effect(deactivate_neurons_fwd_hooks)

def get_neuron_hook(layer, neuron, inactive_value):
    def deactivate_neuron_hook(value, hook):
        value[:, :, neuron] = inactive_value
        return value
    return [(f'blocks.{layer}.mlp.hook_post', deactivate_neuron_hook)]

for layer, neuron, f1 in german_neurons_with_f1:
    print(f"Ablate L{layer}N{neuron} context neuron with f1 of {f1}:")
    ablation_effect(get_neuron_hook(layer, neuron, english_activations[layer][:, neuron].mean()))

Full ablation:
tensor(3.6835, device='cuda:0') tensor(4.0660, device='cuda:0')
10.382535% loss increase
Ablate L5N2649 context neuron with f1 of 1.0:
tensor(3.6835, device='cuda:0') tensor(3.7163, device='cuda:0')
0.890389% loss increase
Ablate L8N2994 context neuron with f1 of 1.0:
tensor(3.6835, device='cuda:0') tensor(3.8847, device='cuda:0')
5.460662% loss increase
Ablate L11N2911 context neuron with f1 of 0.99:
tensor(3.6835, device='cuda:0') tensor(3.6756, device='cuda:0')
-0.215380% loss increase
Ablate L10N1129 context neuron with f1 of 0.97:
tensor(3.6835, device='cuda:0') tensor(3.6798, device='cuda:0')
-0.101768% loss increase
Ablate L6N1838 context neuron with f1 of 0.65:
tensor(3.6835, device='cuda:0') tensor(3.6940, device='cuda:0')
0.283271% loss increase
Ablate L7N1594 context neuron with f1 of 0.65:
tensor(3.6835, device='cuda:0') tensor(3.6920, device='cuda:0')
0.228546% loss increase
Ablate L11N1819 context neuron with f1 of 0.61:
tensor(3.6835, device='cuda:0') tens

All context neurons are in the output path of the L5 context neuron (L8 and L11 less and the rest more).

Most circuits are in the output path of the L8 context neuron.

In [176]:
print(german_activations[8][:, 2994].mean())
ablation_effect(get_neuron_hook(8, 2994, german_activations[8][:, 2994].mean()))
ablation_effect(get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean()))

tensor(4.1120, device='cuda:0')
3.683535575866699 3.7638607025146484
2.180653% loss increase
3.683535575866699 3.7241556644439697
1.102747% loss increase


In [175]:
deactivate_context_hooks = get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()) + get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean())
activate_context_hooks = get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean()) + get_neuron_hook(8, 2994,  german_activations[8][:, 2994].mean())

ablation_effect(get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()))
ablation_effect(get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean()))
ablation_effect(deactivate_context_hooks)

# Assume due to bimodal activation distribution being significant
ablation_effect(activate_context_hooks)



3.683535575866699 3.7163333892822266
0.890389% loss increase
3.683535575866699 3.884680986404419
5.460662% loss increase
3.683535575866699 4.020592212677002
9.150356% loss increase
3.683535575866699 3.8088128566741943
3.401006% loss increase


In [97]:
german_acts_5 = []
german_acts_8 = []
for prompt in german_data:
    _, cache = model.run_with_cache(prompt)
    german_acts_5 += cache['post', 5][:, :, 2649].flatten().tolist()
    german_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

english_acts_5 = []
english_acts_8 = []
for prompt in english_data:
    _, cache = model.run_with_cache(prompt)
    english_acts_5 += cache['post', 5][:, :, 2649].flatten().tolist()
    english_acts_8 += cache['post', 8][:, :, 2994].flatten().tolist()

# 1 - 3
# 6 - 8


In [98]:
px.histogram(german_acts_8)

In [68]:
# px.histogram(english_acts_8)

Ablating them together causes the majority of the loss, even through their individual ablation loss increases sum to less than this.

Perhaps there are circuits which rely on both - somewhat AND gates, although the context neurons are dependent.

### Trimodal context neuron in L8

In [151]:
from datasets import load_dataset
english_data_long = haystack_utils.load_json_data("data/english_europarl.json")
stack_exchange_data = load_dataset('habedi/stack-exchange-dataset', split='train')

data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3962/308917268.py", line 3, in <module>
    stack_exchange_data = load_dataset('habedi/stack-exchange-dataset', split='train')
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 2106, in load_dataset
    builder_instance = load_dataset_builder(
  File "/opt/conda/lib/python3.10/site-packages/datasets/load.py", line 1829, in load_dataset_builder
    builder_instance: DatasetBuilder = builder_cls(
TypeError: 'NoneType' object is not callable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2102, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/opt/conda/lib/python3.10/site-packages/IP

In [182]:
# record_prompts = []
# for prompt in stack_exchange_data:
#     prompt = prompt['body']
#     tokens = model.to_str_tokens(model.to_tokens(prompt))
#     if 'ord' in tokens:
#         ord_index = tokens.index('ord')
#         if ord_index > 0: print(tokens[ord_index - 1])
#         if ord_index > 0 and 'f' in tokens[ord_index - 1]:
#             record_prompts.append(prompt, ord_index)
#             print(True)
#             print(prompt[prompt.index('ord') - 10 : prompt.index('ord') + 50])
            

print(model.to_single_str_token(47))
ford_data = []
ord_token = model.to_single_token('ord')
# f_token = model.to_single_token('f')
# print(f_token)
from collections import Counter
counter = Counter()
for prompt in english_data_long:
    tokens = model.to_tokens(prompt)[0]
    try: 
        index = tokens.tolist().index(ord_token)
    except:
        continue
    if index + 1 < len(tokens):
        counter.update(model.to_single_str_token(tokens[index + 1].item()))
print(counter)

counter = Counter()
for prompt in german_data:
    tokens = model.to_tokens(prompt)[0]
    try: 
        index = tokens.tolist().index(ord_token)
    except:
        continue
    if index + 1 < len(tokens):
        counter.update(model.to_single_str_token(tokens[index + 1].item()))

print(counter)

N
Counter({'i': 9, 'c': 8, 'a': 4, 'n': 4, '\n': 4, 'e': 1, ',': 1, '.': 1, 't': 1, 'o': 1})
Counter({'n': 82, 'e': 21, 't': 20, 'u': 2, 'g': 2, 'i': 1, 's': 1, 'c': 1, 'h': 1})


### DLA & component-level path patching

In [43]:
logit_attr_original, labels = haystack_utils.DLA(german_data, model)

# Patch in disabled context neurons and plot the direct logit attribution difference for each component
with model.hooks(fwd_hooks=deactivate_context_hooks):
    logit_attr_ablated, _ = haystack_utils.DLA(german_data, model)

logit_diffs = (logit_attr_original - logit_attr_ablated).mean(0)
# The small differences accumulated before the ablation are due to the final layer norm scale being affected by the L3 hook.
haystack_utils.line(logit_diffs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="(Original DLA - Ablated DLA) per component", xticks=labels)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

The direct loss increases are relatively small until layer 9, implying that the direct effects of the ablations are more minor and that most 
differences are indirect effects of the context neurons in layers 5 and 8 starting in layer 9.

### High loss prompts - MLP11

In [71]:
def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    loss_diff = (ablated_loss - original_loss) # Loss increase from context neuron
    mlp_11_power = (only_activated_loss - original_loss) # Loss increase from MLP5
    mlp_11_power[mlp_11_power < 0] = 0
    combined = 0.5 * loss_diff - mlp_11_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

def print_prompt(prompt: str):
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))

    pos_wise_diff = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, max_value=2.5, additional_measures=loss_list, additional_measure_names=loss_names)

def get_mlp11_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

average_loss_plot = haystack_utils.get_average_loss_plot_method(activate_context_hooks, deactivate_context_hooks, "MLP11",
                                                                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                                                                activated_components = ("blocks.11.hook_mlp_out"))


In [45]:
german_losses = []
for prompt in tqdm(german_data):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks, 
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))
    german_losses.append((original_loss, ablated_loss, context_and_activated_loss, only_activated_loss))

  0%|          | 0/200 [00:00<?, ?it/s]

In [153]:
measure = get_mlp11_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)
for i, measure in sorted_measure[:2]:
    print(measure)
    print_prompt(german_data[i])

2.2552361488342285


2.0187695026397705


### Patch patch components

In [154]:
def get_prompt_and_token():
    for prompt in german_data:
        original_loss, _ = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
        with model.hooks(deactivate_context_hooks):
            ablated_loss, _ = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
        value, index = torch.max(ablated_loss - original_loss, dim=1)
        if value > 3 and original_loss[0, index] < 3:
            return prompt, index
    return '', -1
        
prompt, index = get_prompt_and_token()

In [163]:
def plot_effects(prompt, index):
        original, ablated, direct_effect, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.6.hook_attn_out", "blocks.7.hook_attn_out", "blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.6.hook_mlp_out", "blocks.7.hook_mlp_out", "blocks.9.hook_mlp_out", "blocks.11.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",))
        
        _, _, _, only_activated_loss_mlp_9 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.11.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.9.hook_mlp_out",))

        _, _, _, only_activated_loss_attn_9 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.9.hook_attn_out",))

        _, _, _, only_activated_loss_mlp_10 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.10.hook_mlp_out",))

        _, _, _, only_activated_loss_attn_10 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out"),
                activated_components = ("blocks.10.hook_attn_out",))

        _, _, _, only_activated_loss_mlp_11 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
                activated_components = ("blocks.11.hook_mlp_out",))

        _, _, _, only_activated_loss_attn_11 = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out","blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out", "blocks.11.hook_mlp_out", ),
                activated_components = ("blocks.11.hook_attn_out",))

        data = [original, ablated, direct_effect, only_activated_loss_mlp_9, only_activated_loss_attn_9, only_activated_loss_mlp_10, only_activated_loss_attn_10, only_activated_loss_mlp_11, only_activated_loss_attn_11]
        # print(original_loss_mlp_9.mean(), ablated_loss_mlp_9.mean(), only_activated_loss_mlp_9.mean())
        # print(original_loss_mlp_10.mean(), ablated_loss_mlp_10.mean(), only_activated_loss_mlp_10.mean())
        haystack_utils.plot_barplot([[item] for item in data],
                                names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP", "Attn"]])

plot_effects(prompt, index)

### Ford circuit

In [165]:
prompts = haystack_utils.load_json_data('data/F_prompts.json')

data/F_prompts.json: Loaded 118 examples with 180 to 5360 characters each.


In [None]:
def pad_left(prompts):
    tokens = model.to_tokens(prompts)
    target_length = tokens.shape[1]

    result = 