In [6]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import plotly.express as px

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [7]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)

  0%|          | 0/200 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', ' zu', 'ch', 'n', 'st', 're', 'z', ' von', ' für', 'äsident', ' Pr', 'ischen', 't', 'ü', 'icht', 'gen', ' ist', ' auf', ' dass', 'ge', 'ig', ' im', 'in', ' über', 'g', ' das', 'te', ' er', 'men', ' w', 'es', ' an', 'ß', ' wir', ' eine', 'f', ' W', 'hen', 'w', ' Europ', ' ich', 'ungen', 'ren', 'le', ' dem', 'ten', ' ein', 'e', ' Z', ' Ver', 'der', ' B', ' mit', ' dies', 'h', ' nicht', 'ungs', 's', ' G', ' z', 'it', ' Herr', ' es', 'l', ' S', 'ich', 'lich', ' An', 'heit', 'ie', ' Er', ' zur', ' V', ' ver', 'u', 'hr', 'chaft', 'Der', ' Ich', ' Ab', ' haben', 'i', 'ant', 'chte', ' mö', 'er', ' K', 'igen', ' Ber', 'ür', ' Fra', 'em']


In [9]:
print(prompts[0])

tensor([19129,   275,  3150, 20150,   249, 10278, 12606,    91, 17079,  2827,
         9527,   296,  3090,    71,  1541,  8449,  2604,   348,   304,  1850,
          657, 34267, 42824,  1541], device='cuda:0')


In [10]:
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache, layer=5):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neuron]
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_deactivated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_deactivated_loss - baseline_loss

print(diffs.mean())

  0%|          | 0/2048 [00:00<?, ?it/s]

tensor(0.0007)


In [11]:
sorted_means, indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
haystack_utils.line(sorted_means, xlabel="Sorted neurons", ylabel="Loss change", title="Loss change from ablating MLP5 neuron") # xticks=indices
sorted_top_neuron_indices = indices

### Top/bottom neuron activation boxplot with and without context neuron ablated

In [12]:

with model.hooks(activate_neurons_fwd_hooks):
    enabled_acts = get_mlp_activations(german_data, 5, model, mean=False)
with model.hooks(deactivate_neurons_fwd_hooks):
    disabled_acts = get_mlp_activations(german_data, 5, model, mean=False)

prompt_strs = [model.tokenizer.decode(prompts[i].tolist()) for i in range(prompts.shape[0])]

with model.hooks(activate_neurons_fwd_hooks):
    enabled_acts_gen = get_mlp_activations(prompt_strs, 5, model, mean=False)
with model.hooks(deactivate_neurons_fwd_hooks):
    disabled_acts_gen = get_mlp_activations(prompt_strs, 5, model, mean=False)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [18]:
skip_indices = torch.arange(0, enabled_acts.shape[0], 3)
top_neurons = sorted_top_neuron_indices[-3:]
bottom_neurons = sorted_top_neuron_indices[:3]

df = pd.DataFrame({f'{i}': enabled_acts[skip_indices, i].cpu().numpy() for i in top_neurons})
df_melted = df.melt(var_name='Neuron', value_name='Activation value')
fig = px.box(df_melted, x='Neuron', y='Activation value', title='Top neuron activation values with context neuron enabled')
fig.update_layout(
    xaxis=dict(showticklabels=False)
)
fig.show()

# df = pd.DataFrame({f'{i}': disabled_acts[skip_indices, i].cpu().numpy() for i in top_neurons})
# df_melted = df.melt(var_name='Neuron', value_name='Activation value')
# fig = px.box(df_melted, x='Neuron', y='Activation value', title='Top neuron activation values with context neuron disabled')
# fig.update_layout(
#     xaxis=dict(showticklabels=False),
#     yaxis=dict(range=[-1, 7])
# )
# fig.show()

# df = pd.DataFrame({f'{i}': enabled_acts[skip_indices, i].cpu().numpy() for i in bottom_neurons})
# df_melted = df.melt(var_name='Neuron', value_name='Activation value')
# fig = px.box(df_melted, x='Neuron', y='Activation value', title='Bottom neuron activation values with context neuron enabled')
# fig.update_layout(
#     xaxis=dict(showticklabels=False),
#     yaxis=dict(range=[-1, 7])
# )
# fig.show()

# df = pd.DataFrame({f'{i}': disabled_acts[skip_indices, i].cpu().numpy() for i in bottom_neurons})
# df_melted = df.melt(var_name='Neuron', value_name='Activation value')
# fig = px.box(df_melted, x='Neuron', y='Activation value', title='Bottom neuron activation values with context neuron disabled')
# fig.update_layout(xaxis=dict(showticklabels=False))
# fig.show()


In [None]:

# print(len(english_data))
top_neurons = sorted_top_neuron_indices[-100:]
bottom_neurons = sorted_top_neuron_indices[:100]
english_acts = get_mlp_activations(english_data[:100], 5, model, mean=False)
# print(english_acts.shape)
# print(top_neurons)
# english_skip_indices = torch.arange(0, english_acts.shape[0], 3)

# df = pd.DataFrame({f'{i}': english_acts[english_skip_indices, i].cpu().numpy() for i in top_neurons})

# df_melted = df.melt(var_name='Neuron', value_name='Activation value')
# fig = px.box(df_melted, x='Neuron', y='Activation value', title='Top neuron activation values on English data')
# fig.update_layout(
#     xaxis=dict(showticklabels=False)
# )
# fig.show()


with model.hooks(activate_neurons_fwd_hooks):
    english_enabled_acts = get_mlp_activations(english_data[:200], 5, model, mean=False)
with model.hooks(deactivate_neurons_fwd_hooks):
    english_disabled_acts = get_mlp_activations(english_data[:200], 5, model, mean=False)

english_skip_indices = torch.arange(0, english_enabled_acts.shape[0], 3)

df = pd.DataFrame({f'{i}': english_disabled_acts[english_skip_indices, i].cpu().numpy() for i in top_neurons})

df_melted = df.melt(var_name='Neuron', value_name='Activation value')
fig = px.box(df_melted, x='Neuron', y='Activation value', title='Top neuron activation values on English data')
fig.update_layout(
    xaxis=dict(showticklabels=False)
)
fig.show()

# df = pd.DataFrame({f'{i}': english_disabled_acts[english_skip_indices, i].cpu().numpy() for i in top_neurons})

# df_melted = df.melt(var_name='Neuron', value_name='Activation value')
# fig = px.box(df_melted, x='Neuron', y='Activation value', title='Top neuron activation values on English data')
# fig.update_layout(
#     xaxis=dict(showticklabels=False)
# )
# fig.show()

In [None]:
english_skip_indices = torch.arange(0, english_enabled_acts.shape[0], 3)

df = pd.DataFrame({f'{i}': english_disabled_acts[english_skip_indices, i].cpu().numpy() for i in top_neurons})

df_melted = df.melt(var_name='Neuron', value_name='Activation value')
fig = px.box(df_melted, x='Neuron', y='Activation value', title='Top neuron activation values on English data')
fig.update_layout(
    xaxis=dict(showticklabels=False)
)
fig.show()