In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from typing import List, Tuple
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [8]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")


english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)


LOG_PROB_THRESHOLD = -7
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [15]:
german_losses = []
for prompt in tqdm(german_data):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=None, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
    german_losses.append((original_loss, ablated_loss, context_and_activated_loss, only_activated_loss))

  0%|          | 0/2000 [00:00<?, ?it/s]

In [23]:
def get_mlp5_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        pos_wise_diff = ablated_loss - only_activated_loss
        measure.append(pos_wise_diff.max().item())
    return measure

measure = get_mlp5_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)

In [None]:
def print_prompt(index: int, losses):
    prompt = german_data[index]
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = losses[index]

    pos_wise_diff = (ablated_loss - only_activated_loss).flatten().cpu().tolist()
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, max_value=5, )

In [9]:
def show_token_loss(prompt: str, model: HookedTransformer, max_value = 5):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=None, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))

    # Look for high difference in ablated loss - 
    pos_wise_difference = (ablated_loss-only_activated_loss)
    haystack_utils.print_strings_as_html(str_token_prompt[1:], pos_wise_difference.flatten().cpu().tolist(), max_value=max_value)

In [10]:
show_token_loss(german_data[0], model)

In [12]:
prompt = "Ich möchte nochmals meine Ansicht"

In [14]:
original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(prompt, model, pos=-1, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks)
original_loss, ablated_loss, context_and_activated_loss, only_activated_loss

(4.7153401374816895, 12.92672348022461, 9.561723709106445, 8.331337928771973)

In [13]:
original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, frozen_loss_without_MLP4 = haystack_utils.pos_wise_mlp_effect_on_single_prompt(prompt, model, k=2048, ablation_hooks=deactivate_neurons_fwd_hooks, log=False, answer_pos=-1, return_mlp4_less_mlp5=True)
original_losses, total_effect_losses, direct_mlp3_mlp5_losses, direct_mlp3_losses, frozen_losses, frozen_loss_without_MLP4

(4.7153401374816895,
 12.92672348022461,
 7.974599838256836,
 3.314662218093872,
 6.4687347412109375,
 8.842937469482422)