## Setup

In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LOG_PROB_THRESHOLD = -7
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/166M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.
wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [23]:
def get_pos_loss_diff(prompt: str, model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]], plot_hist=False, use_activate_hook=False, debug_log=True):
    tokens = model.to_tokens(prompt)
    if use_activate_hook:
        original_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=activate_neurons_hook, loss_per_token=True)
    else:
        original_loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=deactivate_neurons_hook, loss_per_token=True)
    
    # Positive difference = loss increase due to ablation
    loss_difference = (ablated_loss - original_loss).flatten()

    if debug_log:
        print(f"Unablated loss: {original_loss.flatten()}")
        print(f"Ablated loss: {ablated_loss.flatten()}")
        print(f"Loss difference: {loss_difference}")

    if plot_hist:
        fig = px.histogram(loss_difference.flatten().cpu().numpy(), title="Loss difference due to ablation per position")
        fig.show()
    return loss_difference

def get_high_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        loss_difference = get_pos_loss_diff(prompt, model, activate_neurons_hook, deactivate_neurons_hook)
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs


In [11]:
def get_top_differences_at_position(prompt: str, model: HookedTransformer, position: int, top_k=20, mode="full"):
    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(tokens)
    # Logprobs instead of logits
    original_logits = model(tokens, return_type="logits")
    if mode=="direct":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_frozen_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    elif mode=="indirect":
        to_freeze = ["blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out"]
        ablated_logits = haystack_utils.get_ablated_logits(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=to_freeze)
    else:
        assert mode=="full"
        ablated_logits = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=deactivate_neurons_fwd_hooks)
    original_logprob = original_logits.log_softmax(dim=-1)
    ablated_logprob = ablated_logits.log_softmax(dim=-1)

    # Positive difference = the German neuron makes the token more likely
    # Negative difference = the German neuron makes the token less likely
    logprob_differences = original_logprob - ablated_logprob
    logit_differences = original_logits - ablated_logits

    print("Prompt:", prompt)
    print(f"Differences for predicting: {str_tokens[position]} -> {str_tokens[position+1]}")

    low_log_prob = torch.argwhere(((original_logprob[0, position, :] <= LOG_PROB_THRESHOLD) & (ablated_logprob[0, position, :] <= LOG_PROB_THRESHOLD))).flatten()
    ignore_tokens = torch.cat([low_log_prob, all_ignore]).unique()
    
    top_original_logprobs, top_original_idx = haystack_utils.top_k_with_exclude(original_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_original_ablated_logprobs = ablated_logprob[0, position, top_original_idx]
    top_ablated_logprobs, top_ablated_idx = haystack_utils.top_k_with_exclude(ablated_logprob[0, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_ablated_original_logprobs = original_logprob[0, position, top_ablated_idx]

    top_original_values = logprob_differences[0, position, top_original_idx]
    top_ablated_values = logprob_differences[0, position, top_ablated_idx]
    top_original_logit_diff = logit_differences[0, position, top_original_idx]
    top_ablated_logit_diff = logit_differences[0, position, top_ablated_idx]
    print("Top predictions with German neuron active (unablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_original_idx), top_original_values.cpu().tolist(), max_value=5, original_log_probs=top_original_logprobs.cpu().tolist(), ablated_log_probs=top_original_ablated_logprobs.cpu().tolist(), logit_difference=top_original_logit_diff.cpu().tolist())
    print("Top predictions with German neuron disabled (ablated)")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_ablated_idx), top_ablated_values.cpu().tolist(), max_value=5, original_log_probs=top_ablated_original_logprobs.cpu().tolist(), ablated_log_probs=top_ablated_logprobs.cpu().tolist(), logit_difference=top_ablated_logit_diff.cpu().tolist())

    top_boosts, top_boosted_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, exclude=ignore_tokens)
    top_boost_original_logprob = original_logprob[0, position, top_boosted_idx]
    top_boost_ablated_logprob = ablated_logprob[0, position, top_boosted_idx]
    top_reduced, top_reduced_idx = haystack_utils.top_k_with_exclude(logprob_differences[:, position, :].flatten(), top_k, largest=False, exclude=ignore_tokens)
    top_reduced_original_logprob = original_logprob[0, position, top_reduced_idx]
    top_reduced_ablated_logprob = ablated_logprob[0, position, top_reduced_idx]
    print("Top boosted tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_boosted_idx), top_boosts.cpu().tolist(), max_value=5, original_log_probs=top_boost_original_logprob.cpu().tolist(), ablated_log_probs=top_boost_ablated_logprob.cpu().tolist())
    print("Top reduced tokens by German neuron")
    haystack_utils.print_strings_as_html(model.to_str_tokens(top_reduced_idx), top_reduced.cpu().tolist(), max_value=5, original_log_probs=top_reduced_original_logprob.cpu().tolist(), ablated_log_probs=top_reduced_ablated_logprob.cpu().tolist())

In [3]:
def show_token_loss(prompt: str, model: HookedTransformer, max_value=None, mode="full", freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")):
    
    original_loss, total_effect_loss_change, direct_effect_loss_change, indirect_effect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False)
    if mode== "indirect":
        pos_wise_loss = indirect_effect_loss_change
        #pos_wise_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    elif mode == "direct":
        pos_wise_loss = direct_effect_loss_change
        #pos_wise_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names)
    else:
        assert mode =="full"
        pos_wise_loss = total_effect_loss_change
        #pos_wise_loss = get_pos_loss_diff(prompt, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks, plot_hist=False)
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    haystack_utils.print_strings_as_html(str_token_prompt[1:], pos_wise_loss.flatten().cpu().tolist(), max_value=max_value)

def print_predictions(prompt, pos, k=20):
    print("\nFull model predictions")
    get_top_differences_at_position(prompt, model, pos, k, mode="full")
    print("\nIndirect predictions (leave German neuron active, patch corrupted activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="indirect")
    print("\nDirect predictions (ablate German neuron, patch clean activations to later components)")
    get_top_differences_at_position(prompt, model, pos, k, mode="direct")

## Look for interesting examples

In [6]:
def get_interesting_loss_prompts(prompts: list[str], model: HookedTransformer, activate_neurons_hook: List[Tuple[str, HookPoint]], deactivate_neurons_hook: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        original_loss, total_effect_loss_change, direct_loss, indirect_loss = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, debug_log=False)
        #indirect_loss = haystack_utils.get_frozen_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks)
        #direct_loss = haystack_utils.get_ablated_loss_difference_measure(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks)
        loss_difference = abs(direct_loss - indirect_loss)
        loss_difference[indirect_loss < 1] = 0
        loss_difference[direct_loss > 1] = 0
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs

n_examples = 1000
max_diffs, average_diffs = get_interesting_loss_prompts(german_data, model, activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks)

loss_data_tuple = [(diff, example) for diff, example in zip(max_diffs, german_data[:n_examples])]
loss_data_tuple.sort(key=lambda x: x[0], reverse=True)
loss_data_tuple[:2]

  0%|          | 0/2459 [00:00<?, ?it/s]

[(11.32781982421875,
  'Ich möchte nochmals meine Ansicht wiederholen, dass die Menschenrechte als ein selbstverständlicher Teil fest in die Außenpolitik der EU verankert werden müssen, und ich glaube, dass die neue institutionelle Struktur der EU, besonders im Hinblick auf den EAD und die zu ihm gehörige Abteilung, eine Gelegenheit bietet, den Zusammenhalt und die Effizienz der EU auf diesem Gebiet zu stärken; ich fordere die Vizepräsidentin/Hohe Vertreterin dringend dazu auf, die Initiative zu ergreifen und durch bilaterale Beziehungen zu Drittländern und eine aktive Teilnahme an internationalen Foren dafür zu sorgen, dass sich die Drittländer für den Schutz der Menschenrechte engagieren und sich gegen Verletzungen der Menschenrechte wehren, anstatt vor den nötigen Maßnahmen zurückzuschrecken, wenn sie missachtet werden; angesichts des wachsenden Ausmaßes der schweren Verletzungen der Glaubensfreiheit, rufe ich die Kommission dazu auf, eine gründliche Evaluierung vorzunehmen und die 

In [7]:
def show_all_loss_types(prompt):
    print("Full model loss")
    show_token_loss(prompt, model, max_value=5, mode="full")
    print("Indirect loss")
    show_token_loss(prompt, model, max_value=5, mode="indirect")
    print("Direct loss")
    show_token_loss(prompt, model, max_value=5, mode="direct")

In [12]:
for _ , prompt in loss_data_tuple[0:2]:
    print("")
    show_token_loss(prompt, model, max_value=5, mode="full")
    show_token_loss(prompt, model, max_value=5, mode="indirect")
    show_token_loss(prompt, model, max_value=5, mode="direct")







## Vorschläge

In [44]:
prompt = "zu den Vorschlägen"
show_all_loss_types(prompt)

Full model loss


Indirect loss


Direct loss


In [24]:
# Check loss MLP5 loss increase when patching clean activations to MLP4

# Freeze everything except for MLP5 to see if MLP5 depends on MLP4
pos = -1
freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out")
original_loss, total_effect_loss_change, direct_effect_loss_change, indirect_effect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False, return_absolute=True)
print(f"Original loss on pos {pos}: {original_loss[0, pos].item():.4f}")
print(f"Total effect loss: {total_effect_loss_change[0, pos].item():.4f}")#
print(f"Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): {direct_effect_loss_change[0, pos].item():.4f}")
freeze_act_names=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")
original_loss, total_effect_loss_change, direct_effect_loss_change, indirect_effect_loss_change = haystack_utils.split_effects(prompt, model, ablation_hooks=deactivate_neurons_fwd_hooks, freeze_act_names=freeze_act_names, debug_log=False, return_absolute=True)
print(f"Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): {direct_effect_loss_change[0, pos].item():.4f}")


Original loss on pos -1: 3.8393
Total effect loss: 7.5884
Direct effect loss of MLP3 and MLP5 (restoring MLP4 and attention): 8.2419
Direct effect loss of MLP3 (restoring MLP4 and MLP5 and attention): 3.2200


In [33]:
def compare_activations(prompts: str, model: HookedTransformer, layer=5, pos=-1):
    tokens = model.to_tokens(prompt)
    
    with model.hooks(fwd_hooks=activate_neurons_fwd_hooks):
        original_logits, original_cache = model.run_with_cache(tokens)

    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
        ablated_logits, ablated_cache = model.run_with_cache(tokens)
    
    act_label = f"blocks.{layer}.mlp.hook_pre"
    
    original_activation, ablated_activation = original_cache[act_label][0, pos, :], ablated_cache[act_label][0, pos, :]
    activation_difference = original_activation - ablated_activation
    #activation_difference = einops.rearrange(activation_difference, "batch pos d_mlp -> (batch pos) d_mlp")
    return activation_difference, original_activation, ablated_activation
    

In [34]:
activation_differences, original_activation, ablated_activation = compare_activations(prompt, model, layer=5, pos=-1)
px.histogram(activation_differences.cpu().numpy(), nbins=100, title="Activation differences between MLP5")

In [37]:
haystack_utils.two_histogram(original_activation, ablated_activation, title="Activation distributions of MLP5")

In [42]:
token = model.to_single_token("gen")
token_direction = model.W_U[:, token]
mlp_outputs = model.W_out[5, :, :]
dot_res = einops.einsum(mlp_outputs, token_direction, "d_mlp res, res -> d_mlp")


In [43]:
px.histogram(dot_res.cpu().numpy(), width=1000)

In [49]:
def get_neuron_logit_contribution(cache: ActivationCache, model: HookedTransformer, answer_tokens: Int[Tensor, "batch pos"], layer: int, pos:int) -> Float[Tensor, "neuron pos"]:
    # Expects cache from a single example, won't work on batched examples
    # Get per neuron output of MLP layer
    neuron_directions = cache.get_neuron_results(layer, neuron_slice=utils.Slice(input_slice=None), pos_slice=utils.Slice(input_slice=None))
    neuron_directions = einops.rearrange(neuron_directions, 'batch pos neuron residual -> neuron batch pos residual')
    # We need to apply the final layer norm because the unembed operation is applied after the final layer norm, so the answer token
    # directions are in the same space as the final layer norm output
    # LN leads to finding top tokens with slightly higher loss attribution
    scaled_neuron_directions = cache.apply_ln_to_stack(neuron_directions)[:, 0, pos-1, :] # [neuron embed]
    # Unembed of correct answer tokens
    correct_token_directions = model.W_U[:, answer_tokens].squeeze(1) # [embed] 
    # Neuron attribution to correct answer token by position
    unembedded = einops.einsum(scaled_neuron_directions, correct_token_directions, 'neuron residual, residual -> neuron')
    return unembedded

def MLP_attribution(prompt: str, model: HookedTransformer, fwd_hooks, layer_to_compare=5, pos=-1):
    
    tokens = model.to_tokens(prompt)
    answer_tokens = tokens[:, pos]
    # Get difference between ablated and unablated neurons' contribution to answer logit
    _, _, original_cache, ablated_cache = haystack_utils.get_caches_single_prompt(
        prompt, model, fwd_hooks)
    original_unembedded = get_neuron_logit_contribution(original_cache, model, answer_tokens, layer=layer_to_compare, pos=pos) # [neuron]
    ablated_unembedded = get_neuron_logit_contribution(ablated_cache, model, answer_tokens, layer=layer_to_compare, pos=pos)
    differences = (original_unembedded - ablated_unembedded).detach().cpu() # [neuron]
    return differences
    
haystack_utils.clean_cache()
layer_to_compare=5
differences = MLP_attribution(prompt, model, fwd_hooks=deactivate_neurons_fwd_hooks, layer_to_compare=layer_to_compare, pos=-1)

In [50]:
px.histogram(differences.cpu().numpy(), width=1000)

In [52]:
top_diff, top_diff_neurons = torch.topk(differences, 50, largest=True)
haystack_utils.line(top_diff.cpu().numpy(), xlabel="Top positive difference neurons", ylabel="Difference in attribution to correct answer token", width=1000, xticks=top_diff_neurons.cpu().tolist())

## Freeze top neurons to see if loss increase was due to them

In [60]:
top_5 = top_diff_neurons[:50]

original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

# # Add the effects of ablating at MLP3 to the components after MLP3
# def freeze_ablated_hook(value, hook: HookPoint):
#     value = ablated_cache[hook.name]
#     return value             
# freeze_ablated_hooks = [(freeze_act_name, freeze_ablated_hook) for freeze_act_name in freeze_act_names]
# with model.hooks(fwd_hooks=freeze_ablated_hooks):
#     original_with_frozen_ablated = model(prompt, return_type="loss", loss_per_token=True)

# Remove the effects of ablating at MLP3 from the components after MLP3
def freeze_neurons_hook(value, hook: HookPoint):
    value[:, :, top_5] = original_cache[hook.name][:, :, top_5] # [batch pos neuron
    return value      


freeze_original_hooks = [("blocks.5.mlp.hook_post", freeze_neurons_hook)]
with model.hooks(fwd_hooks=freeze_original_hooks+deactivate_neurons_fwd_hooks):
    ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)


print(original_loss)
print(ablated_loss)
print(ablated_with_original_frozen_loss)

tensor([[15.9090,  9.2650,  5.8355,  6.6877,  5.1798,  3.8393]],
       device='cuda:0')
tensor([[15.8943,  9.3858,  6.8106,  8.5760,  8.1114,  7.5884]],
       device='cuda:0')
tensor([[15.8967,  9.3769,  6.8582,  8.7013,  7.2039,  0.9723]],
       device='cuda:0')


In [61]:
loss_per_neuron = []
for neuron in top_5:
    original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
        ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

    # # Add the effects of ablating at MLP3 to the components after MLP3
    # def freeze_ablated_hook(value, hook: HookPoint):
    #     value = ablated_cache[hook.name]
    #     return value             
    # freeze_ablated_hooks = [(freeze_act_name, freeze_ablated_hook) for freeze_act_name in freeze_act_names]
    # with model.hooks(fwd_hooks=freeze_ablated_hooks):
    #     original_with_frozen_ablated = model(prompt, return_type="loss", loss_per_token=True)

    # Remove the effects of ablating at MLP3 from the components after MLP3
    def freeze_neurons_hook(value, hook: HookPoint):
        value[:, :, neuron] = original_cache[hook.name][:, :, neuron] # [batch pos neuron
        return value      


    freeze_original_hooks = [("blocks.5.mlp.hook_post", freeze_neurons_hook)]
    with model.hooks(fwd_hooks=freeze_original_hooks+deactivate_neurons_fwd_hooks):
        ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)
    
    loss_per_neuron.append(ablated_with_original_frozen_loss[0, -1].item())

haystack_utils.line(loss_per_neuron, xlabel="Neuron", ylabel="Total loss of restoring the neuron", width=1000, xticks=top_diff_neurons[:top_5.shape[0]].cpu().tolist())

In [59]:
new_top = [838, 1026, 1709, 84, 852, 959, 1043]

original_loss, original_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type="loss", loss_per_token=True)

# # Add the effects of ablating at MLP3 to the components after MLP3
# def freeze_ablated_hook(value, hook: HookPoint):
#     value = ablated_cache[hook.name]
#     return value             
# freeze_ablated_hooks = [(freeze_act_name, freeze_ablated_hook) for freeze_act_name in freeze_act_names]
# with model.hooks(fwd_hooks=freeze_ablated_hooks):
#     original_with_frozen_ablated = model(prompt, return_type="loss", loss_per_token=True)

# Remove the effects of ablating at MLP3 from the components after MLP3
def freeze_neurons_hook(value, hook: HookPoint):
    value[:, :, new_top] = original_cache[hook.name][:, :, new_top] # [batch pos neuron
    return value      


freeze_original_hooks = [("blocks.5.mlp.hook_post", freeze_neurons_hook)]
with model.hooks(fwd_hooks=freeze_original_hooks+deactivate_neurons_fwd_hooks):
    ablated_with_original_frozen_loss = model(prompt, return_type="loss", loss_per_token=True)


print(original_loss)
print(ablated_loss)
print(ablated_with_original_frozen_loss)

tensor([[15.9090,  9.2650,  5.8355,  6.6877,  5.1798,  3.8393]],
       device='cuda:0')
tensor([[15.8943,  9.3858,  6.8106,  8.5760,  8.1114,  7.5884]],
       device='cuda:0')
tensor([[15.8947,  9.3865,  6.8033,  8.6222,  7.9965,  4.7644]],
       device='cuda:0')
