In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line, two_histogram
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


In [3]:
english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.
wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.


## Check accuracy of German neurons

Sanity check: reproduce sparse probe results on the German neuron with Pythia-v1

In [5]:
english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [6]:
def run_single_neuron_lr(layer, neuron, num_samples=5000):
    # Check accuracy of logistic regression
    A = torch.concat([german_activations[layer][:num_samples, neuron], english_activations[layer][:num_samples, neuron]]).view(-1, 1).cpu().numpy()
    y = torch.concat([torch.ones(num_samples), torch.zeros(num_samples)]).cpu().numpy()
    A_train, A_test, y_train, y_test = train_test_split(A, y, test_size=0.2)
    lr_model = LogisticRegression()
    lr_model.fit(A_train, y_train)
    test_acc = lr_model.score(A_test, y_test)
    train_acc = lr_model.score(A_train, y_train)
    f1 = sklearn.metrics.f1_score(y_test, lr_model.predict(A_test))
    return train_acc, test_acc, f1
    

def get_neuron_accuracy(layer, neuron, plot=False):
    mean_english_activation = english_activations[layer][:,neuron].mean()
    mean_german_activation = german_activations[layer][:,neuron].mean()
    
    if plot:
        two_histogram(english_activations[layer][:,neuron], german_activations[layer][:,neuron], "English", "German", "Activation", "Frequency", f"L{layer}N{neuron} activations on English vs German text")
    train_acc, test_acc, f1 = run_single_neuron_lr(layer, neuron)
    print(f"\nL{layer}N{neuron}: F1={f1:.2f}, Train acc={train_acc:.2f}, and test acc={test_acc:.2f}")
    print(f"Mean activation English={mean_english_activation:.2f}, German={mean_german_activation:.2f}")


In [7]:
for layer, neuron in german_neurons:
    get_neuron_accuracy(layer, neuron)


L4N482: F1=0.90, Train acc=0.91, and test acc=0.91
Mean activation English=-0.07, German=1.21

L5N1039: F1=0.85, Train acc=0.84, and test acc=0.83
Mean activation English=1.02, German=-0.06

L5N407: F1=0.66, Train acc=0.63, and test acc=0.67
Mean activation English=5.23, German=3.70

L5N1516: F1=0.76, Train acc=0.76, and test acc=0.77
Mean activation English=2.31, German=1.02

L5N1336: F1=0.96, Train acc=0.96, and test acc=0.96
Mean activation English=-0.06, German=1.40

L4N326: F1=0.83, Train acc=0.83, and test acc=0.84
Mean activation English=0.03, German=0.81

L5N250: F1=0.76, Train acc=0.77, and test acc=0.77
Mean activation English=-0.00, German=-0.04

L3N669: F1=0.99, Train acc=0.99, and test acc=0.99
Mean activation English=-0.07, German=3.82


## Check loss increases for all German neurons

In [8]:
mean_original_loss, mean_ablated_loss, percent_increase = haystack_utils.get_ablated_performance(german_data[:1000], model, LAYER_TO_ABLATE,  NEURONS_TO_ABLATE, MEAN_ACTIVATION_INACTIVE)
print(f"Mean original loss={mean_original_loss:.2f}, mean ablated loss={mean_ablated_loss:.2f}, percent increase={percent_increase:.2f}")

NameError: name 'LAYER_TO_ABLATE' is not defined

In [None]:
for layer, neuron in tqdm(german_neurons):
    mean_activation_inactive = english_activations[layer][:, neuron].mean()
    mean_original_loss, mean_ablated_loss, percent_increase = haystack_utils.get_ablated_performance(german_data[:1000], model, layer, [neuron], mean_activation_inactive, display_tqdm=False)
    print(f"L{layer}N{neuron}: Loss original={mean_original_loss:.2f}, ablated={mean_ablated_loss:.2f} (+{percent_increase:.2f}%)")

  0%|          | 0/8 [00:00<?, ?it/s]

L4N482: Loss original=3.18, ablated=3.20 (+0.60%)
L5N1039: Loss original=3.18, ablated=3.18 (+0.16%)
L5N407: Loss original=3.18, ablated=3.19 (+0.22%)
L5N1516: Loss original=3.18, ablated=3.18 (+0.09%)
L5N1336: Loss original=3.18, ablated=3.20 (+0.59%)
L4N326: Loss original=3.18, ablated=3.18 (+0.12%)
L5N250: Loss original=3.18, ablated=3.18 (+0.01%)
L3N669: Loss original=3.18, ablated=3.55 (+11.57%)


In [None]:
# Sanity check: also calculate loss on English
print("English loss impact (should be close to 0)")
for layer, neuron in tqdm(german_neurons):
    mean_activation_inactive = english_activations[layer][:, neuron].mean()
    mean_original_loss, mean_ablated_loss, percent_increase = haystack_utils.get_ablated_performance(english_data[:1000], model, layer, [neuron], mean_activation_inactive, display_tqdm=False)
    print(f"L{layer}N{neuron}: Loss original={mean_original_loss:.2f}, ablated={mean_ablated_loss:.2f} (+{percent_increase:.2f}%)")

English loss impact (should be close to 0)


  0%|          | 0/8 [00:00<?, ?it/s]

L4N482: Loss original=3.96, ablated=3.96 (+0.02%)
L5N1039: Loss original=3.96, ablated=3.96 (+0.09%)
L5N407: Loss original=3.96, ablated=3.96 (+0.02%)
L5N1516: Loss original=3.96, ablated=3.95 (+-0.01%)
L5N1336: Loss original=3.96, ablated=3.96 (+-0.00%)
L4N326: Loss original=3.96, ablated=3.96 (+-0.00%)
L5N250: Loss original=3.96, ablated=3.95 (+-0.02%)
L3N669: Loss original=3.96, ablated=3.96 (+0.00%)


## Check loss breakdown by component for L3N669

In [9]:
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

In [11]:
def DLA(prompts: list[str], model: HookedTransformer):
    logit_attributions = []
    for prompt in tqdm(prompts):
        tokens = model.to_tokens(prompt)
        answers = tokens[:, 1:]
        tokens = tokens[:, :-1]
        answer_residual_directions = model.tokens_to_residual_directions(answers)
        _, cache = model.run_with_cache(tokens)
        accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=None, return_labels=True)
        # Component batch pos d_model
        scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer = -1, pos_slice=None)
        logit_attribution = einsum(scaled_residual_stack, answer_residual_directions, "component batch pos d_model, batch pos d_model -> component") / answers.shape[1]
        logit_attributions.append(logit_attribution)
    
    logit_attributions = torch.stack(logit_attributions)
    return logit_attributions, labels

def ablate_neuron_hook(value, hook):
        value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
        return value

ABLATE_HOOK=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', ablate_neuron_hook)]


In [12]:
logit_attr_original, labels = DLA(german_data[:1000], model)

with model.hooks(fwd_hooks=ABLATE_HOOK):
    logit_attr_ablated, _ = DLA(german_data[:1000], model)

logit_diffs = (logit_attr_original - logit_attr_ablated).mean(0)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [13]:
haystack_utils.line(logit_diffs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="(Original DLA - Ablated DLA) per component", xticks=labels)

In [14]:
_, cache = model.run_with_cache(german_data[0])
cache

ActivationCache with keys ['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'bloc