## Setup

In [82]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")


english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

## Find top common German tokens

In [11]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

  0%|          | 0/2000 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', 'ch', 'st', ' zu', 're', ' für', 'äsident', ' Pr', 'n', 'z', 'ischen', ' von', 'ü', 't', 'icht', 'in', 'ge', 'gen', 'te', ' ist', ' auf', 'ig', ' über', ' dass', ' im', 'f', ' er', 'es', ' das', 'men', 'g', 'ß', ' Europ', ' w', 'w', 'le', 'ten', ' eine', ' wir', ' ein', ' an', 'hen', 'ren', 'e', ' ich', 'ungen', ' W', ' Ver', ' B', ' dem', ' mit', ' dies', ' nicht', ' Z', 'h', ' z', 's', 'it', 'hr', ' es', ' zur', ' An', ' Herr', 'ich', 'heit', 'b', 'lich', 'l', ' ver', ' S', 'i', ' G', 'Der', ' V', 'der', 'u', 'ie', ' Ab', 'ungs', 'chte', 'chaft', 'igen', ' werden', 'uss', 'ord', 'em', ' Ber', 'ür', ' haben', 'et', ' um', ' Ich']


## Analysis of ngrams preceded by random prompts

In [72]:
def get_random_selection(tensor, n=12):
    # Hacky replacement for np.random.choice
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    # Generate a batch of random prompts ending with a specific ngram
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

def replace_column(prompts: Int[Tensor, "n_prompts n_tokens"], token_index: int):
    # Replaces a specific token position in a batch of prompts with random common German tokens
    new_prompts = prompts.clone()
    random_tokens = get_random_selection(top_tokens[:max(50, prompts.shape[0])], n=prompts.shape[0]).cuda()
    new_prompts[:, token_index] = random_tokens
    return new_prompts 

def loss_analysis(prompts: Tensor, title=""):
    # Loss plot for a batch of prompts
    names = ["Original", "Ablated", "MLP5 path patched"]
    original_loss, ablated_loss, _, only_activated_loss = \
        haystack_utils.get_direct_effect(prompts, model, pos=-1,
                                        context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                        context_activation_hooks=activate_neurons_fwd_hooks, 
                                        )
    haystack_utils.plot_barplot([original_loss.tolist(), ablated_loss.tolist(), only_activated_loss.tolist()], names, ylabel="Loss", title=title)

In [78]:
def loss_analysis_random_prompts(end_string, n=50, length=12, replace_columns: list[int] | None = None):
    # Loss plot for a batch of random prompts ending with a specific ngram and optionally replacing specific tokens
    prompts = generate_random_prompts(end_string, n=n, length=length)
    title=f"Average last token loss on {length} random tokens ending in '{end_string}'"
    if replace_columns is not None:
        replaced_tokens = model.to_str_tokens(prompts[0, replace_columns])
        title += f" replacing {replaced_tokens}"
        for column in replace_columns:
            prompts = replace_column(prompts, column)
    
    loss_analysis(prompts, title=title)

In [86]:
# Example usage
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)
# loss_analysis_random_prompts(" Vorschlägen", n=100, length=20, replace_columns=[-2])

## Widget

In [85]:
# Create a dropdown menu widget
dropdown = widgets.Dropdown(
    options = [" Vorschlägen", " häufig", " schließt", " beweglich"],
    value = " Vorschlägen",
    description = 'Ngram: ',
)

replace_columns_dropdown = widgets.SelectMultiple(
    options = ['-2', '-3', '-4', 'None'],
    value = ['None'],  # default selected value
    description = 'Replace Columns:',
)

# Create an output widget to hold the plot
output = widgets.Output()

# Define a function to call when the widget's value changes
def update_plot(*args):
    # Clear the old plot from the output
    with output:
        clear_output(wait=True)

        if 'None' in replace_columns_dropdown.value:
            replace_columns = None
        else:
            # If 'None' not selected, convert the selected values to integers
            replace_columns = [int(val) for val in replace_columns_dropdown.value]
        
        # Call your function with the values from the widgets
        loss_analysis_random_prompts(dropdown.value, n=100, length=20, replace_columns=replace_columns)

# Set the function to be called when the widget's value changes
dropdown.observe(update_plot, 'value')
replace_columns_dropdown.observe(update_plot, 'value')

# Display the widget and the output
display(dropdown, replace_columns_dropdown, output)

# Run once at startup
update_plot()


Dropdown(description='Ngram: ', options=(' Vorschlägen', ' häufig', ' schließt', ' beweglich'), value=' Vorsch…

SelectMultiple(description='Replace Columns:', index=(3,), options=('-2', '-3', '-4', 'None'), value=('None',)…

Output()

## Neuron level analysis