## Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Downloading model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

## Presentation plots

In [11]:
english_context_activations = get_mlp_activations(english_data, 3, model, mean=False, hook_pre=True, neurons=[669])[:50000, 0]
german_context_activations =  get_mlp_activations(german_data, 3, model, mean=False, hook_pre=True, neurons=[669])[:50000, 0]
haystack_utils.two_histogram(english_context_activations, german_context_activations, "English", "German", "Context neuron L3N669 activation", "Activation value (pre-GELU)", y_label="Percent")

In [14]:
import plotting_utils

In [24]:
downstream_components = ("blocks.4.hook_attn_out", "blocks.5.hook_attn_out", "blocks.4.hook_mlp_out", "blocks.5.hook_mlp_out")

original_metrics = []
activated_metrics = []
ablated_metrics = []
direct_effect_metrics = []
indirect_effect_metrics = []
for prompt in tqdm(german_data):
    original_metric, activated_metric, ablated_metric, direct_effect_metric, indirect_effect_metric = haystack_utils.get_context_effect(german_data[0], model, 
                    context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=[], downstream_components=downstream_components)
    original_metrics.extend(original_metric.flatten().tolist())
    activated_metrics.extend(activated_metric.flatten().tolist())
    ablated_metrics.extend(ablated_metric.flatten().tolist())
    direct_effect_metrics.extend(direct_effect_metric.flatten().tolist())
    indirect_effect_metrics.extend(indirect_effect_metric.flatten().tolist())

  0%|          | 0/200 [00:00<?, ?it/s]

In [28]:
direct_effect_ablation = [a - b for a, b in zip(indirect_effect_metrics, original_metrics)]
indirect_effect_ablation = [a - b for a, b in zip(direct_effect_metrics, original_metrics)]
total_effect_ablation = [a - b for a, b in zip(ablated_metrics, original_metrics)]

In [39]:
import numpy as np
def percent_increase (ablation_effect, original_loss):
    original_mean_loss = np.mean(original_loss)
    mean_ablation_effect = np.mean(ablation_effect)
    return (mean_ablation_effect / original_mean_loss) * 100

ablation_percent_increases = [percent_increase(ablation_effect, original_metrics) for ablation_effect in [direct_effect_ablation, indirect_effect_ablation, total_effect_ablation]]
print(ablation_percent_increases)

[1.7648537318481978, 11.785123451106351, 12.034168413348215]


In [44]:
fig = plotting_utils.plot_barplot([direct_effect_ablation, indirect_effect_ablation, total_effect_ablation], confidence_interval=True, show=False, names=["Direct effect", "Indirect effect", "Total effect"], title="Effect of ablating context neuron L3N669 on German text", xlabel="", ylabel="Loss Increase")

for i, percent_increase in enumerate(ablation_percent_increases):
    fig.add_annotation(
        x=i,
        y=0.02,
        text=f"<b>+{percent_increase:.2f}%</b>",
        showarrow=False,
        font=dict(size=16)
    )
fig.show()

In [None]:
german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE]

tensor(-0.0844, device='cuda:0')

## Find top common German tokens

In [3]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

  0%|          | 0/200 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', ' zu', 'ch', 'n', 'st', 're', 'z', ' von', ' für', 'äsident', ' Pr', 'ischen', 't', 'ü', 'icht', 'gen', ' ist', ' auf', ' dass', 'ge', 'ig', ' im', 'in', ' über', 'g', ' das', 'te', ' er', 'men', ' w', 'es', ' an', 'ß', ' wir', ' eine', 'f', ' W', 'hen', 'w', ' Europ', ' ich', 'ungen', 'ren', 'le', ' dem', 'ten', ' ein', 'e', ' Z', ' Ver', 'der', ' B', ' mit', ' dies', 'h', ' nicht', 'ungs', 's', ' G', ' z', 'it', ' Herr', ' es', 'l', ' S', 'ich', 'lich', ' An', 'heit', 'ie', ' Er', ' zur', ' V', ' ver', 'u', 'hr', 'chaft', 'Der', ' Ich', ' Ab', ' haben', 'i', 'ant', 'chte', ' mö', 'er', ' K', 'igen', ' Ber', 'ür', ' Fra', 'em']


## Analysis of ngrams preceded by random prompts

In [4]:
def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

def replace_column(prompts: Int[Tensor, "n_prompts n_tokens"], token_index: int):
    # Replaces a specific token position in a batch of prompts with random common German tokens
    new_prompts = prompts.clone()
    random_tokens = get_random_selection(top_tokens[:max(50, prompts.shape[0])], n=prompts.shape[0]).cuda()
    new_prompts[:, token_index] = random_tokens
    return new_prompts 

def loss_analysis(prompts: Tensor, title=""):
    # Loss plot for a batch of prompts
    names = ["Original", "Ablated", "MLP5 path patched"]
    original_loss, ablated_loss, _, only_activated_loss = \
        haystack_utils.get_direct_effect(prompts, model, pos=-1,
                                        context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                        context_activation_hooks=activate_neurons_fwd_hooks, 
                                        )
    haystack_utils.plot_barplot([original_loss.tolist(), ablated_loss.tolist(), only_activated_loss.tolist()], names, ylabel="Loss", title=title)

In [5]:
def loss_analysis_random_prompts(end_string, n=50, length=12, replace_columns: list[int] | None = None):
    # Loss plot for a batch of random prompts ending with a specific ngram and optionally replacing specific tokens
    prompts = generate_random_prompts(end_string, n=n, length=length)
    title=f"Average last token loss on {length} random tokens ending in '{end_string}'"
    if replace_columns is not None:
        replaced_tokens = model.to_str_tokens(prompts[0, replace_columns])
        title += f" replacing {replaced_tokens}"
        for column in replace_columns:
            prompts = replace_column(prompts, column)
    
    loss_analysis(prompts, title=title)

In [6]:
# Example usage
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20, replace_columns=[-2])

## Widget

In [7]:
# Create a dropdown menu widget
dropdown = widgets.Dropdown(
    options = [" statt", " Vorschlägen", " Vorschläge", " häufig", " schließt", " beweglich", " seine Ansicht"], 
    value = " Vorschlägen",
    description = 'Ngram: ',
)

replace_columns_dropdown = widgets.SelectMultiple(
    options = ['-2', '-3', '-4', 'None'],
    value = ['None'],  # default selected value
    description = 'Replace Columns:',
)

# Create an output widget to hold the plot
output = widgets.Output()

# Define a function to call when the widget's value changes
def update_plot(*args):
    # Clear the old plot from the output
    with output:
        clear_output(wait=True)

        if 'None' in replace_columns_dropdown.value:
            replace_columns = None
        else:
            # If 'None' not selected, convert the selected values to integers
            replace_columns = [int(val) for val in replace_columns_dropdown.value]
        
        # Call your function with the values from the widgets
        loss_analysis_random_prompts(dropdown.value, n=100, length=20, replace_columns=replace_columns)

# Set the function to be called when the widget's value changes
dropdown.observe(update_plot, 'value')
replace_columns_dropdown.observe(update_plot, 'value')

# Display the widget and the output
display(dropdown, replace_columns_dropdown, output)

# Run once at startup
update_plot()


Dropdown(description='Ngram: ', index=1, options=(' statt', ' Vorschlägen', ' Vorschläge', ' häufig', ' schlie…

SelectMultiple(description='Replace Columns:', index=(3,), options=('-2', '-3', '-4', 'None'), value=('None',)…

Output()

## Identify relevant neurons

In [8]:
prompts = generate_random_prompts(" statt", n=100, length=20)
print(prompts.shape)

torch.Size([100, 22])


In [9]:
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache, layer=5):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neuron]
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_deactivated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_deactivated_loss - baseline_loss

print(diffs.mean())

  0%|          | 0/2048 [00:00<?, ?it/s]

tensor(0.0008)


In [10]:
sorted_means, indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
haystack_utils.line(sorted_means, xlabel="Sorted neurons", ylabel="Loss change", title="Loss change from ablating MLP5 neuron") # xticks=indices
sorted_top_neuron_indices = indices

In [11]:
# Check loss change when ablating top / bottom neurons
top_neurons_count = 10
top_neurons = indices[-top_neurons_count:]
bottom_neurons = indices[:top_neurons_count]

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss")

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_loss, ablated_loss, _, all_MLP5_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
_, _, _, top_MLP5_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook)
_, _, _, bottom_MLP5_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook)

names = ["Original", "Ablated", "MLP5 path patched", f"MLP5 path patched + Top {top_neurons_count} MLP5 neurons ablated", f"MLP5 path patched + Bottom {top_neurons_count} MLP5 neurons ablated"]
short_names = ["Original", "Ablated", "MLP5 path patched", f"Top MLP5 removed", f"Bottom MLP5 removed"]

values = [original_loss.tolist(), ablated_loss.tolist(), all_MLP5_loss.tolist(), top_MLP5_ablated_loss.tolist(), bottom_MLP5_ablated_loss.tolist()]
haystack_utils.plot_barplot(values, names, short_names=short_names, ylabel="Loss", title=f"Average last token loss when removing top / bottom neurons from path patching")

## Investigate top / bottom neuron boosts

In [13]:
# Patch in MLP5 neurons, ablated bottom neurons, compare with patched MLP5 neurons without ablating bottom neurons

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_logits, ablated_cache = model.run_with_cache(prompts)

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
_, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
_, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

In [14]:
def get_top_difference_neurons(baseline_logprobs, ablated_logprops, positive=True, logprob_threshold=-7, k=50):
    neuron_logprob_difference = (baseline_logprobs - ablated_logprops).mean(0)
    neuron_logprob_difference[baseline_logprobs.mean(0) < logprob_threshold] = 0
    if positive:
        non_zero_count = (neuron_logprob_difference > 0).sum()
    else:
        non_zero_count = (neuron_logprob_difference < 0).sum()
    print(neuron_logprob_difference.shape)
    top_logprob_difference, top_neurons = haystack_utils.top_k_with_exclude(neuron_logprob_difference, min(non_zero_count, k), all_ignore, largest=positive)
    return top_logprob_difference, top_neurons

In [15]:
# Boosted tokens
bottom_neuron_pos_difference_logprobs, bottom_pos_indices = get_top_difference_neurons(all_MLP5_logprobs, bottom_MLP5_ablated_logprobs, positive=True)
top_neuron_pos_difference_logprobs, top_pos_indices = get_top_difference_neurons(all_MLP5_logprobs, top_MLP5_ablated_logprobs, positive=True)
# Deboosted tokens
bottom_neuron_neg_difference_logprobs, bottom_neg_indices = get_top_difference_neurons(all_MLP5_logprobs, bottom_MLP5_ablated_logprobs, positive=False)
top_neuron_neg_difference_logprobs, top_neg_indices = get_top_difference_neurons(all_MLP5_logprobs, top_MLP5_ablated_logprobs, positive=False)

torch.Size([50304])
torch.Size([50304])
torch.Size([50304])
torch.Size([50304])


In [16]:
def plot_data(boosted_deboosted, top_bottom):
    if boosted_deboosted == 'Boosted' and top_bottom == 'Top':
        logprobs = top_neuron_pos_difference_logprobs
        indices = top_pos_indices
        title_change = 'increase'
    elif boosted_deboosted == 'Boosted' and top_bottom == 'Bottom':
        logprobs = bottom_neuron_pos_difference_logprobs
        indices = bottom_pos_indices
        title_change = 'increase'
    elif boosted_deboosted == 'Deboosted' and top_bottom == 'Top':
        logprobs = top_neuron_neg_difference_logprobs
        indices = top_neg_indices
        title_change = 'decrease'
    else:  # 'Deboosted' and 'Bottom'
        logprobs = bottom_neuron_neg_difference_logprobs
        indices = bottom_neg_indices
        title_change = 'decrease'

    xlabel = boosted_deboosted + " tokens"
    ylabel = "full_logprob - ablated_logprop"
    title = 'Logprob ' + title_change + ' from ' + top_bottom.lower() + ' neurons'
    xticks = [model.to_str_tokens([i])[0] for i in indices]

    haystack_utils.line(logprobs.cpu().numpy(), xlabel=xlabel, ylabel=ylabel, title=title, xticks=xticks, show_legend=False)


boosted_deboosted = widgets.Dropdown(options=['Boosted', 'Deboosted'], value='Boosted', description='Tokens:')
top_bottom = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
widgets.interactive(plot_data, boosted_deboosted=boosted_deboosted, top_bottom=top_bottom)


interactive(children=(Dropdown(description='Tokens:', options=('Boosted', 'Deboosted'), value='Boosted'), Drop…

## Decompose neuron effects

In [25]:
# Get logprob change for individual neuron

def get_individual_neuron_logprob_effect(sorted_indices, neuron_pos=0, top=True, positive=True, plot=True):
    if top:
        neurons = sorted_indices[-(neuron_pos+1)]
    else:
        neurons = sorted_indices[neuron_pos]
    neurons_hook = get_ablate_neurons_hook([neurons], ablated_cache)

    original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
    _, _, _, neuron_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+neurons_hook, return_type='logprobs')
    neuron_difference_logprobs, difference_indices = \
        get_top_difference_neurons(all_MLP5_logprobs, neuron_ablated_logprobs, positive=positive)
    
    if not plot:
        return neuron_difference_logprobs, difference_indices
    
    if positive:
        xlabel = 'Boosted tokens'
        title = 'Logprob increase from neuron ' + str(neurons.item())
    else: 
        xlabel = 'Deboosted tokens'
        title = 'Logprob decrease from neuron ' + str(neurons.item())
    if top:
        title += f' (top {neuron_pos+1} neuron)'
    else:
        title += f' (bottom {neuron_pos+1} neuron)'

    xticks = [model.to_str_tokens([i])[0] for i in difference_indices]
    haystack_utils.line(neuron_difference_logprobs.cpu().numpy(), xlabel=xlabel, ylabel="full_logprob - ablated_logprop", title=title, xticks=xticks, show_legend=False)

In [26]:
# Define a function to call your function with widget inputs
def plot_individual_neuron(neuron_pos, top_bottom, pos_neg):
    top_bool = True if top_bottom == 'Top' else False
    positive_bool = True if pos_neg == 'Positive' else False
    get_individual_neuron_logprob_effect(indices, neuron_pos=neuron_pos, top=top_bool, positive=positive_bool)

# Define widgets
neuron_pos_slider = widgets.IntSlider(min=0, max=20, step=1, value=0, description='Neuron Pos:')
top_bottom_dropdown = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
pos_neg_dropdown = widgets.Dropdown(options=['Positive', 'Negative'], value='Positive', description='Pos/Neg:')

# Use interactive function to bind the widgets to the plotting function
widgets.interactive(plot_individual_neuron, neuron_pos=neuron_pos_slider, top_bottom=top_bottom_dropdown, pos_neg=pos_neg_dropdown)


interactive(children=(IntSlider(value=0, description='Neuron Pos:', max=20), Dropdown(description='Neurons:', …

In [18]:
# Sum individual weighted positive / negative effects
def summed_neuron_differences(num_neurons=10, top=True, positive=True, k=50):
    summed_neuron_diffs = torch.zeros(model.cfg.d_vocab).cuda()
    for neuron_pos in range(num_neurons):
        # It would probably be more principle to do our filtering on the summed diffs instead of per neuron
        neuron_diff, token_indices = get_individual_neuron_logprob_effect(indices, neuron_pos=neuron_pos, top=top, positive=positive, plot=False)
        summed_neuron_diffs[token_indices] += neuron_diff
    if positive:
        non_zero_count = (summed_neuron_diffs > 0).sum()
    else:
        non_zero_count = (summed_neuron_diffs < 0).sum()
    top_logprob_difference, top_tokens = haystack_utils.top_k_with_exclude(summed_neuron_diffs, min(non_zero_count, k), all_ignore, largest=positive)
    return top_logprob_difference, top_tokens

def plot_summed_neuron_differences(num_neurons=10, top=True, positive=True):
    top_logprob_difference, top_logprob_tokens = summed_neuron_differences(num_neurons=num_neurons, top=top, positive=positive)
    xticks = [model.to_str_tokens([i])[0] for i in top_logprob_tokens]
    
    if positive:
        title = "Summed individual boosts by"
        xlabel = "Boosted tokens"
    else:
        title = "Summed individual deboosts by"
        xlabel = "Deboosted tokens"
    if top:
        title += f" top {num_neurons} neurons"
    else:
        title += f" bottom {num_neurons} neurons"
    haystack_utils.line(top_logprob_difference.cpu().numpy(), xlabel=xlabel, ylabel="full_logprob - ablated_logprop", title=title, xticks=xticks, show_legend=False)

In [19]:
# Define a function to call your function with widget inputs
def plot_summed_neurons_widget(num_neurons, top_bottom, pos_neg):
    top_bool = True if top_bottom == 'Top' else False
    positive_bool = True if pos_neg == 'Positive' else False
    plot_summed_neuron_differences(num_neurons=num_neurons, top=top_bool, positive=positive_bool)

# Define widgets
num_neuron_slider = widgets.IntSlider(min=1, max=20, step=1, value=10, description='Num Neurons:')
top_bottom_dropdown = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
pos_neg_dropdown = widgets.Dropdown(options=['Positive', 'Negative'], value='Positive', description='Pos/Neg:')

# Use interactive function to bind the widgets to the plotting function
widgets.interactive(plot_summed_neurons_widget, num_neurons=num_neuron_slider, top_bottom=top_bottom_dropdown, pos_neg=pos_neg_dropdown)

interactive(children=(IntSlider(value=10, description='Num Neurons:', max=20, min=1), Dropdown(description='Ne…

Loss is logprob(correct_token_logit) = logit(correct_token_logit) - LogSumExp(all_token_logits)
Loss = -logprob

LogSumExp approximates a maximum function. If the neuron engages in destructive interference of a high logit for a non-answer token, then the exp(logit) for the token will be lower and so the LogSumExp will be more similar to logit(correct_token_logit) so the loss will be lower. So a lower logsumexp(all vocab) is good.

If the neuron engages in destructive interference of a low logit for a non-answer token, then the exp(logit) for the the token won't really change and so the logsumexp and the loss will both be the same.

For a single neuron:
1. For each "gen" prompt, zero centre the logits and record the logit of "gen". Calculate the mean "gen" logit.
2. For each "gen" prompt, disable the neuron under test, zero centre the logits, edit the logit of "gen" to the mean with a hook, then get the logprob of "gen" and calculate logsumexp(all vocab) = logit(gen) - logprob(gen).
2. For each "gen" prompt, enable the neuron, zero centre the logits, edit the logit of "gen" to the mean with a hook, then get the logprob of "gen" and calculate logsumexp(all vocab) = logit(gen) - logprob(gen).
3. Take the difference in logsum exps. If it's positive, the neuron is reducing the loss via destructive interference by the difference.
Can use the same procedure for sets of neurons, or for all neurons, to find high level effects of the context neuron

Logprobs are logits with a constant subtracted and the constant is the same for every logit within a prompt.

Taking the difference in terms with and without a neuron's effect via the context neuron:
- If log sum exp increases, the neuron is boosting tokens on average. 
- If logit increases, the neuron is boosting the correct token


Remove the neuron's effects on the gen logit. Take the mean on the prompt and position?

In [29]:
# 1. Calculate the mean "gen" logit.
gen_index = model.to_single_token('icht')
gen_logits = []
logits = model(prompts, return_type='logits') # batch pos vocab
logits = logits - logits.mean(-1).unsqueeze(-1) # batch pos vocab, batch pos 1

mean_gen_logit = logits[:, -2, gen_index].mean(0)
mean_gen_logit

tensor(38.6459, device='cuda:0')

In [30]:
import plotly.express as px
import numpy as np
# from transformer_lens import utils
# 
# px.histogram(np.random.choice(logits[:, -1, gen_index].flatten().cpu().numpy(), 1000), nbins=100)
# utils.test_prompt("".join(model.to_str_tokens(prompts[0, :-1])), "gen", model, prepend_space_to_answer=False)

In [31]:
def decompose_interference_diff(ablated_cache, disabled_neurons=[top_neurons[0]], mean=True
                                ) -> tuple[float, float] | tuple[Float[Tensor, "n_prompts"], Float[Tensor, "n_prompts"]]:
    '''
    Finds the effect of the German context neuron ablation via a given set of MLP5 neurons on the logit of each final 
    token. Decomposes it into constructive and destructive interference. A positive constructive interference difference
    means the neurons boost the logit of the correct token, a positive destructive interference difference means the
    LogSumExp is getting closer to the correct token logit because neurons deboost the logits of the higher probability 
    incorrect tokens.

    Loss = -logprob(correct_token_logit)
    logprob(correct_token_logit) = correct_token_logit - LogSumExp(all_token_logits)

    LogSumExp is a smooth maximum function, so it approximates the max of all logits. A neuron that destructively 
    interfers with a non-answer token with a high logit lowers the LogSumExp(all_token_logits) and thus the loss.'''
    ablate_top_neuron_hook = get_ablate_neurons_hook(disabled_neurons, ablated_cache)

    _, _, _, mlp5_enabled_logits = haystack_utils.get_direct_effect(prompts, model, pos=-2, 
                                                                    context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                                                    context_activation_hooks=activate_neurons_fwd_hooks, return_type='logits')
    _, _, _, mlp5_enabled_top_neuron_ablated_logits = haystack_utils.get_direct_effect(prompts, model, pos=-2, 
                                                                                       context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                                                                       context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neuron_hook, return_type='logits')
    
    # Mean center logits to avoid picking up on constant boosts / deboosts
    mlp5_enabled_logits = mlp5_enabled_logits - mlp5_enabled_logits.mean(-1).unsqueeze(-1)
    mlp5_enabled_top_neuron_ablated_logits = mlp5_enabled_top_neuron_ablated_logits - mlp5_enabled_top_neuron_ablated_logits.mean(-1).unsqueeze(-1)
    
    # 1. Constructive interference difference
    # This is the change in the correct answer token logit from ablating the neuron, positive is good
    constructive_interference_diffs = mlp5_enabled_logits[:, gen_index] - mlp5_enabled_top_neuron_ablated_logits[:, gen_index]

    # 2. Destructive interference difference
    # This is the change in the LogSumExp of all logits from not ablating the neuron positive/increase in LogSumExp is bad 
    # even though it means that all tokens get deboosted more by a constant amount because the LogSumExp is guaranteed to be
    # >= the correct token logit

    # Set the gen logit to its mean value so the neuron's constructive interference doesn't affect the LogSumExp difference
    mlp5_enabled_logits[:, gen_index] = mean_gen_logit
    mlp5_enabled_top_neuron_ablated_logits[:, gen_index] = mean_gen_logit
    
    # Compute logsumexp
    mlp5_enabled_log_sum_exp = mlp5_enabled_logits.exp().sum(-1).log()
    mlp5_enabled_top_neuron_ablated_log_sum_exp = mlp5_enabled_top_neuron_ablated_logits.exp().sum(-1).log()

    # Check for errors
    assert torch.allclose(mlp5_enabled_log_sum_exp, mlp5_enabled_logits[:, gen_index] - mlp5_enabled_logits.log_softmax(-1)[:, gen_index])
    assert torch.allclose(mlp5_enabled_top_neuron_ablated_log_sum_exp, mlp5_enabled_top_neuron_ablated_logits[:, gen_index] - mlp5_enabled_top_neuron_ablated_logits.log_softmax(-1)[:, gen_index])

    # Difference in logsumexp
    # Logsumexp of enabled should be higher than ablated if the neuron does something good
    # Negative results are good - they mean that all tokens are deboosted more when the neuron is active
    destructive_interference_diffs = mlp5_enabled_log_sum_exp - mlp5_enabled_top_neuron_ablated_log_sum_exp
    # Convert the results so positive is good
    destructive_interference_diffs *= -1

    if mean:
        return constructive_interference_diffs.mean().item(), destructive_interference_diffs.mean().item(), 
    return constructive_interference_diffs, destructive_interference_diffs,


destructive_diffs = []
# Calculate neuron-wise loss change
print("constructive interference diff, destructive interference diff")
for i in range(10):
    constructive_diff, destructive_diff = decompose_interference_diff(ablated_cache, disabled_neurons=[top_neurons[i]], mean=True)
    print(f'{constructive_diff:2f}', f'{destructive_diff:2f}')
    destructive_diffs.append(destructive_diff)

constructive_diff, destructive_diff = decompose_interference_diff(ablated_cache, top_neurons)
print(f'{constructive_diff:2f}', f'{destructive_diff:2f}')

# line plot of logsumexp
haystack_utils.plot_barplot([[item] for item in destructive_diffs], names=[f"Neuron {i}" for i in range(10)], ylabel="LogSumExp reduction", title="LogSumExp reduction from ablating top neurons")

constructive interference diff, destructive interference diff
0.053818 0.069388
0.157467 0.031159
0.273463 -0.016081
0.163135 0.037566
0.089304 0.073434
0.092345 0.071882
-0.005005 0.118236
0.256493 0.025861
-0.076165 0.150736
0.625586 -0.014790
1.725928 0.645747


### Neuron -> Token Logit Analysis

In [23]:
# Our top neurons are selected by the difference in their boost in gen based on the context neuron
# Many other neurons boost gen more per unit of activation

cosine_sim = torch.nn.CosineSimilarity(dim=1)
answer_residual_direction = model.tokens_to_residual_directions("gen")
neuron_weights = model.state_dict()['blocks.5.mlp.W_out'][top_neurons]

cosine_sims = cosine_sim(neuron_weights, answer_residual_direction.unsqueeze(0))
print(cosine_sims)

cosine_sim = torch.nn.CosineSimilarity(dim=1)
answer_residual_direction = model.tokens_to_residual_directions("gen")
neuron_weights = model.state_dict()['blocks.5.mlp.W_out']

cosine_sims = cosine_sim(neuron_weights, answer_residual_direction.unsqueeze(0))
top, indices = torch.topk(cosine_sims, 20)
print(top)


tensor([ 0.0346, -0.0237, -0.0347,  0.0397,  0.0595,  0.0501, -0.0274,  0.1405,
         0.0417, -0.0392], device='cuda:0')
tensor([0.1634, 0.1461, 0.1405, 0.1399, 0.1349, 0.1207, 0.1196, 0.1161, 0.1156,
        0.1139, 0.1130, 0.1128, 0.1120, 0.1104, 0.1096, 0.1070, 0.1056, 0.1048,
        0.1042, 0.1028], device='cuda:0')


In [24]:
haystack_utils.clean_cache()
english_activations_l5 = {}
german_activations_l5 = {}
all_activations_l5 = {}
for layer in range(5, 6):
    english_activations_l5 = get_mlp_activations(english_data, layer, model, mean=True)
    german_activations_l5 = get_mlp_activations(german_data, layer, model, mean=True)
    all_activations_l5 = get_mlp_activations(german_data[:50] + english_data[:150], layer, model, mean=True)

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

### Gen dataframe

In [25]:
import pandas as pd
import einops

W_out = model.state_dict()['blocks.5.mlp.W_out']
W_U = model.W_U
avg_W_out_all = einops.einsum(all_activations_l5, W_out, 'd_mlp, d_mlp d_model -> d_mlp d_model')
avg_mlp5_boosts_all = einops.einsum(avg_W_out_all, W_U, 'd_mlp d_model, d_model d_vocab -> d_mlp d_vocab')
avg_W_out_enabled = einops.einsum(german_activations_l5, W_out, 'd_mlp, d_mlp d_model -> d_mlp d_model')
avg_mlp5_boosts_enabled = einops.einsum(avg_W_out_enabled, W_U, 'd_mlp d_model, d_model d_vocab -> d_mlp d_vocab')
avg_W_out_disabled = einops.einsum(english_activations_l5, W_out, 'd_mlp, d_mlp d_model -> d_mlp d_model')
avg_mlp5_boosts_disabled = einops.einsum(avg_W_out_disabled, W_U, 'd_mlp d_model, d_model d_vocab -> d_mlp d_vocab')

prompt_strs = [model.tokenizer.decode(prompts[i].tolist()) for i in range(prompts.shape[0])]

gen_acts_l5_all = get_mlp_activations(prompt_strs, 5, model, mean=True, pos=-2)
with model.hooks(deactivate_neurons_fwd_hooks):
    gen_acts_l5_disabled = get_mlp_activations(prompt_strs, 5, model, mean=True, pos=-2)
with model.hooks(activate_neurons_fwd_hooks):
    gen_acts_l5_enabled = get_mlp_activations(prompt_strs, 5, model, mean=True, pos=-2)

with model.hooks(deactivate_neurons_fwd_hooks):
    all_gen_acts_l5_disabled = get_mlp_activations(prompt_strs, 5, model, mean=False, pos=-2)
with model.hooks(activate_neurons_fwd_hooks):
    all_gen_acts_l5_enabled = get_mlp_activations(prompt_strs, 5, model, mean=False, pos=-2)

data = {
    'firing count diff': ((all_gen_acts_l5_enabled > 0).sum(0) - (all_gen_acts_l5_disabled > 0).sum(0)).tolist(),
    # 'neuron index': list(range(2048)),
    'cos sim gen': cosine_sims.tolist(),
    # 'avg act (enabled)': german_activations_l5.tolist(),
    # 'avg act (disabled)': english_activations_l5.tolist(),
    'avg act increase enabled': (german_activations_l5 - english_activations_l5).tolist(),
    'avg act': all_activations_l5.tolist(),
    'avg gen act': gen_acts_l5_all.tolist(),
    'avg gen act enabled': gen_acts_l5_enabled.tolist(),
    'avg gen act disabled': gen_acts_l5_disabled.tolist(),
    # 'avg boost gen': avg_mlp5_boosts_all[:, gen_index].tolist(),
    # 'avg boost enabled gen': avg_mlp5_boosts_enabled[:, gen_index].tolist(),
    # 'avg boost disabled gen': avg_mlp5_boosts_disabled[:, gen_index].tolist(),
    'enabled firing count': (all_gen_acts_l5_enabled > 0).sum(0).tolist(),
    'disabled firing count': (all_gen_acts_l5_disabled > 0).sum(0).tolist(),
    'firing count': (all_gen_acts_l5_disabled.bool().sum(0).tolist())
    
    # 'Loss change when ablated': 
}

df = pd.DataFrame(data)
# Line plot of each neuron's gen boost with dotted vertical lines where the German context neuron-boosted ones are

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [26]:
for prompt in english_data:
    prompt_tokens = model.to_tokens(prompt)
    if gen_index in prompt_tokens:
        index = prompt_tokens[0].tolist().index(gen_index)
        print(model.to_str_tokens(prompt_tokens[0, index-5:index+5]))

[' and', ' creating', ' the', ' Sc', 'hen', 'gen', ' area', '.', ' From', ' 2012']
[' Application', ' of', ' the', ' Sc', 'hen', 'gen', ' acquis', ' relating', ' to', ' the']
[' Bulgaria', ' to', ' the', ' Sc', 'hen', 'gen', ' Area', ',', ' nor', ' the']
[' can', ' enter', ' the', ' Sc', 'hen', 'gen', ' area', ' without', ' having', ' to']
[' to', ' restrict', ' the', ' Sc', 'hen', 'gen', ' area', ',', ' compatible', ' with']


In [27]:
# Weight analysis
gen_df = df[(df['cos sim gen'] > 0)]
print(f"{len(gen_df)}/{len(df)} MLP5 neurons generically increase \"gen\", {len(gen_df)/len(df)*100:.2f}%")

# Activation analysis
# All on German prompts
filtered_df = gen_df[(gen_df['avg act increase enabled'] > 0)]
print(f"{len(filtered_df)} of these fire more when context neuron enabled, {len(filtered_df)/len(df)*100:.2f}%")
filtered_df = gen_df[(gen_df['avg act increase enabled'] < 0)]
print(f"{len(filtered_df)} of these fire less when context neuron enabled, {len(filtered_df)/len(df)*100:.2f}%")

filtered_df = gen_df[(gen_df['avg gen act enabled'] > 0)]
print(f"{len(filtered_df)} of these fire on avg when context neuron enabled, {len(filtered_df)/len(df)*100:.2f}%")
filtered_df = gen_df[(gen_df['avg gen act disabled'] > 0)]
print(f"{len(filtered_df)} of these fire on avg when context neuron disabled, {len(filtered_df)/len(df)*100:.2f}%")

pd.set_option('display.max_rows', 100)

sorted = gen_df.sort_values(by=['firing count diff'], ascending=False)

sorted.head(15)

1107/2048 MLP5 neurons generically increase "gen", 54.05%
461 of these fire more when context neuron enabled, 22.51%
646 of these fire less when context neuron enabled, 31.54%
277 of these fire on avg when context neuron enabled, 13.53%
250 of these fire on avg when context neuron disabled, 12.21%


Unnamed: 0,firing count diff,cos sim gen,avg act increase enabled,avg act,avg gen act,avg gen act enabled,avg gen act disabled,enabled firing count,disabled firing count,firing count
13,93,0.057594,0.739129,0.160572,0.25159,0.15186,-0.124767,94,1,100
1927,92,0.025486,0.347759,0.020567,0.212794,0.149408,-0.075781,96,4,100
2038,91,0.003579,0.793564,0.159707,0.416826,0.289465,-0.061763,100,9,100
1667,90,0.03845,0.321733,0.052155,0.176787,0.096939,-0.121565,90,0,100
1716,88,0.146135,1.097669,0.571957,0.511899,0.311726,-0.100262,98,10,100
987,88,0.005353,0.352996,0.051033,0.144142,0.081435,-0.109428,89,1,100
84,87,0.041687,0.345945,-0.001036,0.625658,0.441772,-0.063952,100,13,100
358,87,0.008868,0.30744,0.041471,0.372113,0.258444,-0.056477,100,13,100
1877,87,0.030004,0.049302,0.190596,0.420401,0.318768,-0.066352,100,13,100
1747,86,0.053287,0.200168,-0.063503,0.163022,0.089683,-0.109547,89,3,100


Interested in how often the context neuron makes the difference between firing and not, rather than firing on average.
Neurons with top difference between number of fires with neuron enabled vs. disabled
x% more likely to fire
