## Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

## Find top common German tokens

In [3]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

  0%|          | 0/200 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', ' zu', 'ch', 'n', 'st', 're', 'z', ' von', ' für', 'äsident', ' Pr', 'ischen', 't', 'ü', 'icht', 'gen', ' ist', ' auf', ' dass', 'ge', 'ig', ' im', 'in', ' über', 'g', ' das', 'te', ' er', 'men', ' w', 'es', ' an', 'ß', ' wir', ' eine', 'f', ' W', 'hen', 'w', ' Europ', ' ich', 'ungen', 'ren', 'le', ' dem', 'ten', ' ein', 'e', ' Z', ' Ver', 'der', ' B', ' mit', ' dies', 'h', ' nicht', 'ungs', 's', ' G', ' z', 'it', ' Herr', ' es', 'l', ' S', 'ich', 'lich', ' An', 'heit', 'ie', ' Er', ' zur', ' V', ' ver', 'u', 'hr', 'chaft', 'Der', ' Ich', ' Ab', ' haben', 'i', 'ant', 'chte', ' mö', 'er', ' K', 'igen', ' Ber', 'ür', ' Fra', 'em']


## Analysis of ngrams preceded by random prompts

In [4]:
def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

def replace_column(prompts: Int[Tensor, "n_prompts n_tokens"], token_index: int):
    # Replaces a specific token position in a batch of prompts with random common German tokens
    new_prompts = prompts.clone()
    random_tokens = get_random_selection(top_tokens[:max(50, prompts.shape[0])], n=prompts.shape[0]).cuda()
    new_prompts[:, token_index] = random_tokens
    return new_prompts 

def loss_analysis(prompts: Tensor, title=""):
    # Loss plot for a batch of prompts
    names = ["Original", "Ablated", "MLP5 path patched"]
    original_loss, ablated_loss, _, only_activated_loss = \
        haystack_utils.get_direct_effect(prompts, model, pos=-1,
                                        context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                        context_activation_hooks=activate_neurons_fwd_hooks, 
                                        )
    haystack_utils.plot_barplot([original_loss.tolist(), ablated_loss.tolist(), only_activated_loss.tolist()], names, ylabel="Loss", title=title)

In [5]:
def loss_analysis_random_prompts(end_string, n=50, length=12, replace_columns: list[int] | None = None):
    # Loss plot for a batch of random prompts ending with a specific ngram and optionally replacing specific tokens
    prompts = generate_random_prompts(end_string, n=n, length=length)
    title=f"Average last token loss on {length} random tokens ending in '{end_string}'"
    if replace_columns is not None:
        replaced_tokens = model.to_str_tokens(prompts[0, replace_columns])
        title += f" replacing {replaced_tokens}"
        for column in replace_columns:
            prompts = replace_column(prompts, column)
    
    loss_analysis(prompts, title=title)

In [6]:
# Example usage
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20, replace_columns=[-2])

## Widget

In [7]:
# Create a dropdown menu widget
dropdown = widgets.Dropdown(
    options = [" Vorschlägen", " Vorschläge", " häufig", " schließt", " beweglich"], 
    value = " Vorschlägen",
    description = 'Ngram: ',
)

replace_columns_dropdown = widgets.SelectMultiple(
    options = ['-2', '-3', '-4', 'None'],
    value = ['None'],  # default selected value
    description = 'Replace Columns:',
)

# Create an output widget to hold the plot
output = widgets.Output()

# Define a function to call when the widget's value changes
def update_plot(*args):
    # Clear the old plot from the output
    with output:
        clear_output(wait=True)

        if 'None' in replace_columns_dropdown.value:
            replace_columns = None
        else:
            # If 'None' not selected, convert the selected values to integers
            replace_columns = [int(val) for val in replace_columns_dropdown.value]
        
        # Call your function with the values from the widgets
        loss_analysis_random_prompts(dropdown.value, n=100, length=20, replace_columns=replace_columns)

# Set the function to be called when the widget's value changes
dropdown.observe(update_plot, 'value')
replace_columns_dropdown.observe(update_plot, 'value')

# Display the widget and the output
display(dropdown, replace_columns_dropdown, output)

# Run once at startup
update_plot()


Dropdown(description='Ngram: ', options=(' Vorschlägen', ' Vorschläge', ' häufig', ' schließt', ' beweglich'),…

SelectMultiple(description='Replace Columns:', index=(3,), options=('-2', '-3', '-4', 'None'), value=('None',)…

Output()

## Identify relevant neurons

In [8]:
prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)
print(prompts.shape)

torch.Size([100, 24])


In [9]:
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache, layer=5):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neuron]
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_activated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_activated_loss - baseline_loss

print(diffs.mean())

  0%|          | 0/2048 [00:00<?, ?it/s]

tensor(0.0007)


In [10]:
sorted_means, indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
haystack_utils.line(sorted_means, xlabel="Sorted neurons", ylabel="Loss change", title="Loss change from ablating MLP5 neuron") # xticks=indices

In [11]:
# Check loss change when ablating top / bottom neurons
top_neurons_count = 10
top_neurons = indices[-top_neurons_count:]
bottom_neurons = indices[:top_neurons_count]

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss")

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_loss, ablated_loss, _, all_MLP5_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
_, _, _, top_MLP5_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook)
_, _, _, bottom_MLP5_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook)

names = ["Original", "Ablated", "MLP5 path patched", f"MLP5 path patched + Top {top_neurons_count} MLP5 neurons ablated", f"MLP5 path patched + Bottom {top_neurons_count} MLP5 neurons ablated"]
short_names = ["Original", "Ablated", "MLP5 path patched", f"Top MLP5 removed", f"Bottom MLP5 removed"]

values = [original_loss.tolist(), ablated_loss.tolist(), all_MLP5_loss.tolist(), top_MLP5_ablated_loss.tolist(), bottom_MLP5_ablated_loss.tolist()]
haystack_utils.plot_barplot(values, names, short_names=short_names, ylabel="Loss", title=f"Average last token loss when removing top / bottom neurons from path patching")

## Investigate top / bottom neuron boosts

In [12]:
# Patch in MLP5 neurons, ablated bottom neurons, compare with patched MLP5 neurons without ablating bottom neurons

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_logits, ablated_cache = model.run_with_cache(prompts)

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
_, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
_, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

In [13]:
def get_top_difference_neurons(baseline_logprobs, ablated_logprops, positive=True, logprob_threshold=-7, k=50):
    neuron_logprob_difference = (baseline_logprobs - ablated_logprops).mean(0)
    neuron_logprob_difference[baseline_logprobs.mean(0) < logprob_threshold] = 0
    if positive:
        non_zero_count = (neuron_logprob_difference > 0).sum()
    else:
        non_zero_count = (neuron_logprob_difference < 0).sum()
    top_logprob_difference, top_neurons = haystack_utils.top_k_with_exclude(neuron_logprob_difference, min(non_zero_count, k), all_ignore, largest=positive)
    return top_logprob_difference, top_neurons

In [14]:
# Boosted tokens
bottom_neuron_pos_difference_logprobs, bottom_pos_indices = get_top_difference_neurons(all_MLP5_logprobs, bottom_MLP5_ablated_logprobs, positive=True)
top_neuron_pos_difference_logprobs, top_pos_indices = get_top_difference_neurons(all_MLP5_logprobs, top_MLP5_ablated_logprobs, positive=True)
# Deboosted tokens
bottom_neuron_neg_difference_logprobs, bottom_neg_indices = get_top_difference_neurons(all_MLP5_logprobs, bottom_MLP5_ablated_logprobs, positive=False)
top_neuron_neg_difference_logprobs, top_neg_indices = get_top_difference_neurons(all_MLP5_logprobs, top_MLP5_ablated_logprobs, positive=False)

In [15]:
def plot_data(boosted_deboosted, top_bottom):
    if boosted_deboosted == 'Boosted' and top_bottom == 'Top':
        logprobs = top_neuron_pos_difference_logprobs
        indices = top_pos_indices
        title_change = 'increase'
    elif boosted_deboosted == 'Boosted' and top_bottom == 'Bottom':
        logprobs = bottom_neuron_pos_difference_logprobs
        indices = bottom_pos_indices
        title_change = 'increase'
    elif boosted_deboosted == 'Deboosted' and top_bottom == 'Top':
        logprobs = top_neuron_neg_difference_logprobs
        indices = top_neg_indices
        title_change = 'decrease'
    else:  # 'Deboosted' and 'Bottom'
        logprobs = bottom_neuron_neg_difference_logprobs
        indices = bottom_neg_indices
        title_change = 'decrease'

    xlabel = boosted_deboosted + " tokens"
    ylabel = "full_logprob - ablated_logprop"
    title = 'Logprob ' + title_change + ' from ' + top_bottom.lower() + ' neurons'
    xticks = [model.to_str_tokens([i])[0] for i in indices]

    haystack_utils.line(logprobs.cpu().numpy(), xlabel=xlabel, ylabel=ylabel, title=title, xticks=xticks, show_legend=False)


boosted_deboosted = widgets.Dropdown(options=['Boosted', 'Deboosted'], value='Boosted', description='Tokens:')
top_bottom = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
widgets.interactive(plot_data, boosted_deboosted=boosted_deboosted, top_bottom=top_bottom)


interactive(children=(Dropdown(description='Tokens:', options=('Boosted', 'Deboosted'), value='Boosted'), Drop…

## Decompose neuron effects

In [16]:
# Get logprob change for individual neuron

def get_individual_neuron_logprob_effect(sorted_indices, neuron_pos=0, top=True, positive=True, plot=True):
    if top:
        neurons = sorted_indices[-(neuron_pos+1)]
    else:
        neurons = sorted_indices[neuron_pos]
    neurons_hook = get_ablate_neurons_hook([neurons], ablated_cache)

    original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
    _, _, _, neuron_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+neurons_hook, return_type='logprobs')
    neuron_difference_logprobs, difference_indices = \
        get_top_difference_neurons(all_MLP5_logprobs, neuron_ablated_logprobs, positive=positive)
    
    if not plot:
        return neuron_difference_logprobs, difference_indices
    
    if positive:
        xlabel = 'Boosted tokens'
        title = 'Logprob increase from neuron ' + str(neurons.item())
    else: 
        xlabel = 'Deboosted tokens'
        title = 'Logprob decrease from neuron ' + str(neurons.item())
    if top:
        title += f' (top {neuron_pos+1} neuron)'
    else:
        title += f' (bottom {neuron_pos+1} neuron)'

    xticks = [model.to_str_tokens([i])[0] for i in difference_indices]
    haystack_utils.line(neuron_difference_logprobs.cpu().numpy(), xlabel=xlabel, ylabel="full_logprob - ablated_logprop", title=title, xticks=xticks, show_legend=False)

In [17]:
# Define a function to call your function with widget inputs
def plot_individual_neuron(neuron_pos, top_bottom, pos_neg):
    top_bool = True if top_bottom == 'Top' else False
    positive_bool = True if pos_neg == 'Positive' else False
    get_individual_neuron_logprob_effect(indices, neuron_pos=neuron_pos, top=top_bool, positive=positive_bool)

# Define widgets
neuron_pos_slider = widgets.IntSlider(min=0, max=20, step=1, value=0, description='Neuron Pos:')
top_bottom_dropdown = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
pos_neg_dropdown = widgets.Dropdown(options=['Positive', 'Negative'], value='Positive', description='Pos/Neg:')

# Use interactive function to bind the widgets to the plotting function
widgets.interactive(plot_individual_neuron, neuron_pos=neuron_pos_slider, top_bottom=top_bottom_dropdown, pos_neg=pos_neg_dropdown)


interactive(children=(IntSlider(value=0, description='Neuron Pos:', max=20), Dropdown(description='Neurons:', …

In [18]:
# Sum individual weighted positive / negative effects
def summed_neuron_differences(num_neurons=10, top=True, positive=True, k=50):
    summed_neuron_diffs = torch.zeros(model.cfg.d_vocab).cuda()
    for neuron_pos in range(num_neurons):
        # It would probably be more principle to do our filtering on the summed diffs instead of per neuron
        neuron_diff, token_indices = get_individual_neuron_logprob_effect(indices, neuron_pos=neuron_pos, top=top, positive=positive, plot=False)
        summed_neuron_diffs[token_indices] += neuron_diff
    if positive:
        non_zero_count = (summed_neuron_diffs > 0).sum()
    else:
        non_zero_count = (summed_neuron_diffs < 0).sum()
    top_logprob_difference, top_tokens = haystack_utils.top_k_with_exclude(summed_neuron_diffs, min(non_zero_count, k), all_ignore, largest=positive)
    return top_logprob_difference, top_tokens

def plot_summed_neuron_differences(num_neurons=10, top=True, positive=True):
    top_logprob_difference, top_logprob_tokens = summed_neuron_differences(num_neurons=num_neurons, top=top, positive=positive)
    xticks = [model.to_str_tokens([i])[0] for i in top_logprob_tokens]
    
    if positive:
        title = "Summed individual boosts by"
        xlabel = "Boosted tokens"
    else:
        title = "Summed individual deboosts by"
        xlabel = "Deboosted tokens"
    if top:
        title += f" top {num_neurons} neurons"
    else:
        title += f" bottom {num_neurons} neurons"
    haystack_utils.line(top_logprob_difference.cpu().numpy(), xlabel=xlabel, ylabel="full_logprob - ablated_logprop", title=title, xticks=xticks, show_legend=False)

In [19]:
# Define a function to call your function with widget inputs
def plot_summed_neurons_widget(num_neurons, top_bottom, pos_neg):
    top_bool = True if top_bottom == 'Top' else False
    positive_bool = True if pos_neg == 'Positive' else False
    plot_summed_neuron_differences(num_neurons=num_neurons, top=top_bool, positive=positive_bool)

# Define widgets
num_neuron_slider = widgets.IntSlider(min=1, max=20, step=1, value=10, description='Num Neurons:')
top_bottom_dropdown = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
pos_neg_dropdown = widgets.Dropdown(options=['Positive', 'Negative'], value='Positive', description='Pos/Neg:')

# Use interactive function to bind the widgets to the plotting function
widgets.interactive(plot_summed_neurons_widget, num_neurons=num_neuron_slider, top_bottom=top_bottom_dropdown, pos_neg=pos_neg_dropdown)

interactive(children=(IntSlider(value=10, description='Num Neurons:', max=20, min=1), Dropdown(description='Ne…

Loss is logprob(correct_token_logit) = logit(correct_token_logit) - LogSumExp(all_token_logits)
Loss = -logprob

LogSumExp approximates a maximum function. If the neuron engages in destructive interference of a high logit for a non-answer token, then the exp(logit) for the token will be lower and so the LogSumExp will be more similar to logit(correct_token_logit) so the loss will be lower. So a lower logsumexp(all vocab) is good.

If the neuron engages in destructive interference of a low logit for a non-answer token, then the exp(logit) for the the token won't really change and so the logsumexp and the loss will both be the same.

For a single neuron:
1. For each "gen" prompt, zero centre the logits and record the logit of "gen". Calculate the mean "gen" logit.
2. For each "gen" prompt, disable the neuron under test, zero centre the logits, edit the logit of "gen" to the mean with a hook, then get the logprob of "gen" and calculate logsumexp(all vocab) = logit(gen) - logprob(gen).
2. For each "gen" prompt, enable the neuron, zero centre the logits, edit the logit of "gen" to the mean with a hook, then get the logprob of "gen" and calculate logsumexp(all vocab) = logit(gen) - logprob(gen).
3. Take the difference in logsum exps. If it's positive, the neuron is reducing the loss via destructive interference by the difference.
Can use the same procedure for sets of neurons, or for all neurons, to find high level effects of the context neuron

Logprobs are logits with a constant subtracted and the constant is the same for every logit within a prompt.

Taking the difference in terms with and without a neuron's effect via the context neuron:
- If log sum exp increases, the neuron is boosting tokens on average. 
- If logit increases, the neuron is boosting the correct token


Remove the neuron's effects on the gen logit. Take the mean on the prompt and position?

In [20]:
# 1. Calculate the mean "gen" logit.
gen_index = model.to_single_token('gen')
gen_logits = []
logits = model(prompts, return_type='logits') # batch pos vocab
logits = logits - logits.mean(-1).unsqueeze(-1) # batch pos vocab, batch pos 1

mean_gen_logit = logits[:, -1, gen_index].mean(0)
mean_gen_logit

tensor(11.2822, device='cuda:0')

In [21]:
import plotly.express as px
import numpy as np
# from transformer_lens import utils
# 
# px.histogram(np.random.choice(logits[:, -1, gen_index].flatten().cpu().numpy(), 1000), nbins=100)
# utils.test_prompt("".join(model.to_str_tokens(prompts[0, :-1])), "gen", model, prepend_space_to_answer=False)

In [27]:
def decompose_interference_diff(ablated_cache, disabled_neurons=[top_neurons[0]], mean=True
                                ) -> tuple[float, float] | tuple[Float[Tensor, "n_prompts"], Float[Tensor, "n_prompts"]]:
    '''
    Finds the effect of the German context neuron ablation via a given set of MLP5 neurons on the logit of each final 
    token. Decomposes it into constructive and destructive interference. A positive constructive interference difference
    means the neurons boost the logit of the correct token, a positive destructive interference difference means the
    neurons deboost the logits of the incorrect tokens.

    Loss = -logprob(correct_token_logit)
    logprob(correct_token_logit) = correct_token_logit - LogSumExp(all_token_logits)

    LogSumExp is a smooth maximum function, so it approximates the max of all logits. A neuron that destructively 
    interfers with a non-answer token with a high logit lowers the LogSumExp(all_token_logits) and thus the loss.'''
    ablate_top_neuron_hook = get_ablate_neurons_hook(disabled_neurons, ablated_cache)

    _, _, _, mlp5_enabled_logits = haystack_utils.get_direct_effect(prompts, model, pos=-1, 
                                                                    context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                                                    context_activation_hooks=activate_neurons_fwd_hooks, return_type='logits')
    _, _, _, mlp5_enabled_top_neuron_ablated_logits = haystack_utils.get_direct_effect(prompts, model, pos=-1, 
                                                                                       context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                                                                       context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neuron_hook, return_type='logits')
    
    # Mean center logits to avoid picking up on constant boosts / deboosts
    mlp5_enabled_logits = mlp5_enabled_logits - mlp5_enabled_logits.mean(-1).unsqueeze(-1)
    mlp5_enabled_top_neuron_ablated_logits = mlp5_enabled_top_neuron_ablated_logits - mlp5_enabled_top_neuron_ablated_logits.mean(-1).unsqueeze(-1)
    
    # 1. Constructive interference difference
    # This is the change in the correct answer token logit from ablating the neuron, positive is good
    constructive_interference_diffs = mlp5_enabled_logits[:, gen_index] - mlp5_enabled_top_neuron_ablated_logits[:, gen_index]

    # 2. Destructive interference difference
    # This is the change in the LogSumExp of all logits from not ablating the neuron, 
    # Lucia: positive/increase in LogSumExp is bad
    # Lovis: I don't think so - positive logsumexp means that all tokens get deboosted which is good

    # Set the gen logit to its mean value so the neuron's constructive interference doesn't affect the LogSumExp difference
    mlp5_enabled_logits[:, gen_index] = mean_gen_logit
    mlp5_enabled_top_neuron_ablated_logits[:, gen_index] = mean_gen_logit
    
    # Compute logsumexp
    mlp5_enabled_log_sum_exp = mlp5_enabled_logits.exp().sum(-1).log()
    mlp5_enabled_top_neuron_ablated_log_sum_exp = mlp5_enabled_top_neuron_ablated_logits.exp().sum(-1).log()

    # Check for errors
    assert torch.allclose(mlp5_enabled_log_sum_exp, mlp5_enabled_logits[:, gen_index] - mlp5_enabled_logits.log_softmax(-1)[:, gen_index])
    assert torch.allclose(mlp5_enabled_top_neuron_ablated_log_sum_exp, mlp5_enabled_top_neuron_ablated_logits[:, gen_index] - mlp5_enabled_top_neuron_ablated_logits.log_softmax(-1)[:, gen_index])

    # Difference in logsumexp
    # Logsumexp of enabled should be higher than ablated if the neuron does something good
    # Positive results are good - they mean that all tokens are deboosted more when the neuron is active
    destructive_interference_diffs = mlp5_enabled_log_sum_exp - mlp5_enabled_top_neuron_ablated_log_sum_exp
    # 3. Flip the diff's sign to ease interpretation, now positive is good for both constructive and destructive interference
    #destructive_interference_diffs *= -1

    if mean:
        return constructive_interference_diffs.mean().item(), destructive_interference_diffs.mean().item(), 
    return constructive_interference_diffs, destructive_interference_diffs,


# Calculate neuron-wise loss change
print("constructive interference diff, destructive interference diff")
for i in range(5):
    constructive_diff, destructive_diff = decompose_interference_diff(ablated_cache, disabled_neurons=[top_neurons[i]], mean=True)
    print(f'{constructive_diff:2f}', f'{destructive_diff:2f}')

constructive_diff, destructive_diff = decompose_interference_diff(ablated_cache, top_neurons)
print(f'{constructive_diff:2f}', f'{destructive_diff:2f}')

constructive interference diff, destructive interference diff
torch.Size([100]) torch.Size([])
0.038536 -0.028058
torch.Size([100]) torch.Size([])
-0.000311 -0.000128
torch.Size([100]) torch.Size([])
0.028768 0.006270
torch.Size([100]) torch.Size([])
0.000000 0.000000
torch.Size([100]) torch.Size([])
0.026021 -0.003454
torch.Size([100]) torch.Size([])
0.337268 -0.055345


### Neuron -> Token Logit Analysis

In [23]:
# Our top neurons are selected by the difference in their boost in gen based on the context neuron
# Many other neurons boost gen more per unit of activation

cosine_sim = torch.nn.CosineSimilarity(dim=1)
answer_residual_direction = model.tokens_to_residual_directions("gen")
neuron_weights = model.state_dict()['blocks.5.mlp.W_out'][top_neurons]

cosine_sims = cosine_sim(neuron_weights, answer_residual_direction.unsqueeze(0))
print(cosine_sims)

cosine_sim = torch.nn.CosineSimilarity(dim=1)
answer_residual_direction = model.tokens_to_residual_directions("gen")
neuron_weights = model.state_dict()['blocks.5.mlp.W_out']

cosine_sims = cosine_sim(neuron_weights, answer_residual_direction.unsqueeze(0))
top, indices = torch.topk(cosine_sims, 20)
print(top)


tensor([ 0.0346, -0.0237, -0.0347, -0.0274,  0.0397,  0.0595,  0.0501,  0.1405,
         0.0417, -0.0392], device='cuda:0')
tensor([0.1634, 0.1461, 0.1405, 0.1399, 0.1349, 0.1207, 0.1196, 0.1161, 0.1156,
        0.1139, 0.1130, 0.1128, 0.1120, 0.1104, 0.1096, 0.1070, 0.1056, 0.1048,
        0.1042, 0.1028], device='cuda:0')


In [24]:
haystack_utils.clean_cache()
english_activations_l5 = {}
german_activations_l5 = {}
all_activations_l5 = {}
for layer in range(5, 6):
    english_activations_l5[layer] = get_mlp_activations(english_data, layer, model, mean=True)
    german_activations_l5[layer] = get_mlp_activations(german_data, layer, model, mean=True)
    all_activations_l5[layer] = get_mlp_activations(german_data[:50] + english_data[:150], layer, model, mean=True)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [25]:
import pandas as pd
# token based measures:
# index
# next index

# neuron based measures:
# index
# cosine sim with gen
# average activation when context neuron enabled
# average activation when context neuron disabled
# 5 if cosine sim > 0 else -5
# loss change when neuron ablated

# boost_gen_acts = cosine_sims
# boost_gen_acts[cosine_sims > 0] = 5.0
# boost_gen_acts[cosine_sims <= 0] = -5.0

# def get_deactivate_neuron_hook(neuron):
#     def deactivate_neurons_hook(value, hook):
#         value[:, :, neuron] = MEAN_ACTIVATION_INACTIVE
#         return value
# deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]
# loss_diffs = []
# for i in range(2048):

W_out = model.state_dict()['blocks.5.mlp.W_out']
W_U = model.W_U
print(W_out.shape, W_U.shape, all_activations_l5.unsqueeze(1).shape)

average_boost = W_out * all_activations_l5.unsqueeze(1) * W_U
print(average_boost.shape) # d_mlp d_vocab

# 'Average gen boost': [i zip(all_activations_l5,

# mean activation
data = {
    'Neuron index': list(range(2048)),
    'Cosine similarity with \"gen\"': cosine_sims.tolist(),
    'Average act - context neuron enabled': german_activations_l5.tolist(),
    'Average act - context neuron disabled': english_activations_l5.tolist(),
    'Average act': all_activations_l5.tolist()
    # 'Boost \"gen\" act': boost_gen_acts.tolist(),
    # 'Loss change when ablated': 
}

# Create the dataframe
df = pd.DataFrame(data)

# Print the dataframe
print(df)

# More complete picture of neurons that boost gen.
# Line plot of each neuron's gen boost with dotted vertical lines where the German context neuron-boosted ones are

AttributeError: 'dict' object has no attribute 'unsqueeze'

## Brainstorm notes

I want to decompose an MLP5 neuron's effect into its boost to the correct logit and its deboost of other logits. I want to discover how these two effects change the log prob.

metric like loss reduction vs. token boost

~~run direct effect, patch each neuron in top 10 individually, get overall loss reduction from neuron controlled by context neuron (equivalent to logprob increase for correct token)~~
decompose loss reduction into two parts:
run direct effect, patch the correct token logit boost from the neuron (removing the neuron's other effects), get overall loss reduction from neuron (equivalent to logprob increase for correct token)
destructive interference loss reduction = overall loss reduction - loss reduction from correct token logit boost (component of the logprob increase for correct token due to it deboosting incorrect token)

Patch the correct token logit boost from the neuron (removing the neuron's other effects).
1. Get baseline logprobs for a prompt
2. Get difference in logits from activating the neuron under test. Can use get_direct_effects with return_type='logits'.
3. Run the model with return_type='logits' and the neuron under test zero ablated. 
4. Add the correct token logit from step 1 to a copy of the output logits. Convert to logprobs
5. Add the incorrect token logits from step 1 to a copy of the output logits. Convert to logprobs
6. Compare A. lobprobs with correct answer token logit increase, B. logprobs with incorrect answer token logit increases, and C. baseline logprobs

~~If the context neuron gives each neuron a flat boost then if we decompose the resulting flat boost to one MLP5 neuron into a boost to one logit vs. boosts to all other logits it will change the logprobs (first will increase answer probability and second will reduce answer probability). 

the two resulting log probs at the correct answer token won't add up to the original log probs (?). 

If the boost is flat the correct percentage decomposition is 1/50000 and 49999/50000? In practice/all other factors being equal

New plan:

Difference between baseline log prob and neuron log prob?
Classify individual neurons by percentage constructive vs destructive by looking at their log probs and summing the incorrect token log probs

Correct log prob difference
Incorrect log prob difference (summed over every plausible token?)

Largest boost

In [None]:
# If the log prob for an incorrect token is significantly lower then that's where the extra probability density on the correct answer is coming from 
# Constructive interference increases correct token log prob and uniformly decreases other log probs
# Destructive interference decreases specific other log probs and uniformly increases other log probs
 
# Most neurons are a mixture of the above
# Decompose neurons into what % of their effect is each

In [None]:
# original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
# _, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
# _, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

# bottom_neuron_high_difference_logprobs = (all_MLP5_logprobs - bottom_MLP5_ablated_logprobs)
# bottom_neuron_high_difference_logprobs[bottom_neuron_high_difference_logprobs < 0] = 0
# bottom_neuron_high_difference_logprobs = bottom_neuron_high_difference_logprobs.mean(0)
# bottom_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7] = 0

# bottom_non_zero_count = (bottom_neuron_high_difference_logprobs > 0).sum()
# bottom_neuron_high_difference_logprobs, bottom_indices = haystack_utils.top_k_with_exclude(bottom_neuron_high_difference_logprobs, min(bottom_non_zero_count, 50), all_ignore)
# haystack_utils.line(bottom_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in bottom_indices])


# top_neuron_high_difference_logprobs = (all_MLP5_logprobs - top_MLP5_ablated_logprobs)
# top_neuron_high_difference_logprobs[top_neuron_high_difference_logprobs < 0] = 0
# top_neuron_high_difference_logprobs = top_neuron_high_difference_logprobs.mean(0)
# top_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7] = 0

# top_non_zero_count = (top_neuron_high_difference_logprobs > 0).sum()
# top_neuron_high_difference_logprobs, top_indices = haystack_utils.top_k_with_exclude(top_neuron_high_difference_logprobs, min(top_non_zero_count, 50), all_ignore)
# haystack_utils.line(top_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in top_indices])
