## Setup

In [2]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
from datasets import load_dataset

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [3]:
uspto_data = haystack_utils.load_txt_data('data/uspto.txt')[:500]
non_uspto_data = haystack_utils.load_txt_data('data/non_uspto.txt')[:500]

LAYER_TO_ABLATE = 4
NEURONS_TO_ABLATE = [663]

german_data = uspto_data
english_data = non_uspto_data


data/uspto.txt: Loaded 3509 examples with 0 to 12506 characters each.
data/non_uspto.txt: Loaded 85495 examples with 0 to 61007 characters each.


In [4]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)


english_activations = {}
german_activations = {}
for layer in range(4, 5):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)


MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

## Find top common German tokens

In [5]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

  0%|          | 0/500 [00:00<?, ?it/s]

[' the', ' a', ' of', ' to', ' and', ' is', ' in', ' for', ' an', ' are', ' as', ' be', ' by', ' or', ' which', ' with', ' that', ' on', ' such', ' The', ' from', ' at', ' can', ' invention', ' data', ' fluid', ' method', ' may', '2', ' it', ' has', 'phase', ' one', ' not', ' In', ' present', ' have', ' example', ' this', ' been', ' when', ' used', ' system', ' device', ' metal', ' motor', ' into', ' light', ' other', ' cell', ' having', ' part', ' chamber', 'ting', ' layer', '.,', ' speed', 'x', ' will', ' use', 'The', ' through', ' high', ' processing', ' 1', ' micro', ' water', ' mobile', ' jet', ' U', ' order', ' gas', ' A', ' there', ' process', '1', ' using', ' material', ' more', ' signal', ' between', ' power', ' current', ' portion', 'In', ' 3', ' projection', ' devices', ' position', ' apparatus', ' also', ' hand', ' body', 'ing', ' each', ' provided', ' electron', ' nozzle', 'A', ' time']


In [None]:
def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    loss_diff = (ablated_loss - original_loss) # High ablation loss increase
    activated_component_power = (only_activated_loss - original_loss) # Low loss increase from activated component/s
    activated_component_power[activated_component_power < 0] = 0
    combined = 0.5*loss_diff - activated_component_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

def print_prompt(prompt: str):
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=None, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks,
                                                                                                                    activated_components=('blocks.5.hook_mlp_out',), deactivated_components=('blocks.5.hook_attn_out',))
    pos_wise_diff = interest_measure(original_loss, ablated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, max_value=5, additional_measures=loss_list, additional_measure_names=loss_names)

In [7]:
def get_mlp5_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

german_losses = []
for prompt in german_data:
    prompt_losses = haystack_utils.get_direct_effect(prompt, model, pos=None, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks,
                                                     activated_components=('blocks.5.hook_mlp_out',), deactivated_components=('blocks.5.hook_attn_out',))
    german_losses.append(prompt_losses)
measure = get_mlp5_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)

In [8]:
for index, _ in sorted_measure[:3]:
    print_prompt(german_data[index]) 

In [9]:
average_loss_plot = haystack_utils.get_average_loss_plot_method(activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks, "MLP4",
                                                                activated_components = ("blocks.5.hook_mlp_out",),
                                                                deactivated_components = ("blocks.5.hook_attn_out",))

prompt = "As shown in FIG"
average_loss_plot([prompt], model, token="FIG")

In [None]:
prompt = "As shown in FIG"
average_loss_plot([prompt], model, token="FIG")

## Analysis of ngrams preceded by random prompts

In [10]:
def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

def replace_column(prompts: Int[Tensor, "n_prompts n_tokens"], token_index: int):
    # Replaces a specific token position in a batch of prompts with random common German tokens
    new_prompts = prompts.clone()
    random_tokens = get_random_selection(top_tokens[:max(50, prompts.shape[0])], n=prompts.shape[0]).cuda()
    new_prompts[:, token_index] = random_tokens
    return new_prompts 

def loss_analysis(prompts: Tensor, title=""):
    # Loss plot for a batch of prompts
    names = ["Original", "Ablated", "MLP5 path patched"]
    original_loss, ablated_loss, _, only_activated_loss = \
        haystack_utils.get_direct_effect(prompts, model, pos=-1,
                                        context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                        context_activation_hooks=activate_neurons_fwd_hooks, 
                                        )
    haystack_utils.plot_barplot([original_loss.tolist(), ablated_loss.tolist(), only_activated_loss.tolist()], names, ylabel="Loss", title=title)

In [11]:
def loss_analysis_random_prompts(end_string, n=50, length=12, replace_columns: list[int] | None = None):
    # Loss plot for a batch of random prompts ending with a specific ngram and optionally replacing specific tokens
    prompts = generate_random_prompts(end_string, n=n, length=length)
    title=f"Average last token loss on {length} random tokens ending in '{end_string}'"
    if replace_columns is not None:
        replaced_tokens = model.to_str_tokens(prompts[0, replace_columns])
        title += f" replacing {replaced_tokens}"
        for column in replace_columns:
            prompts = replace_column(prompts, column)
    
    loss_analysis(prompts, title=title)

In [12]:
# Example usage
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20, replace_columns=[-2])

## Widget

In [13]:
# Create a dropdown menu widget
dropdown = widgets.Dropdown(
    options = [" Vorschlägen", " Vorschläge", " häufig", " schließt", " beweglich"], 
    value = " Vorschlägen",
    description = 'Ngram: ',
)

replace_columns_dropdown = widgets.SelectMultiple(
    options = ['-2', '-3', '-4', 'None'],
    value = ['None'],  # default selected value
    description = 'Replace Columns:',
)

# Create an output widget to hold the plot
output = widgets.Output()

# Define a function to call when the widget's value changes
def update_plot(*args):
    # Clear the old plot from the output
    with output:
        clear_output(wait=True)

        if 'None' in replace_columns_dropdown.value:
            replace_columns = None
        else:
            # If 'None' not selected, convert the selected values to integers
            replace_columns = [int(val) for val in replace_columns_dropdown.value]
        
        # Call your function with the values from the widgets
        loss_analysis_random_prompts(dropdown.value, n=100, length=20, replace_columns=replace_columns)

# Set the function to be called when the widget's value changes
dropdown.observe(update_plot, 'value')
replace_columns_dropdown.observe(update_plot, 'value')

# Display the widget and the output
display(dropdown, replace_columns_dropdown, output)

# Run once at startup
update_plot()


Dropdown(description='Ngram: ', options=(' Vorschlägen', ' Vorschläge', ' häufig', ' schließt', ' beweglich'),…

SelectMultiple(description='Replace Columns:', index=(3,), options=('-2', '-3', '-4', 'None'), value=('None',)…

Output()

## Identify relevant neurons

In [14]:
prompts = generate_random_prompts(" Vorschlägen", n=100, length=20)
print(prompts.shape)

torch.Size([100, 24])


In [15]:
# Calculate neuron-wise loss change
with model.hooks(deactivate_neurons_fwd_hooks):
    _, ablated_cache = model.run_with_cache(prompts)

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache['blocks.5.mlp.hook_post'][:, :, neuron]
        return value
    return [('blocks.5.mlp.hook_post', ablate_neurons_hook)]

diffs = torch.zeros(2048, prompts.shape[0])
# Loss with path patched MLP5 neurons
_, _, _, baseline_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
for neuron in tqdm(range(2048)):
    ablate_single_neuron_hook = get_ablate_neurons_hook(neuron, ablated_cache)
    # Loss with path patched MLP5 neurons but a single neuron changed back to original ablated value
    _, _, _, only_activated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_single_neuron_hook)
    diffs[neuron] = only_activated_loss - baseline_loss

print(diffs.mean())

  0%|          | 0/2048 [00:00<?, ?it/s]

In [None]:
sorted_means, indices = torch.sort(diffs.mean(1))
sorted_means = sorted_means.tolist()
haystack_utils.line(sorted_means, xlabel="Sorted neurons", ylabel="Loss change", title="Loss change from ablating MLP5 neuron") # xticks=indices

In [None]:
# Check loss change when ablating top / bottom neurons

top_neurons_count = 10
top_neurons = indices[-top_neurons_count:]
bottom_neurons = indices[:top_neurons_count]

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss")

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_loss, ablated_loss, _, all_MLP5_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
_, _, _, top_MLP5_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook)
_, _, _, bottom_MLP5_ablated_loss = haystack_utils.get_direct_effect(prompts, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook)

names = ["Original", "Ablated", "MLP5 path patched", f"MLP5 path patched + Top {top_neurons_count} MLP5 neurons ablated", f"MLP5 path patched + Bottom {top_neurons_count} MLP5 neurons ablated"]
short_names = ["Original", "Ablated", "MLP5 path patched", f"Top MLP5 removed", f"Bottom MLP5 removed"]

values = [original_loss.tolist(), ablated_loss.tolist(), all_MLP5_loss.tolist(), top_MLP5_ablated_loss.tolist(), bottom_MLP5_ablated_loss.tolist()]
haystack_utils.plot_barplot(values, names, short_names=short_names, ylabel="Loss", title=f"Average last token loss when removing top / bottom neurons from path patching")

## Investigate top / bottom neuron boosts

In [None]:
# Patch in MLP5 neurons, ablated bottom neurons, compare with patched MLP5 neurons without ablating bottom neurons

with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_logits, ablated_cache = model.run_with_cache(prompts)

ablate_top_neurons_hook = get_ablate_neurons_hook(top_neurons, ablated_cache)
ablate_bottom_neurons_hook = get_ablate_neurons_hook(bottom_neurons, ablated_cache)

original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
_, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
_, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

In [None]:
def get_top_difference_neurons(baseline_logprobs, ablated_logprops, positive=True, logprob_threshold=-7, k=50):
    neuron_logprob_difference = (baseline_logprobs - ablated_logprops).mean(0)
    neuron_logprob_difference[baseline_logprobs.mean(0) < logprob_threshold] = 0
    if positive:
        non_zero_count = (neuron_logprob_difference > 0).sum()
    else:
        non_zero_count = (neuron_logprob_difference < 0).sum()
    top_logprob_difference, top_neurons = haystack_utils.top_k_with_exclude(neuron_logprob_difference, min(non_zero_count, k), all_ignore, largest=positive)
    return top_logprob_difference, top_neurons

In [None]:
# Boosted tokens
bottom_neuron_pos_difference_logprobs, bottom_pos_indices = get_top_difference_neurons(all_MLP5_logprobs, bottom_MLP5_ablated_logprobs, positive=True)
top_neuron_pos_difference_logprobs, top_pos_indices = get_top_difference_neurons(all_MLP5_logprobs, top_MLP5_ablated_logprobs, positive=True)
# Deboosted tokens
bottom_neuron_neg_difference_logprobs, bottom_neg_indices = get_top_difference_neurons(all_MLP5_logprobs, bottom_MLP5_ablated_logprobs, positive=False)
top_neuron_neg_difference_logprobs, top_neg_indices = get_top_difference_neurons(all_MLP5_logprobs, top_MLP5_ablated_logprobs, positive=False)

In [None]:
def plot_data(boosted_deboosted, top_bottom):
    if boosted_deboosted == 'Boosted' and top_bottom == 'Top':
        logprobs = top_neuron_pos_difference_logprobs
        indices = top_pos_indices
        title_change = 'increase'
    elif boosted_deboosted == 'Boosted' and top_bottom == 'Bottom':
        logprobs = bottom_neuron_pos_difference_logprobs
        indices = bottom_pos_indices
        title_change = 'increase'
    elif boosted_deboosted == 'Deboosted' and top_bottom == 'Top':
        logprobs = top_neuron_neg_difference_logprobs
        indices = top_neg_indices
        title_change = 'decrease'
    else:  # 'Deboosted' and 'Bottom'
        logprobs = bottom_neuron_neg_difference_logprobs
        indices = bottom_neg_indices
        title_change = 'decrease'

    xlabel = boosted_deboosted + " tokens"
    ylabel = "full_logprob - ablated_logprop"
    title = 'Logprob ' + title_change + ' from ' + top_bottom.lower() + ' neurons'
    xticks = [model.to_str_tokens([i])[0] for i in indices]

    haystack_utils.line(logprobs.cpu().numpy(), xlabel=xlabel, ylabel=ylabel, title=title, xticks=xticks, show_legend=False)


boosted_deboosted = widgets.Dropdown(options=['Boosted', 'Deboosted'], value='Boosted', description='Tokens:')
top_bottom = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
widgets.interactive(plot_data, boosted_deboosted=boosted_deboosted, top_bottom=top_bottom)


interactive(children=(Dropdown(description='Tokens:', options=('Boosted', 'Deboosted'), value='Boosted'), Drop…

## Decompose neuron effects

In [None]:
# Get logprob change for individual neuron

def get_individual_neuron_logprob_effect(sorted_indices, neuron_pos=0, top=True, positive=True, plot=True):
    if top:
        neurons = sorted_indices[-(neuron_pos+1)]
    else:
        neurons = sorted_indices[neuron_pos]
    neurons_hook = get_ablate_neurons_hook([neurons], ablated_cache)

    original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
    _, _, _, neuron_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+neurons_hook, return_type='logprobs')
    neuron_difference_logprobs, difference_indices = \
        get_top_difference_neurons(all_MLP5_logprobs, neuron_ablated_logprobs, positive=positive)
    
    if not plot:
        return neuron_difference_logprobs, difference_indices
    
    if positive:
        xlabel = 'Boosted tokens'
        title = 'Logprob increase from neuron ' + str(neurons.item())
    else: 
        xlabel = 'Deboosted tokens'
        title = 'Logprob decrease from neuron ' + str(neurons.item())
    if top:
        title += f' (top {neuron_pos+1} neuron)'
    else:
        title += f' (bottom {neuron_pos+1} neuron)'

    xticks = [model.to_str_tokens([i])[0] for i in difference_indices]
    haystack_utils.line(neuron_difference_logprobs.cpu().numpy(), xlabel=xlabel, ylabel="full_logprob - ablated_logprop", title=title, xticks=xticks, show_legend=False)

In [None]:
# Define a function to call your function with widget inputs
def plot_individual_neuron(neuron_pos, top_bottom, pos_neg):
    top_bool = True if top_bottom == 'Top' else False
    positive_bool = True if pos_neg == 'Positive' else False
    get_individual_neuron_logprob_effect(indices, neuron_pos=neuron_pos, top=top_bool, positive=positive_bool)

# Define widgets
neuron_pos_slider = widgets.IntSlider(min=0, max=20, step=1, value=0, description='Neuron Pos:')
top_bottom_dropdown = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
pos_neg_dropdown = widgets.Dropdown(options=['Positive', 'Negative'], value='Positive', description='Pos/Neg:')

# Use interactive function to bind the widgets to the plotting function
widgets.interactive(plot_individual_neuron, neuron_pos=neuron_pos_slider, top_bottom=top_bottom_dropdown, pos_neg=pos_neg_dropdown)


interactive(children=(IntSlider(value=0, description='Neuron Pos:', max=20), Dropdown(description='Neurons:', …

In [None]:
# Sum individual weighted positive / negative effects
def summed_neuron_differences(num_neurons=10, top=True, positive=True, k=50):
    summed_neuron_diffs = torch.zeros(model.cfg.d_vocab).cuda()
    for neuron_pos in range(num_neurons):
        # It would probably be more principle to do our filtering on the summed diffs instead of per neuron
        neuron_diff, token_indices = get_individual_neuron_logprob_effect(indices, neuron_pos=neuron_pos, top=top, positive=positive, plot=False)
        summed_neuron_diffs[token_indices] += neuron_diff
    if positive:
        non_zero_count = (summed_neuron_diffs > 0).sum()
    else:
        non_zero_count = (summed_neuron_diffs < 0).sum()
    top_logprob_difference, top_tokens = haystack_utils.top_k_with_exclude(summed_neuron_diffs, min(non_zero_count, k), all_ignore, largest=positive)
    return top_logprob_difference, top_tokens

def plot_summed_neuron_differences(num_neurons=10, top=True, positive=True):
    top_logprob_difference, top_logprob_tokens = summed_neuron_differences(top=top, positive=positive)
    xticks = [model.to_str_tokens([i])[0] for i in top_logprob_tokens]
    
    if positive:
        title = "Summed individual boosts by"
        xlabel = "Boosted tokens"
    else:
        title = "Summed individual deboosts by"
        xlabel = "Deboosted tokens"
    if top:
        title += f" top {num_neurons} neurons"
    else:
        title += f" bottom {num_neurons} neurons"
    haystack_utils.line(top_logprob_difference.cpu().numpy(), xlabel=xlabel, ylabel="full_logprob - ablated_logprop", title=title, xticks=xticks, show_legend=False)

In [None]:
# Define a function to call your function with widget inputs
def plot_summed_neurons_widget(num_neurons, top_bottom, pos_neg):
    top_bool = True if top_bottom == 'Top' else False
    positive_bool = True if pos_neg == 'Positive' else False
    plot_summed_neuron_differences(num_neurons=num_neurons, top=top_bool, positive=positive_bool)

# Define widgets
num_neuron_slider = widgets.IntSlider(min=1, max=20, step=1, value=10, description='Num Neurons:')
top_bottom_dropdown = widgets.Dropdown(options=['Top', 'Bottom'], value='Top', description='Neurons:')
pos_neg_dropdown = widgets.Dropdown(options=['Positive', 'Negative'], value='Positive', description='Pos/Neg:')

# Use interactive function to bind the widgets to the plotting function
widgets.interactive(plot_summed_neurons_widget, num_neurons=num_neuron_slider, top_bottom=top_bottom_dropdown, pos_neg=pos_neg_dropdown)

interactive(children=(IntSlider(value=10, description='Num Neurons:', max=20, min=1), Dropdown(description='Ne…

## Brainstorm notes

I want to decompose an MLP5 neuron's effect into its boost to the correct logit and its deboost of other logits. I want to discover how these two effects change the log prob.

metric like loss reduction vs. token boost

~~run direct effect, patch each neuron in top 10 individually, get overall loss reduction from neuron controlled by context neuron (equivalent to logprob increase for correct token)~~
decompose loss reduction into two parts:
run direct effect, patch the correct token logit boost from the neuron (removing the neuron's other effects), get overall loss reduction from neuron (equivalent to logprob increase for correct token)
destructive interference loss reduction = overall loss reduction - loss reduction from correct token logit boost (component of the logprob increase for correct token due to it deboosting incorrect token)

Patch the correct token logit boost from the neuron (removing the neuron's other effects).
1. Get baseline logprobs for a prompt
2. Get difference in logits from activating the neuron under test. Can use get_direct_effects with return_type='logits'.
3. Run the model with return_type='logits' and the neuron under test zero ablated. 
4. Add the correct token logit from step 1 to a copy of the output logits. Convert to logprobs
5. Add the incorrect token logits from step 1 to a copy of the output logits. Convert to logprobs
6. Compare A. lobprobs with correct answer token logit increase, B. logprobs with incorrect answer token logit increases, and C. baseline logprobs

~~If the context neuron gives each neuron a flat boost then if we decompose the resulting flat boost to one MLP5 neuron into a boost to one logit vs. boosts to all other logits it will change the logprobs (first will increase answer probability and second will reduce answer probability). 

the two resulting log probs at the correct answer token won't add up to the original log probs (?). 

If the boost is flat the correct percentage decomposition is 1/50000 and 49999/50000? In practice/all other factors being equal

New plan:

Difference between baseline log prob and neuron log prob?
Classify individual neurons by percentage constructive vs destructive by looking at their log probs and summing the incorrect token log probs

Correct log prob difference
Incorrect log prob difference (summed over every plausible token?)

Largest boost

In [None]:
# If the log prob for an incorrect token is significantly lower then that's where the extra probability density on the correct answer is coming from 
# Constructive interference increases correct token log prob and uniformly decreases other log probs
# Destructive interference decreases specific other log probs and uniformly increases other log probs
 
# Most neurons are a mixture of the above
# Decompose neurons into what % of their effect is each

In [None]:
# original_logits, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')
# _, _, _, top_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_top_neurons_hook, return_type='logprobs')
# _, _, _, bottom_MLP5_ablated_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks+ablate_bottom_neurons_hook, return_type='logprobs')

# bottom_neuron_high_difference_logprobs = (all_MLP5_logprobs - bottom_MLP5_ablated_logprobs)
# bottom_neuron_high_difference_logprobs[bottom_neuron_high_difference_logprobs < 0] = 0
# bottom_neuron_high_difference_logprobs = bottom_neuron_high_difference_logprobs.mean(0)
# bottom_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7] = 0

# bottom_non_zero_count = (bottom_neuron_high_difference_logprobs > 0).sum()
# bottom_neuron_high_difference_logprobs, bottom_indices = haystack_utils.top_k_with_exclude(bottom_neuron_high_difference_logprobs, min(bottom_non_zero_count, 50), all_ignore)
# haystack_utils.line(bottom_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in bottom_indices])


# top_neuron_high_difference_logprobs = (all_MLP5_logprobs - top_MLP5_ablated_logprobs)
# top_neuron_high_difference_logprobs[top_neuron_high_difference_logprobs < 0] = 0
# top_neuron_high_difference_logprobs = top_neuron_high_difference_logprobs.mean(0)
# top_neuron_high_difference_logprobs[all_MLP5_logprobs.mean(0) < -7] = 0

# top_non_zero_count = (top_neuron_high_difference_logprobs > 0).sum()
# top_neuron_high_difference_logprobs, top_indices = haystack_utils.top_k_with_exclude(top_neuron_high_difference_logprobs, min(top_non_zero_count, 50), all_ignore)
# haystack_utils.line(top_neuron_high_difference_logprobs.cpu().numpy(), title='Largest positive difference in log probs for tokens when bottom neurons are not ablated', xticks=[model.to_str_tokens([i])[0] for i in top_indices])
