In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px
from torchmetrics.regression import KendallRankCorrCoef, SpearmanCorrCoef
from collections import defaultdict

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
haystack_utils.clean_cache()
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [3]:
# print(model.to_tokens("swim"), model.to_tokens("swam"))

tensor([[   0, 2140,  303]], device='cuda:0') tensor([[   0, 2140,  312]], device='cuda:0')


In [79]:
german_neurons_with_f1 = [
    [5, 2649, 1.0],
    [8,	2994, 1.0],
    [11, 2911, 0.99],
    [10, 1129, 0.97],
    [6, 1838, 0.65],
    [7, 1594, 0.65],
    [11, 1819, 0.61],
    [11, 2014, 0.56],
    [10, 753, 0.54],
    [11, 205, 0.48],
]

important_german_neurons = defaultdict(list)
for layer, neuron, f1 in german_neurons_with_f1:
    if f1 > 0.9:
        important_german_neurons[layer].append(neuron)

english_activations = {}
german_activations = {}
for layer in set([layer for layer, _, _ in german_neurons_with_f1]):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [80]:
mean_context_neuron_acts_active = defaultdict(list)
mean_context_neuron_acts_inactive = defaultdict(list)
for layer, neuron, _ in german_neurons_with_f1:
    mean_context_neuron_acts_active[layer].append((neuron, german_activations[layer][:, neuron].mean(0)))
    mean_context_neuron_acts_inactive[layer].append((neuron, english_activations[layer][:, neuron].mean(0)))

def get_deactivate_neurons_hook(layer):
    def deactivate_neurons_hook(value, hook):
        neurons, acts = zip(*mean_context_neuron_acts_inactive[layer])
        value[:, :, neurons] = torch.tensor(acts).cuda()
        return value
    return deactivate_neurons_hook
deactivate_neurons_fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_deactivate_neurons_hook(layer)) for layer in important_german_neurons.keys()]

def get_activate_neurons_hook(layer):
    def activate_neurons_hook(value, hook):
        neurons, acts = zip(*mean_context_neuron_acts_inactive[layer])
        value[:, :, neurons] = torch.tensor(acts).cuda()
        return value
    return activate_neurons_hook
activate_neurons_fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_activate_neurons_hook(layer)) for layer in important_german_neurons.keys()]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

## Check classification accuracy of German neurons

In [81]:
def run_single_neuron_lr(layer, neuron, num_samples=5000, german_activations=german_activations, english_activations=english_activations):
    # Check accuracy of logistic regression
    A = torch.concat([german_activations[layer][:num_samples, neuron], english_activations[layer][:num_samples, neuron]]).view(-1, 1).cpu().numpy()
    y = torch.concat([torch.ones(num_samples), torch.zeros(num_samples)]).cpu().numpy()
    A_train, A_test, y_train, y_test = train_test_split(A, y, test_size=0.2)
    lr_model = LogisticRegression()
    lr_model.fit(A_train, y_train)
    test_acc = lr_model.score(A_test, y_test)
    train_acc = lr_model.score(A_train, y_train)
    f1 = sklearn.metrics.f1_score(y_test, lr_model.predict(A_test))
    return train_acc, test_acc, f1
    
def get_neuron_accuracy(layer, neuron, german_activations=german_activations, english_activations=english_activations, plot=False, print_f1s=True):
    mean_english_activation = english_activations[layer][:,neuron].mean()
    mean_german_activation = german_activations[layer][:,neuron].mean()
    
    if plot:
        haystack_utils.two_histogram(english_activations[layer][:,neuron], german_activations[layer][:,neuron], "English", "German", "Activation", "Frequency", f"L{layer}N{neuron} activations on English vs German text")
    train_acc, test_acc, f1 = run_single_neuron_lr(layer, neuron)
    if print_f1s:
        print(f"\nL{layer}N{neuron}: F1={f1:.2f}, Train acc={train_acc:.2f}, and test acc={test_acc:.2f}")
        print(f"Mean activation English={mean_english_activation:.2f}, German={mean_german_activation:.2f}")
    return f1

In [97]:
f1s = []
for layer, neuron, reported_f1 in german_neurons_with_f1:
    f1s.append(get_neuron_accuracy(layer, neuron, print_f1s=False))

german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _ in german_neurons_with_f1]
haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", xticks=german_neuron_names, show_legend=False)

In [83]:
enabled_context_neuron_acts = {str(layer) + '_' + str(neuron): german_activations[layer][:, neuron] for layer, neuron, _ in german_neurons_with_f1}

def disable_other_context_neurons(german_activations, current_layer, current_neuron):
    for layer, neuron, f1 in german_neurons_with_f1:
        if layer == current_layer and neuron == current_neuron:
            german_activations[layer][:, neuron] = enabled_context_neuron_acts[str(layer) + '_' + str(neuron)]
        else:
            # Not a perfect ablation
            german_activations[layer][:, neuron][:english_activations[layer].shape[0]] = english_activations[layer][:, neuron]
    return german_activations


f1s = []
for layer, neuron, _ in german_neurons_with_f1:
    german_activations = disable_other_context_neurons(german_activations, layer, neuron)
    f1s.append(get_neuron_accuracy(layer, neuron, german_activations=german_activations, 
                                                                    english_activations=english_activations, print_f1s=False, plot=False))

german_neuron_names = [f"L{layer}N{neuron}" for layer, neuron, _ in german_neurons_with_f1]
haystack_utils.line(f1s, xlabel="", ylabel="F1 score of sparse probe", title="Sparse probe performance on individual German neurons", xticks=german_neuron_names, show_legend=False)

In [99]:
# Full ablation accuracy
def ablation_effect(fwd_hooks):
    original_losses = []
    ablated_losses = []
    batch_size = 50
    for i in range(4):
        original_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss'))
        with model.hooks(fwd_hooks):
            ablated_losses.append(model(german_data[i * batch_size:i * batch_size + 50], return_type='loss'))

    original_loss = sum(original_losses) / len(original_losses)
    ablated_loss = sum(ablated_losses) / len(ablated_losses)

    print(original_loss, ablated_loss)
    print(f'{(ablated_loss - original_loss) / original_loss * 100:2f}% loss increase')

In [100]:
print("Full ablation:")
ablation_effect(deactivate_neurons_fwd_hooks)

def get_deactivate_neuron_hook(layer, neuron):
    inactive_value = english_activations[layer][:, neuron].mean()
    def deactivate_neuron_hook(value, hook):
        value[:, :, neuron] = inactive_value
        return value
    return [(f'blocks.{layer}.mlp.hook_post', deactivate_neuron_hook)]

for layer, neuron, f1 in german_neurons_with_f1:
    print(f"Ablate L{layer}N{neuron} context neuron with f1 of {f1}:")
    ablation_effect(get_deactivate_neuron_hook(layer, neuron))

Full ablation:
tensor(3.6835, device='cuda:0') tensor(4.0660, device='cuda:0')
10.382569% loss increase
Ablate L5N2649 context neuron with f1 of 1.0:
tensor(3.6835, device='cuda:0') tensor(3.7163, device='cuda:0')
0.890422% loss increase
Ablate L8N2994 context neuron with f1 of 1.0:
tensor(3.6835, device='cuda:0') tensor(3.8847, device='cuda:0')
5.460695% loss increase
Ablate L11N2911 context neuron with f1 of 0.99:
tensor(3.6835, device='cuda:0') tensor(3.6756, device='cuda:0')
-0.215374% loss increase
Ablate L10N1129 context neuron with f1 of 0.97:
tensor(3.6835, device='cuda:0') tensor(3.6798, device='cuda:0')
-0.101768% loss increase
Ablate L6N1838 context neuron with f1 of 0.65:
tensor(3.6835, device='cuda:0') tensor(3.6940, device='cuda:0')
0.283284% loss increase
Ablate L7N1594 context neuron with f1 of 0.65:
tensor(3.6835, device='cuda:0') tensor(3.6920, device='cuda:0')
0.228552% loss increase
Ablate L11N1819 context neuron with f1 of 0.61:
tensor(3.6835, device='cuda:0') tens

All context neurons and especially the L8 context neuron are in the output path of the L5 context neuron.
Most circuits are in the output path of the L8 context neuron.

In [98]:
ablation_effect(get_deactivate_neuron_hook(5, 2649) + get_deactivate_neuron_hook(8, 2994))

9.150404% loss increase


Ablating them together causes the majority of the loss, even through their individual ablation loss increases sum to less than this.

Perhaps there are circuits which rely on both - somewhat AND gates, although the context neurons are dependent.

In [None]:
deactivate_context_hooks = get_deactivate_neuron_hook(5, 2649) + get_deactivate_neuron_hook(8, 2994)