### Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
haystack_utils.clean_cache()
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=100)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

In [6]:
german_neurons_with_f1 = [
    [5, 2649, 1.0],
    [8,	2994, 1.0],
    [11, 2911, 0.99],
    [10, 1129, 0.97],
    [6, 1838, 0.65],
    [7, 1594, 0.65],
    [11, 1819, 0.61],
    [11, 2014, 0.56],
    [10, 753, 0.54],
    [11, 205, 0.48],
]

english_activations = {}
german_activations = {}
for layer in set([layer for layer, *_ in german_neurons_with_f1]):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

for item in german_neurons_with_f1:
    layer, neuron, f1 = item
    item.append(german_activations[layer][:, neuron].mean(0))
    item.append(english_activations[layer][:, neuron].mean(0))

def get_neuron_hook(layer, neuron, act_value):
    def neuron_hook(value, hook):
        value[:, :, neuron] = act_value
        return value
    return (f'blocks.{layer}.mlp.hook_post', neuron_hook)

deactivate_context_hooks = [get_neuron_hook(5, 2649, english_activations[5][:, 2649].mean()), get_neuron_hook(8, 2994, english_activations[8][:, 2994].mean())]
activate_context_hooks = [get_neuron_hook(5, 2649, german_activations[5][:, 2649].mean()), get_neuron_hook(8, 2994,  german_activations[8][:, 2994].mean())]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [4]:
def mlp_effects_german(prompt, index):
        """Customised to L5 and L8 context neurons"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]
     
        original, ablated, direct_effect, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components=tuple(downstream_components), activated_components=("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",))
        
        data = [original, ablated, direct_effect]
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_mlp_out",))
                data.append(activated_component_loss)
        return data

def attn_effects_german(prompt, index):
        """Customised to L5 and L8 context neurons"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]

        data = []
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_attn_out",))
                data.append(activated_component_loss)
        return data

def component_analysis(end_strings: list[str] | str):
    if isinstance(end_strings, str):
        end_strings = [end_strings]
    for end_string in end_strings:
        print(model.to_str_tokens(end_string))
        random_prompts = haystack_utils.generate_random_prompts(end_string, model, common_tokens, 400, length=20)
        data = mlp_effects_german(random_prompts, -1)

        haystack_utils.plot_barplot([[item.cpu().flatten().mean().item()] for item in data],
                                        names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP"]], # + ["MLP9 + MLP11"]
                                        title=f'Loss increases from ablating various MLP components for end string \"{end_string}\"')
        
def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    """Per-token measure, mixture of overall loss increase and loss increase from ablating MLP11"""
    loss_diff = (ablated_loss - original_loss) # Loss increase from context neuron
    mlp_11_power = (only_activated_loss - original_loss) # Loss increase from MLP11
    mlp_11_power[mlp_11_power < 0] = 0
    combined = 0.5 * loss_diff - mlp_11_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

def print_prompt(prompt: str):
    """Red/blue scale showing the interest measure for each token"""
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))

    pos_wise_diff = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, additional_measures=loss_list, additional_measure_names=loss_names)

def get_mlp11_decrease_measure(losses: list[tuple[Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"], Float[Tensor, "pos"]]]):
    """Token with max interest measure"""
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

def left_pad(prompts, model):
    tokens = model.to_tokens(prompts)
    target_length = tokens.shape[1]

    results = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)[0]
        padded_tokens = torch.cat([torch.zeros((target_length - tokens.shape[0],), dtype=int).cuda(), tokens])
        results.append(padded_tokens)

    return torch.stack(results)

def print_counter(token: str, data: list[str], next_tokens_count = 1):
    counter = Counter()
    token_index = model.to_single_token(token)
    for prompt in data:
        tokens = model.to_tokens(prompt)[0]
        try: 
            index = tokens.tolist().index(token_index)
        except:
            continue
        if index + next_tokens_count < len(tokens):
            next_tokens = tokens[index : index + next_tokens_count + 1]
            next_tokens_str = "".join(model.to_str_tokens(next_tokens))
            counter.update([next_tokens_str])
    print(counter)

In [None]:
def mlp_language_logprob_diffs(prompt, index, german_token, french_token):
        """Customised to L5 and L8 context neurons. 
        Reverses the German context activation hooks for use on French text - the ablation hook activates the German context"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]
     
        original_logprobs, ablated_logprobs, direct_effect_logprobs, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                deactivated_components=tuple(downstream_components), activated_components=("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",),
                return_type='logprobs')
        
        data = [original_logprobs, ablated_logprobs, direct_effect_logprobs]
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=activate_context_hooks, context_activation_hooks=[],
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_mlp_out",), return_type='logprobs')
                data.append(activated_component_loss)

        data = [item[:, french_token] - item[:, german_token] for item in data]
        return data

### ord -> n

In [None]:
component_analysis([' legen'])
component_analysis([' die Bräsidentin'])
component_analysis([' Anerkennung'])