In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from tqdm.auto import tqdm
import plotly.io as pio
import pandas as pd
import numpy as np
import plotly.express as px 
import pickle

import haystack_utils
import hook_utils
import plotting_utils
import probing_utils
from probing_utils import get_and_score_new_word_probe
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
from sklearn.datasets import make_classification
from concept_erasure import LeaceEraser

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

%reload_ext autoreload
%autoreload 2

model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

LAYER, NEURON = 8, 2994
hook_name = f'blocks.{8}.mlp.hook_post'

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [None]:
with open(f'data/pythia_160m/layer_8/single_neurons_df.pkl', 'rb') as f:
    one_sparse_probe_scores_df = pickle.load(f)

top_one_scores = one_sparse_probe_scores_df.sort_values(by='mcc', ascending=False).head(10)
print(top_one_scores)

with open(f'data/pythia_160m/layer_8/probes_two_sparse_df_10_mcc.pkl', 'rb') as f:
    two_sparse_probe_scores_df = pickle.load(f)

top_two_scores_mcc = two_sparse_probe_scores_df.sort_values(by='mcc', ascending=False).head(10)
print(top_two_scores_mcc)

with open(f'data/pythia_160m/layer_8/probes_two_sparse_df_10_f1.pkl', 'rb') as f:
    two_sparse_probe_scores_df = pickle.load(f)

top_two_scores_f1 = two_sparse_probe_scores_df.sort_values(by='f1', ascending=False).head(10)
print(top_two_scores_f1)

In [None]:
activation_slice = np.s_[0, :-1, [1426]]
x, y = probing_utils.get_new_word_labels_and_activations(model, german_data, hook_name, activation_slice)
probe = probing_utils.get_probe(x, y)


In [None]:
# Train probe without scaling X to get an idea of what the activation threshold is for classification
# activation_slice = np.s_[0, :-1, :]
# unscaled_x, y = probing_utils.get_new_word_labels_and_activations(model, german_data, hook_name, activation_slice, scale_x=False)
# probe = probing_utils.get_probe(unscaled_x[:20_000, [1426]], y[:20_000])

# default classifier output threshold is 0.5
# print(probe.predict(unscaled_x[:30, [1426]]))
# print(probe.intercept_)
# print(probe.coef_[0])

In [None]:
# Best neuron seems to act as a German retokenization context neuron
plotting_utils.plot_neuron_acts(model, german_data, [[8, 1426]])

In [5]:
activation_slice = np.s_[0, :-1, :]
x, y = probing_utils.get_new_word_labels_and_activations(model, german_data, hook_name, activation_slice)
eraser = probing_utils.get_leace_eraser(x, y)

# d = 3071
# k = 1

# x = torch.randn(1, d)
# bias = torch.randn(d)
# proj_left = torch.randn(d, k)
# proj_right = torch.randn(d, k)

# def test(x):
#     """Apply the projection to the input tensor."""
#     delta = x - bias if bias is not None else x

#     # Ensure we do the matmul in the most efficient order.
#     x_ = x - (delta @ proj_right.T) @ proj_left.T
#     return x_.type_as(x)

# test(x)

LeaceEraser(proj_left=tensor([[ 0.1065],
        [ 0.3068],
        [ 0.3447],
        ...,
        [ 0.1113],
        [ 0.1074],
        [-0.0787]], dtype=torch.float64), proj_right=tensor([[ 0.0027,  0.0156,  0.0022,  ...,  0.0022,  0.0017, -0.0062]],
       dtype=torch.float64), bias=tensor([ 4.9738e-18, -9.2371e-18, -1.8119e-17,  ..., -4.2588e-17,
        -1.0933e-16, -3.5527e-18], dtype=torch.float64))

In [121]:
# Projection matrix is stored as left and right matrices
leace_proj = torch.eye(3072) - eraser.proj_left @ eraser.proj_right
print(leace_proj[2994])
print(leace_proj[2995])

tensor([ 0.0017,  0.0101,  0.0014,  ...,  0.0014,  0.0011, -0.0040],
       dtype=torch.float64)
tensor([-1.1895e-04, -6.9421e-04, -9.8921e-05,  ..., -9.6575e-05,
        -7.7456e-05,  2.7653e-04], dtype=torch.float64)


In [8]:
x, y = probing_utils.get_new_word_labels_and_activations(model, german_data, hook_name, activation_slice)

# Logistic regression does learn something before concept erasure
real_lr = LogisticRegression(max_iter=2000).fit(x, y)
beta = torch.from_numpy(real_lr.coef_)
assert beta.norm(p=torch.inf) > 0.1

eraser = probing_utils.get_leace_eraser(x, y)
X_ = eraser(torch.from_numpy(x))

# But learns nothing after
null_lr = LogisticRegression(max_iter=2000, tol=0.0).fit(X_.numpy(), y)
beta = torch.from_numpy(null_lr.coef_)
assert beta.norm(p=torch.inf) < 1e-4


lbfgs failed to converge (status=2):
ABNORMAL_TERMINATION_IN_LNSRCH.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression



In [9]:
X_ = eraser(torch.from_numpy(x))

In [None]:
# top neurons used to erase "is space" feature
print(torch.topk(eraser.proj_left.squeeze(1), 10))
# top neurons used to erase "is mid-word" feature
inverse_eraser = probing_utils.get_leace_eraser(x, ~y)
print(torch.topk(inverse_eraser.proj_left.squeeze(1), 10))

In [None]:
# Cosine sim of 0.08
probe = probing_utils.get_probe(x[:20_000, [1426, 1507]], y[:20_000])
probe_dir = torch.zeros(model.cfg.d_mlp, dtype=torch.float32)
probe_dir[[1426, 1507]] = torch.from_numpy(probe.coef_[0]).float()
cosine_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
print(cosine_sim(eraser.proj_left.squeeze(1), probe_dir))

In [10]:
print(haystack_utils.get_average_loss(german_data, model))

def erase_feature_l8(value, hook):
    value = eraser(value.cpu()).cuda()
    return value
erase_feature_l8_hooks = [('blocks.8.mlp.hook_post', erase_feature_l8)]

def deactivate_retokenization_l8(value, hook):
    value = inverse_eraser(value.cpu()).cuda()
    return value
deactivate_retokenization_l8_hooks = [('blocks.8.mlp.hook_post', deactivate_retokenization_l8)]

# Same thing
with model.hooks(erase_feature_l8_hooks):
    print(haystack_utils.get_average_loss(german_data, model))

with model.hooks(deactivate_retokenization_l8_hooks):
    print(haystack_utils.get_average_loss(german_data, model))

tensor(2.4177, device='cuda:0')
tensor(2.4574, device='cuda:0')


NameError: name 'inverse_eraser' is not defined

### Repeat loss investigation with LEACE direction

In [11]:
# Path patch with direct
import pythia_160m_utils
all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=100)

def component_analysis(model, common_tokens, end_strings: list[str] | str, deactivate_context_hooks=erase_feature_l8_hooks, activate_context_hooks=[]):
    if isinstance(end_strings, str):
        end_strings = [end_strings]
    for end_string in end_strings:
        print(model.to_str_tokens(end_string))
        random_prompts = haystack_utils.generate_random_prompts(end_string, model, common_tokens, 400, length=20)
        data = mlp_effects_german(model, random_prompts, -1, deactivate_context_hooks, activate_context_hooks)

        haystack_utils.plot_barplot([[item.cpu().flatten().mean().item()] for item in data],
                                        names=['original', 'ablated', 'direct effect'] + [f'{i}{j}' for j in [9, 10, 11] for i in ["MLP"]], # + ["MLP9 + MLP11"]
                                        title=f'Loss increases from ablating various MLP components for end string \"{end_string}\"')

def mlp_effects_german(model, prompt, index, deactivate_context_hooks=erase_feature_l8_hooks, activate_context_hooks=[]):
        """Customised to L5 and L8 context neurons"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]
     
        original, ablated, direct_effect, _ = haystack_utils.get_direct_effect(
                prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                deactivated_components=tuple(downstream_components), activated_components=("blocks.5.hook_mlp_out", "blocks.8.hook_mlp_out",))
        
        data = [original, ablated, direct_effect]
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_mlp_out",))
                data.append(activated_component_loss)
        return data

def attn_effects_german(model, prompt, index, deactivate_context_hooks=erase_feature_l8_hooks, activate_context_hooks=[]):
        """Customised to L5 and L8 context neurons"""
        downstream_components = [(f"blocks.{layer}.hook_{component}_out") for layer in [6, 7, 9, 10, 11] for component in ['mlp', 'attn']]

        data = []
        for layer in [9, 10, 11]:
                _, _, _, activated_component_loss = haystack_utils.get_direct_effect(
                        prompt, model, pos=index, context_ablation_hooks=deactivate_context_hooks, context_activation_hooks=activate_context_hooks,
                        deactivated_components=tuple(component for component in downstream_components if component != f"blocks.{layer}.hook_mlp_out"),
                        activated_components=(f"blocks.{layer}.hook_attn_out",))
                data.append(activated_component_loss)
        return data

def get_mlp11_decrease_measure(losses):
    """Token with max interest measure"""
    measure = []
    for original_loss, ablated_loss, context_and_activated_loss, only_activated_loss in losses:
        combined = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss)
        measure.append(combined.max().item())
    return measure

def interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss):
    """Per-token measure, mixture of overall loss increase and loss increase from ablating MLP11"""
    loss_diff = (ablated_loss - original_loss) # Loss increase from context neuron
    mlp_11_power = (only_activated_loss - original_loss) # Loss increase from MLP11
    mlp_11_power[mlp_11_power < 0] = 0
    combined = 0.5 * loss_diff - mlp_11_power
    combined[original_loss > 6] = 0
    combined[original_loss > ablated_loss] = 0
    return combined

def print_prompt(prompt: str, activate_context_hooks=[], deactivate_context_hooks=erase_feature_l8_hooks):
    """Red/blue scale showing the interest measure for each token"""
    str_token_prompt = model.to_str_tokens(model.to_tokens(prompt))
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=erase_feature_l8_hooks, context_activation_hooks=activate_context_hooks,
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))

    pos_wise_diff = interest_measure(original_loss, ablated_loss, context_and_activated_loss, only_activated_loss).flatten().cpu().tolist()

    loss_list = [loss.flatten().cpu().tolist() for loss in [original_loss, ablated_loss, context_and_activated_loss, only_activated_loss]]
    loss_names = ["original_loss", "ablated_loss", "context_and_activated_loss", "only_activated_loss"]
    haystack_utils.clean_print_strings_as_html(str_token_prompt[1:], pos_wise_diff, additional_measures=loss_list, additional_measure_names=loss_names)


  0%|          | 0/200 [00:00<?, ?it/s]

In [12]:
# Get general and MLP11 specific losses
german_losses = []
for prompt in tqdm(german_data):
    original_loss, ablated_loss, context_and_activated_loss, only_activated_loss = haystack_utils.get_direct_effect(
        prompt, model, pos=None, context_ablation_hooks=erase_feature_l8_hooks, context_activation_hooks=[], 
        deactivated_components =("blocks.9.hook_attn_out", "blocks.10.hook_attn_out", "blocks.11.hook_attn_out", "blocks.9.hook_mlp_out", "blocks.10.hook_mlp_out"),
        activated_components = ("blocks.11.hook_mlp_out",))
    german_losses.append((original_loss, ablated_loss, context_and_activated_loss, only_activated_loss))


# component_analysis(model, common_tokens, end_strings, erase_feature_l8_hooks, [])

  0%|          | 0/200 [00:00<?, ?it/s]

In [13]:
measure = get_mlp11_decrease_measure(german_losses)
index = [i for i in range(len(measure))]

sorted_measure = list(zip(index, measure))
sorted_measure.sort(key=lambda x: x[1], reverse=True)
for i, measure in sorted_measure[:10]:
    print(measure)
    print_prompt(german_data[i])

0.8646253347396851


0.629259467124939


0.5446069240570068


0.51181960105896


0.4889712333679199


0.47079598903656006


0.46613597869873047


0.46566808223724365


0.45053958892822266


0.4457130432128906
