In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")


english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [11]:
# Get top common german tokens excluding punctuation
token_counts = torch.zeros(model.cfg.d_vocab).cuda()
for example in tqdm(german_data):
    tokens = model.to_tokens(example)
    for token in tokens[0]:
        token_counts[token.item()] += 1

punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
leading_space_punctuation = [" " + char for char in punctuation]
punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
token_counts[punctuation_tokens] = 0
token_counts[all_ignore] = 0

top_counts, top_tokens = torch.topk(token_counts, 100)
print(model.to_str_tokens(top_tokens[:100]))

  0%|          | 0/2000 [00:00<?, ?it/s]

[' der', 'en', ' die', ' und', 'ung', 'ä', ' in', ' den', ' des', 'ch', 'st', ' zu', 're', ' für', 'äsident', ' Pr', 'n', 'z', 'ischen', ' von', 'ü', 't', 'icht', 'in', 'ge', 'gen', 'te', ' ist', ' auf', 'ig', ' über', ' dass', ' im', 'f', ' er', 'es', ' das', 'men', 'g', 'ß', ' Europ', ' w', 'w', 'le', 'ten', ' eine', ' wir', ' ein', ' an', 'hen', 'ren', 'e', ' ich', 'ungen', ' W', ' Ver', ' B', ' dem', ' mit', ' dies', ' nicht', ' Z', 'h', ' z', 's', 'it', 'hr', ' es', ' zur', ' An', ' Herr', 'ich', 'heit', 'b', 'lich', 'l', ' ver', ' S', 'i', ' G', 'Der', ' V', 'der', 'u', 'ie', ' Ab', 'ungs', 'chte', 'chaft', 'igen', ' werden', 'uss', 'ord', 'em', ' Ber', 'ür', ' haben', 'et', ' um', ' Ich']


In [63]:
def get_random_selection(tensor, n=12):
    return tensor[torch.randperm(len(tensor))[:n]]

def generate_random_prompts(end_string, n=50, length=12):
    end_tokens = model.to_tokens(end_string).flatten()[1:]
    prompts = []
    for i in range(n):
        prompt = get_random_selection(top_tokens[:max(50, length)], n=length).cuda()
        prompt = torch.cat([prompt, end_tokens])
        prompts.append(prompt)
    prompts = torch.stack(prompts)
    return prompts

def replace_column(prompts: Int[Tensor, "n_prompts n_tokens"], token_index: int):
    new_prompts = prompts.clone()
    random_tokens = get_random_selection(top_tokens[:max(50, prompts.shape[0])], n=prompts.shape[0]).cuda()
    print(new_prompts.shape, random_tokens.shape)
    new_prompts[:, token_index] = random_tokens
    return new_prompts 

def loss_analysis(prompts: Tensor, title=""):
    names = ["Original", "Ablated", "MLP5 path patched"]
    original_loss, ablated_loss, _, only_activated_loss = \
        haystack_utils.get_direct_effect(prompts, model, pos=-1,
                                        context_ablation_hooks=deactivate_neurons_fwd_hooks, 
                                        context_activation_hooks=activate_neurons_fwd_hooks, 
                                        )
    haystack_utils.plot_barplot([original_loss.tolist(), ablated_loss.tolist(), only_activated_loss.tolist()], names, ylabel="Loss", title=title)

In [64]:
def loss_analysis_random_prompts(end_string, n=50, length=12, replace_columns: list[int] | None = None):
    prompts = generate_random_prompts(end_string, n=n, length=length)
    title=f"Average last token loss on {length} random tokens ending in '{end_string}'"
    if replace_columns is not None:
        replaced_tokens = model.to_str_tokens(prompts[0, replace_columns])
        title += f" replacing {replaced_tokens}"
        for column in replace_columns:
            prompts = replace_column(prompts, column)
    
    loss_analysis(prompts, title=title)

In [65]:
loss_analysis_random_prompts(" Vorschlägen", n=100, length=20)

In [66]:
loss_analysis_random_prompts(" Vorschlägen", n=100, length=20, replace_columns=[-2])

torch.Size([100, 24]) torch.Size([100])
