In [2]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [4]:
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.
wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
def get_pos_loss_diff(prompt: str, model: HookedTransformer, fwd_hooks: List[Tuple[str, HookPoint]], plot_hist=False):
    tokens = model.to_tokens(prompt)
    original_loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=fwd_hooks, loss_per_token=True)
    
    # Positive difference = loss increase due to ablation
    loss_difference = (ablated_loss - original_loss).flatten()

    if plot_hist:
        fig = px.histogram(loss_difference.flatten().cpu().numpy(), title="Loss difference due to ablation per position")
        fig.show()
    return loss_difference

def get_high_loss_prompts(prompts: list[str], model: HookedTransformer, fwd_hooks: List[Tuple[str, HookPoint]]):
    max_diffs = []
    average_diffs = []
    for prompt in tqdm(prompts):
        loss_difference = get_pos_loss_diff(prompt, model, fwd_hooks)
        max_diffs.append(loss_difference.max().item())
        average_diffs.append(loss_difference.mean().item())
    return max_diffs, average_diffs

max_diffs, average_diffs = get_high_loss_prompts(german_data, model, deactivate_neurons_fwd_hooks)


In [20]:
px.histogram(average_diffs, title="Average loss difference per prompt", width=1000)

In [21]:
px.histogram(max_diffs, title="Maximum loss difference on a single token per prompt", width=1000)

In [76]:
# Get prompts with high average loss 
threshold = 5
high_max_loss_prompts = [i for i in range(len(german_data)) if max_diffs[i] > threshold]
print(high_max_loss_prompts)

[12, 13, 60, 67, 69, 82, 87, 102, 120, 125, 126, 133, 139, 146, 150, 157, 164, 195, 202, 217, 219, 227, 237, 238, 243, 247, 250, 251, 257, 260, 263, 269, 270, 294, 329, 337, 352, 365, 368, 391, 399, 408, 409, 419, 430, 441, 457, 463, 475, 478, 516, 522, 529, 531, 534, 541, 557, 577, 579, 585, 589, 615, 616, 617, 618, 620, 626, 630, 641, 646, 647, 664, 684, 688, 705, 707, 708, 711, 712, 713, 714, 716, 719, 723, 724, 731, 734, 743, 760, 771, 773, 774, 791, 795, 805, 813, 847, 860, 861, 869, 882, 919, 923, 928, 938, 939, 941, 946, 949, 961, 1017, 1025, 1037, 1039, 1085, 1091, 1094, 1104, 1109, 1112, 1117, 1122, 1123, 1148, 1159, 1162, 1165, 1166, 1171, 1178, 1185, 1201, 1210, 1212, 1213, 1217, 1229, 1232, 1234, 1388, 1404, 1471, 1483, 1563, 1569, 1602, 1606, 1686, 1760, 1800, 1847, 1877, 1899, 1998, 2029, 2064, 2153, 2200, 2213, 2251, 2273, 2298, 2327, 2427]


In [77]:
# Get prompts with high average loss 
threshold = 0.5
high_average_loss_prompts = [i for i in range(len(german_data)) if average_diffs[i] > threshold]
print(high_average_loss_prompts)

[46, 69, 74, 115, 144, 164, 217, 233, 245, 280, 295, 307, 436, 447, 452, 503, 578, 664, 721, 869, 1110, 1148, 1150, 1155, 1158, 1185, 1211, 2224]


In [71]:
def show_token_loss(prompt: str, model: HookedTransformer, fwd_hooks: List[Tuple[str, HookPoint]], max_value=None):
    pos_wise_loss = get_pos_loss_diff(prompt, model, deactivate_neurons_fwd_hooks, plot_hist=False)
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    haystack_utils.print_strings_as_html(str_token_prompt[:-1], pos_wise_loss.flatten().cpu().tolist(), max_value=max_value)

In [81]:
for prompt_idx in high_max_loss_prompts[:5]:
    prompt = german_data[prompt_idx]
    show_token_loss(prompt, model, deactivate_neurons_fwd_hooks, max_value=5)

In [82]:
for prompt_idx in high_average_loss_prompts[:5]:
    prompt = german_data[prompt_idx]
    show_token_loss(prompt, model, deactivate_neurons_fwd_hooks, max_value=5)