In [2]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.express as px
from torchmetrics.regression import KendallRankCorrCoef, SpearmanCorrCoef

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [3]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [4]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)

def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)

  0%|          | 0/200 [00:00<?, ?it/s]

In [5]:
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache, layer=5):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neuron]
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_deactivated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_deactivated_loss - baseline_loss

  0%|          | 0/2048 [00:00<?, ?it/s]

### Top neuron differences between prompts

In [None]:
average_losses, average_indices = torch.sort(diffs.mean(1))
print(average_indices.shape)

torch.Size([2048])


In [None]:
# Are the same top neurons activated regardless of prior prompt? 

# If we look at top neuron activations, there's some similarities between the lists but they're not entirely consistent between prompts
# For 100 prompts, take each prompt's top neurons and take the kendall's tau between each pair of prompt top neurons to get the average ordinal similarity

neuron_cutoff = 50
num_prompts = diffs.shape[1]

In [None]:
top_neurons_by_prompt = torch.zeros(num_prompts, neuron_cutoff)
bottom_neurons_by_prompt = torch.zeros(num_prompts, neuron_cutoff)
for prompt_index in range(num_prompts):
    losses, indices = torch.sort(diffs[:, prompt_index])
    top_neurons_by_prompt[prompt_index] = indices[-neuron_cutoff:]
    bottom_neurons_by_prompt[prompt_index] = indices[:neuron_cutoff]

kendall = KendallRankCorrCoef()

taus = torch.zeros(num_prompts)
for i in range(num_prompts):
    taus[i] = kendall(top_neurons_by_prompt[i], average_indices[-neuron_cutoff:])

In [None]:
# there's a fair amount of overlap in the very top neurons
for prompt_index in range(5):
    _, indices = torch.sort(diffs[:, prompt_index])
    print(indices[-5:])

tensor([  84,  255,  216,  905, 1268])
tensor([1510,  395, 1268,   84,  255])
tensor([ 395, 1510,  213,   84,  255])
tensor([1709,  255, 1510,  213,   84])
tensor([ 213,  255, 1709,   84,  905])


In [None]:
print(taus.mean().item())
print(taus.shape)
px.histogram(taus.cpu().numpy(), title="taus")

0.024457141757011414
torch.Size([100])


If the orderings were independent we would see the results clustered around 0. If the orderings are similar, as I expected, we would see a cluster around a positive mean.

What we see is a mean that's very slightly positive, but very close to 0, and a very wide spread compared with the baseline/random measurements.

Is the true mean 0, and there's a lot of noise injected by some process? 
Or is the true mean 0.04 - if this is the case why is there so much variance?

I suspect the first option is correct

In [None]:
# Baseline tau to compare: random list of neurons?
baseline_taus = torch.zeros(100)
for i in range(30):
    baseline_taus[i] = kendall(torch.from_numpy(np.random.randint(0, 2048, (neuron_cutoff,))), torch.from_numpy(np.random.randint(0, 2048, (neuron_cutoff,))))

px.histogram(baseline_taus.cpu().numpy(), title="taus")

In [None]:
neuron_cutoff = 2048

# Take the average neuron helpfulness ranking. Get the average loss for each neuron (this is our model). For a single prompt, take the average helpfulness of these
# neurons. Compare whether the helpfulness of the neurons in the prompts increases monotonically in the same order as the model neurons, and to what extent.
spearman = SpearmanCorrCoef()
phos = torch.zeros(num_prompts)
for i in range(num_prompts):
    phos[i] = spearman(diffs[:, i][average_indices[-neuron_cutoff:]], average_losses[-neuron_cutoff:])

print(phos.mean().item())
print(phos.shape)
px.histogram(phos.cpu().numpy(), title="Spearman's rank correlation coefficient <br>Mean neuron importances compared with prompt neuron importances")

NameError: name 'SpearmanCorrCoef' is not defined

Baseline: random prompts

In [None]:
diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(german_data[:100], model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_deactivated_loss = haystack_utils.get_direct_effect(german_data[:100], model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_deactivated_loss - baseline_loss