## Setup

In [2]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [3]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

## Find top common German tokens

In [None]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

  0%|          | 0/2000 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', 'ch', 'st', ' zu', 're', ' für', 'äsident', ' Pr', 'n', 'z', 'ischen', ' von', 'ü', 't', 'icht', 'in', 'ge', 'gen', 'te', ' ist', ' auf', 'ig', ' über', ' dass', ' im', 'f', ' er', 'es', ' das', 'men', 'g', 'ß', ' Europ', ' w', 'w', 'le', 'ten', ' eine', ' wir', ' ein', ' an', 'hen', 'ren', 'e', ' ich', 'ungen', ' W', ' Ver', ' B', ' dem', ' mit', ' dies', ' nicht', ' Z', 'h', ' z', 's', 'it', 'hr', ' es', ' zur', ' An', ' Herr', 'ich', 'heit', 'b', 'lich', 'l', ' ver', ' S', 'i', ' G', 'Der', ' V', 'der', 'u', 'ie', ' Ab', 'ungs', 'chte', 'chaft', 'igen', ' werden', 'uss', 'ord', 'em', ' Ber', 'ür', ' haben', 'et', ' um', ' Ich']


## Analysis of ngrams preceded by random prompts

In [None]:
def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

def replace_column(prompts: Int[Tensor, "n_prompts n_tokens"], token_index: int):
    # Replaces a specific token position in a batch of prompts with random common German tokens
    new_prompts = prompts.clone()
    random_tokens = get_random_selection(top_tokens[:max(50, prompts.shape[0])], n=prompts.shape[0]).cuda()
    new_prompts[:, token_index] = random_tokens
    return new_prompts 

def loss_analysis(prompts: Tensor, title=""):
    # Loss plot for a batch of prompts
    names = ["Original", "Ablated", "MLP5 path patched"]
    original_loss, ablated_loss, _, only_activated_loss = \
        haystack_utils.get_direct_effect(prompts, model, pos=-1,
                                        context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                        context_activation_hooks=activate_neurons_fwd_hooks, 
                                        )
    haystack_utils.plot_barplot([original_loss.tolist(), ablated_loss.tolist(), only_activated_loss.tolist()], names, ylabel="Loss", title=title)

In [None]:
def loss_analysis_random_prompts(end_string, n=50, length=12, replace_columns: list[int] | None = None):
    # Loss plot for a batch of random prompts ending with a specific ngram and optionally replacing specific tokens
    prompts = generate_random_prompts(end_string, n=n, length=length)
    title=f"Average last token loss on {length} random tokens ending in '{end_string}'"
    if replace_columns is not None:
        replaced_tokens = model.to_str_tokens(prompts[0, replace_columns])
        title += f" replacing {replaced_tokens}"
        for column in replace_columns:
            prompts = replace_column(prompts, column)
    
    loss_analysis(prompts, title=title)

In [None]:
# Example usage
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20, replace_columns=[-2])

## Widget

In [None]:
# Create a dropdown menu widget
dropdown = widgets.Dropdown(
    options = [" Vorschlägen", " Vorschläge", " häufig", " schließt", " beweglich"], 
    value = " Vorschlägen",
    description = 'Ngram: ',
)

replace_columns_dropdown = widgets.SelectMultiple(
    options = ['-2', '-3', '-4', 'None'],
    value = ['None'],  # default selected value
    description = 'Replace Columns:',
)

# Create an output widget to hold the plot
output = widgets.Output()

# Define a function to call when the widget's value changes
def update_plot(*args):
    # Clear the old plot from the output
    with output:
        clear_output(wait=True)

        if 'None' in replace_columns_dropdown.value:
            replace_columns = None
        else:
            # If 'None' not selected, convert the selected values to integers
            replace_columns = [int(val) for val in replace_columns_dropdown.value]
        
        # Call your function with the values from the widgets
        loss_analysis_random_prompts(dropdown.value, n=100, length=20, replace_columns=replace_columns)

# Set the function to be called when the widget's value changes
dropdown.observe(update_plot, 'value')
replace_columns_dropdown.observe(update_plot, 'value')

# Display the widget and the output
display(dropdown, replace_columns_dropdown, output)

# Run once at startup
update_plot()


Dropdown(description='Ngram: ', options=(' Vorschlägen', ' Vorschläge', ' häufig', ' schließt', ' beweglich'),…

SelectMultiple(description='Replace Columns:', index=(3,), options=('-2', '-3', '-4', 'None'), value=('None',)…

Output()

## Neuron level analysis

In [None]:
prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)
print(prompts.shape)

torch.Size([100, 24])


In [None]:
# Calculate neuron-wise loss change
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache['blocks.5.mlp.hook_post'][:, :, neuron]
        return value
    return [('blocks.5.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks,
                                                          deactivated_components=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.5.hook_mlp_out"), activated_components=("blocks.4.hook_mlp_out",))
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_activated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook,
                                                                    deactivated_components=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.5.hook_mlp_out"), activated_components=("blocks.4.hook_mlp_out",))
    diffs[neuron] = only_activated_loss - baseline_loss

print(diffs.mean())

  0%|          | 0/2048 [00:00<?, ?it/s]

tensor(0.0007)


In [None]:
sorted_means, indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
haystack_utils.line(sorted_means, xlabel="Neuron", ylabel="Loss change", title="Loss change from ablating MLP5 neuron") # xticks=indices

In [None]:
# Check loss change when ablating top / bottom 100 neurons

top_neurons_count = 10
top_neurons = indices[-top_neurons_count:]
bottom_neurons = indices[:top_neurons_count]

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss")

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_loss, ablated_loss, _, all_MLP4_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks,
                                                                                 deactivated_components=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.5.hook_mlp_out"), activated_components=("blocks.4.hook_mlp_out",))
_, _, _, top_MLP4_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook,
                                                                  deactivated_components=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.5.hook_mlp_out"), activated_components=("blocks.4.hook_mlp_out",))
_, _, _, bottom_MLP4_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook,
                                                                     deactivated_components=("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.5.hook_mlp_out"), activated_components=("blocks.4.hook_mlp_out",))


names = ["Original", "Ablated", "MLP5 path patched", f"MLP5 path patched + Top {top_neurons_count} MLP5 neurons ablated", f"MLP5 path patched + Bottom {top_neurons_count} MLP5 neurons ablated"]
short_names = ["Original", "Ablated", "MLP5 path patched", f"Top MLP5 ablated", f"Bottom MLP5 ablated"]

values = [original_loss.tolist(), ablated_loss.tolist(), all_MLP4_loss.tolist(), top_MLP4_ablated_loss.tolist(), bottom_MLP4_ablated_loss.tolist()]
haystack_utils.plot_barplot(values, names, short_names=short_names, ylabel="Loss", title=f"Average last token loss for different MLP5 neuron sets")

Hypothesis: there's a second token the model likes, represented by the bottom neurons boosting something else

Question: would it generate the same bar plot token? i.e. depends on the context neuron

In [None]:
# Patch in MLP5 neurons, ablated bottom neurons, compare with patched MLP5 neurons without ablating bottom neurons

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_logits, ablated_cache = model.run_with_cache(prompts)

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
_, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
_, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

bottom_neuron_high_difference_logprobs = (all_MLP5_logprobs - bottom_MLP5_ablated_logprobs).mean(0)
bottom_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7.0] = 0
bottom_non_zero_count = (bottom_neuron_high_difference_logprobs > 0).sum()
bottom_neuron_high_difference_logprobs, bottom_indices = haystack_utils.top_k_with_exclude(bottom_neuron_high_difference_logprobs, min(bottom_non_zero_count, 50), all_ignore)


names = ["Original", "Ablated", "MLP5 path patched", f"MLP5 path patched + Top {top_neurons_count} MLP5 neurons ablated", f"MLP5 path patched + Bottom {top_neurons_count} MLP5 neurons ablated"]
short_names = ["Original", "Ablated", "MLP5 path patched", f"Top MLP5 ablated", f"Bottom MLP5 ablated"]


values = [original_loss.tolist(), ablated_loss.tolist(), all_MLP5_loss.tolist(), top_MLP5_ablated_loss.tolist(), bottom_MLP5_ablated_loss.tolist()]
haystack_utils.plot_barplot(values, names, short_names=short_names, ylabel="Loss", title=f"Average last token loss for different MLP5 neuron sets")

In [None]:
haystack_utils.line(bottom_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in bottom_indices])
# possible interference completions " Vorschlägar", " Vorschläges", " Vorschläcen", " Vorschläg", " Vorschläu", " Vorschlägt", " Vorschlägs"

In [None]:
# If the log prob for an incorrect token is significantly lower then that's where the extra probability density on the correct answer is coming from 
# Constructive interference increases correct token log prob and uniformly decreases other log probs
# Destructive interference decreases specific other log probs and uniformly increases other log probs
 
# Most neurons are a mixture of the above
# Decompose neurons into what % of their effect is each
# 


top_neuron_high_difference_logprobs = (all_MLP5_logprobs - top_MLP5_ablated_logprobs).mean(0)
top_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7.0] = 0
top_non_zero_count = (top_neuron_high_difference_logprobs > 0).sum()
top_neuron_high_difference_logprobs, top_indices = haystack_utils.top_k_with_exclude(top_neuron_high_difference_logprobs, min(top_non_zero_count, 50), all_ignore)
haystack_utils.line(top_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when top neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in top_indices])

I want to decompose an MLP5 neuron's effect into its boost to the correct logit and its deboost of other logits. I want to discover how these two effects change the log prob. I think I can just get the logits, modify, calculate logprobs, and compare.

metric like loss reduction vs. token boost

~~run direct effect, patch each neuron in top 10 individually, get overall loss reduction from neuron controlled by context neuron (equivalent to logprob increase for correct token)~~
decompose loss reduction into two parts:
run direct effect, patch the correct token logit boost from the neuron (removing the neuron's other effects), get overall loss reduction from neuron (equivalent to logprob increase for correct token)
destructive interference loss reduction = overall loss reduction - loss reduction from correct token logit boost (component of the logprob increase for correct token due to it deboosting incorrect token)

Patch the correct token logit boost from the neuron (removing the neuron's other effects).
1. Get baseline logprobs for a prompt
2. Get difference in logits from activating the neuron under test. Can use get_direct_effects with return_type='logits'.
3. Run the model with return_type='logits' and the neuron under test zero ablated. 
4. Add the correct token logit from step 1 to a copy of the output logits. Convert to logprobs
5. Add the incorrect token logits from step 1 to a copy of the output logits. Convert to logprobs
6. Compare A. lobprobs with correct answer token logit increase, B. logprobs with incorrect answer token logit increases, and C. baseline logprobs

~~If the context neuron gives each neuron a flat boost then if we decompose the resulting flat boost to one MLP5 neuron into a boost to one logit vs. boosts to all other logits it will change the logprobs (first will increase answer probability and second will reduce answer probability). 

the two resulting log probs at the correct answer token won't add up to the original log probs (?). 

If the boost is flat the correct percentage decomposition is 1/50000 and 49999/50000? In practice/all other factors being equal

New plan:

Difference between baseline log prob and neuron log prob?
Classify individual neurons by percentage constructive vs destructive by looking at their log probs and summing the incorrect token log probs

Correct log prob difference
Incorrect log prob difference (summed over every plausible token?)

Largest boost

In [None]:
original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
_, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
_, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

bottom_neuron_high_difference_logprobs = (all_MLP5_logprobs - bottom_MLP5_ablated_logprobs)
bottom_neuron_high_difference_logprobs[bottom_neuron_high_difference_logprobs < 0] = 0
bottom_neuron_high_difference_logprobs = bottom_neuron_high_difference_logprobs.mean(0)
bottom_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7] = 0

bottom_non_zero_count = (bottom_neuron_high_difference_logprobs > 0).sum()
bottom_neuron_high_difference_logprobs, bottom_indices = haystack_utils.top_k_with_exclude(bottom_neuron_high_difference_logprobs, min(bottom_non_zero_count, 50), all_ignore)
haystack_utils.line(bottom_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in bottom_indices])


top_neuron_high_difference_logprobs = (all_MLP5_logprobs - top_MLP5_ablated_logprobs)
top_neuron_high_difference_logprobs[top_neuron_high_difference_logprobs < 0] = 0
top_neuron_high_difference_logprobs = top_neuron_high_difference_logprobs.mean(0)
top_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7] = 0

top_non_zero_count = (top_neuron_high_difference_logprobs > 0).sum()
top_neuron_high_difference_logprobs, top_indices = haystack_utils.top_k_with_exclude(top_neuron_high_difference_logprobs, min(top_non_zero_count, 50), all_ignore)
haystack_utils.line(top_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in top_indices])


NameError: name 'date' is not defined