In [14]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.express as px
from torchmetrics.regression import KendallRankCorrCoef

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [4]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Downloading (…)lve/main/config.json: 100%|██████████| 567/567 [00:00<00:00, 2.06MB/s]
Downloading model.safetensors: 100%|██████████| 166M/166M [00:02<00:00, 67.0MB/s] 
Downloading (…)okenizer_config.json: 100%|██████████| 396/396 [00:00<00:00, 907kB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.11M/2.11M [00:00<00:00, 7.93MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 99.0/99.0 [00:00<00:00, 235kB/s]
Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


100%|██████████| 200/200 [00:03<00:00, 53.71it/s]
100%|██████████| 200/200 [00:02<00:00, 86.84it/s]


In [5]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)

def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)

100%|██████████| 200/200 [00:02<00:00, 78.29it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', ' zu', 'ch', 'n', 'st', 're', 'z', ' von', ' für', 'äsident', ' Pr', 'ischen', 't', 'ü', 'icht', 'gen', ' ist', ' auf', ' dass', 'ge', 'ig', ' im', 'in', ' über', 'g', ' das', 'te', ' er', 'men', ' w', 'es', ' an', 'ß', ' wir', ' eine', 'f', ' W', 'hen', 'w', ' Europ', ' ich', 'ungen', 'ren', 'le', ' dem', 'ten', ' ein', 'e', ' Z', ' Ver', 'der', ' B', ' mit', ' dies', 'h', ' nicht', 'ungs', 's', ' G', ' z', 'it', ' Herr', ' es', 'l', ' S', 'ich', 'lich', ' An', 'heit', 'ie', ' Er', ' zur', ' V', ' ver', 'u', 'hr', 'chaft', 'Der', ' Ich', ' Ab', ' haben', 'i', 'ant', 'chte', ' mö', 'er', ' K', 'igen', ' Ber', 'ür', ' Fra', 'em']





In [9]:
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache, layer=5):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neuron]
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# taus = torch.zeros
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_deactivated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_deactivated_loss - baseline_loss

print(diffs.mean())

100%|██████████| 2048/2048 [02:46<00:00, 12.31it/s]

tensor(0.0007)





In [10]:
sorted_means, indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
# haystack_utils.line(sorted_means, xlabel="Sorted neurons", ylabel="Loss change", title="Loss change from ablating MLP5 neuron") # xticks=indices
sorted_top_neuron_indices = indices

In [13]:
prompt_strs = [model.tokenizer.decode(prompts[i].tolist()) for i in range(prompts.shape[0])]

with model.hooks(activate_neurons_fwd_hooks):
    enabled_acts_gen = get_mlp_activations(prompt_strs, 5, model, mean=False, pos=-2)

100%|██████████| 100/100 [00:01<00:00, 71.10it/s]


### Top neuron differences between prompts

In [38]:
import numpy as np

In [39]:
# Are the same top neurons activated regardless of prior prompt? 

# If we look at top neuron activations, there's some similarities between the lists but they're not entirely consistent between prompts
# For 100 prompts, take each prompt's top neurons and take the kendall's tau between each pair of prompt top neurons to get the average ordinal similarity

neuron_cutoff = 20
num_prompts = diffs.shape[1]

top_neurons_by_prompt = torch.zeros(num_prompts, neuron_cutoff)
bottom_neurons_by_prompt = torch.zeros(num_prompts, neuron_cutoff)
for prompt_index in range(num_prompts):
    sorted_neurons, indices = torch.sort(diffs[:, prompt_index])
    sorted_neurons = sorted_neurons.tolist()
    top_neurons_by_prompt[prompt_index] = indices[-neuron_cutoff:]
    bottom_neurons_by_prompt[prompt_index] = indices[:neuron_cutoff]

kendall = KendallRankCorrCoef()
taus = torch.zeros(num_prompts * num_prompts - num_prompts)
for i in range(num_prompts):
    for j in range(num_prompts):
        if i == j:
            continue
        taus[i * (num_prompts - 1) + j] = kendall(top_neurons_by_prompt[i], top_neurons_by_prompt[j])

In [41]:
print(taus.mean().item())

px.histogram(taus.cpu().numpy(), title="taus")

0.04469430819153786


In [37]:
# Baseline tau to compare: random list of neurons?
baseline_taus = torch.zeros(30)
for i in range(30):
    baseline_taus[i] = kendall(torch.from_numpy(np.random.randint(0, 2048, (neuron_cutoff,))), torch.from_numpy(np.random.randint(0, 2048, (neuron_cutoff,))))

print(baseline_taus.mean().item())

-0.004108896013349295


In [34]:


print(kendall(top_neurons_by_prompt[i], top_neurons_by_prompt[j]))

0.04469430819153786
