In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import json

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
ngram = "orschlägen"
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

activate_neurons_fwd_hooks, deactivate_neurons_fwd_hooks = haystack_utils.get_context_ablation_hooks(3, [669], model)
all_ignore, _ = haystack_utils.get_weird_tokens(model, plot_norms=False)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=100)

# Sort tokens into new word vs continuation
new_word_tokens = []
continuation_tokens = []
for token in common_tokens:
    str_token = model.to_single_str_token(token.item())
    if str_token.startswith(" "):
        new_word_tokens.append(token)
    else:
        continuation_tokens.append(token)
new_word_tokens = torch.stack(new_word_tokens)
continuation_tokens = torch.stack(continuation_tokens)

context_direction = model.W_out[3, 669, :]

def get_cosine_sim(direction: Float[Tensor, "d_res"], layer=5) -> Float[Tensor, "d_mlp"]:
    cosine = torch.nn.CosineSimilarity(dim=1)
    return cosine(model.W_in[layer].T, direction.unsqueeze(0))

def plot_histogram(t1, t2, t3, name1, name2, name3):
    t1 = t1.cpu().numpy()
    t2 = t2.cpu().numpy()
    t3 = t3.cpu().numpy()
    fig = go.Figure()
    bin_width= 0.01
    fig.add_trace(go.Histogram(x=t1, name=name1, opacity=0.5, histnorm='probability density', xbins=dict(size=bin_width)))
    fig.add_trace(go.Histogram(x=t2, name=name2, opacity=0.5 , histnorm='probability density', xbins=dict(size=bin_width)))
    fig.add_trace(go.Histogram(x=t3, name=name3, opacity=0.5, histnorm='probability density', xbins=dict(size=bin_width)))

    fig.update_layout(
        title="Individual MLP5 similarities to direction vectors",
        xaxis_title="Cosine Similarity",
        yaxis_title="Probability Density",
        barmode="overlay",
    )

    fig.show()

def compute_mlp_loss(prompts, df, neurons, ablate_mode="NNN", layer=5, compute_original_loss=False):

    mean_activations = torch.Tensor(df[df.index.isin(neurons.tolist())][ablate_mode].tolist()).cuda()
    def ablate_mlp_hook(value, hook):
        value[:, :, neurons] = mean_activations
        return value

    with model.hooks(fwd_hooks=[(f"blocks.{layer}.mlp.hook_pre", ablate_mlp_hook)]):
        ablated_loss = model(prompts, return_type="loss", loss_per_token=True)[:, -1].mean().item()

    if compute_original_loss:
        loss = model(prompts, return_type="loss", loss_per_token=True)[:, -1].mean().item()
        return loss, ablated_loss
    return ablated_loss

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

In [3]:
# Loss change for different AND thresholds
option = "orschlägen"
df = pd.read_pickle(f"data/and_neurons/df_{option}.pkl") 

with open(f"data/and_neurons/set_losses.json", "r") as f:
    all_losses = json.load(f)

prompts = haystack_utils.generate_random_prompts(option, model, common_tokens, 500, length=20)

In [4]:
gen_token = model.to_single_token("gen")
ge_token = model.to_single_token("ge")
gen_dir = model.tokens_to_residual_directions(gen_token)
ge_dir = model.tokens_to_residual_directions(ge_token)

cos = torch.nn.CosineSimilarity(dim=1)
gen_sims = cos(model.W_out[5], gen_dir.unsqueeze(0)).cpu().numpy()
ge_sims = cos(model.W_out[5], ge_dir.unsqueeze(0)).cpu().numpy()

df["GenSim"] = gen_sims
df["GeSim"] = ge_sims

In [43]:
# Check how similar gen and ge dirs are
cos = torch.nn.CosineSimilarity(dim=0)
cos(gen_dir, ge_dir)

tensor(0.3292, device='cuda:0')

In [None]:
# Global logprob effect from context neuron

original_logprobs = model(prompts, return_type="logprob", loss_per_token=True)[:, -1].cpu().numpy()
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    deactivated_logprobs = model(prompts, return_type="logprob", loss_per_token=True)[:, -1].cpu().numpy()

logprob_diffs = deactivated_logprobs - original_logprobs

In [5]:
print(df[df["AblationDiff"]>0.2][["AblationDiff", "GenSim", "GeSim"]].mean())
print(df[df["AblationDiff"]<-0.2][["AblationDiff", "GenSim", "GeSim"]].mean())

Prev/Curr/Context
AblationDiff    0.434787
GenSim          0.002992
GeSim          -0.000410
dtype: float64
Prev/Curr/Context
AblationDiff   -0.431541
GenSim          0.000029
GeSim           0.004855
dtype: float64


In [54]:
df_tmp = df.copy()
ablation_mode = "YYN"

df_tmp["Custom"] = (df["YYY"]>0) & (df["YYY"]>df["NNN"]) & (df["GenSim"]>df["GeSim"]) &\
    (df["YYY"]>df["YYN"]) & (df["YYY"]>df["YNY"]) & (df["YYY"]>df["NYY"]) &\
    (df["YYY"]>df["NYN"]) & (df["YYY"]>df["NNY"]) & (df["YYY"]>df["YNN"])# & df["PosSim"]

print(df_tmp["Custom"].sum())
pos_and_neurons = torch.LongTensor(df_tmp[df_tmp["Custom"]].index.tolist()).cuda()

#df_tmp["context_diff"] = df_tmp["YYY"] - df_tmp["YYN"]
#df_tmp = df_tmp.sort_values(by=["Custom", "context_diff"], ascending=False)
#pos_and_neurons = torch.LongTensor(df_tmp.index.tolist()[:30]).cuda()

original_loss, ablated_loss = compute_mlp_loss(prompts, df, pos_and_neurons, ablate_mode=ablation_mode, compute_original_loss=True)
print(original_loss, ablated_loss)

102
1.4382295608520508 7.869058609008789


In [63]:
haystack_utils.clean_cache()
prompts = haystack_utils.generate_random_prompts(option, model, common_tokens, 500, length=20)


In [78]:
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    deactivated_loss, cache = model.run_with_cache(prompts, return_type="loss", loss_per_token=True)

#cache = cache["blocks.5.mlp.hook_pre"][:, -2].mean(0)
print(deactivated_loss[:, -1].mean().item())

3.2064943313598633


In [83]:
cache["post", 5].shape

torch.Size([500, 23, 2048])

In [59]:
compute_mlp_loss(prompts, df, torch.LongTensor([i for i in range(2048)]), "YYN", compute_original_loss=True)

(1.4382295608520508, 5.653645992279053)

In [69]:
mean_activations_df = torch.Tensor(df[df.index.isin([i for i in range(2048)])]["YYN"].tolist()).cuda()

In [71]:
print(mean_activations_df.shape, cache.shape)
print(mean_activations_df[:10])
print(cache[:10])

torch.Size([2048]) torch.Size([2048])
tensor([-1.3228, -1.1839, -0.7734, -2.4799, -0.9134, -0.9777, -1.1465, -1.8528,
        -0.9403, -0.6645], device='cuda:0')
tensor([-1.3351, -1.1835, -0.7855, -2.4714, -0.9022, -0.9851, -1.1769, -1.8635,
        -0.9359, -0.6646], device='cuda:0')


In [50]:
# And thresholds without similarity constraint
prompts = haystack_utils.generate_random_prompts(option, model, common_tokens, 2000, length=20)
losses = []
lens =  []
thresholds = []
for and_threshold in np.arange(-0.5, 2, 0.1):
    df_tmp = df.copy()
    ablation_mode = "YYN"
    df_tmp["Custom"] = (df["YYY"]>and_threshold) & (df["YYN"]<=and_threshold) & (df["YNY"]<=and_threshold) & (df["NYY"]<=and_threshold) & (df["YNN"]<=and_threshold) & (df["NNY"]<=and_threshold)& (df["NYN"]<=and_threshold)
    pos_and_neurons = torch.LongTensor(df_tmp[df_tmp["Custom"]].index.tolist()).cuda()

    original_loss, ablated_loss = compute_mlp_loss(prompts, df, pos_and_neurons, ablate_mode=ablation_mode, compute_original_loss=True)
    if len(losses) == 0:
        losses.append([original_loss])
    losses.append([ablated_loss])
    lens.append(len(pos_and_neurons))
    thresholds.append(and_threshold)
names = ["Original"] + [f"AND {thr:.2f} ({length})" for thr, length in zip(thresholds, lens)]
haystack_utils.plot_barplot(losses, names, title="AND neuron loss increase for different thresholds")

In [51]:
# And thresholds with similarity constraint
losses = []
lens =  []
thresholds = []
for and_threshold in np.arange(-0.5, 2, 0.1):
    df_tmp = df.copy()
    ablation_mode = "YYN"
    df_tmp["Custom"] = (df["GenSim"]>df["GeSim"]) & \
        (df["YYY"]>and_threshold) & (df["YYN"]<=and_threshold) & (df["YNY"]<=and_threshold) & (df["NYY"]<=and_threshold) & (df["YNN"]<=and_threshold) & (df["NNY"]<=and_threshold)& (df["NYN"]<=and_threshold)
    pos_and_neurons = torch.LongTensor(df_tmp[df_tmp["Custom"]].index.tolist()).cuda()

    original_loss, ablated_loss = compute_mlp_loss(prompts, df, pos_and_neurons, ablate_mode=ablation_mode, compute_original_loss=True)
    if len(losses) == 0:
        losses.append([original_loss])
    losses.append([ablated_loss])
    lens.append(len(pos_and_neurons))
    thresholds.append(and_threshold)
names = ["Original"] + [f"AND {thr:.2f} ({length})" for thr, length in zip(thresholds, lens)]
haystack_utils.plot_barplot(losses, names, title="AND neuron loss increase with cosine sim constraint: gen > ge")

In [10]:
with model.hooks(deactivate_neurons_fwd_hooks):
    ablated_loss, ablated_cache = model.run_with_cache(prompts, return_type="loss")

def get_ablate_neurons_hook(neuron: int | list[int], ablated_cache, layer=5):
    def ablate_neurons_hook(value, hook):
        value[:, :, neuron] = ablated_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neuron]
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

ablate_top_neurons_hook = get_ablate_neurons_hook([i for i in range(2048)], ablated_cache)

In [11]:
original_logprobs, ablated_logprobs, _, all_MLP5_logprobs = haystack_utils.get_direct_effect(prompts, model, pos=-2, context_ablation_hooks=deactivate_neurons_fwd_hooks, context_activation_hooks=activate_neurons_fwd_hooks, return_type='logprobs')

diffs = (original_logprobs - ablated_logprobs).mean(0)
diffs[all_ignore] = 0
diffs[original_logprobs.mean(0)<-7] = 0
top_diff, top_token = torch.topk(diffs, 20)
print(top_diff)
print(model.to_str_tokens(top_token))

tensor([1.5872, 1.4587, 0.8471, 0.3234, 0.0033, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000], device='cuda:0')
['ß', 'gen', 'ßen', 'ger', 'gt', '-', '*', ',', '(', '&', ')', '+', "'", '!', '<|endoftext|>', '<|padding|>', '%', '$', '"', '#']


In [None]:

# %%

option = "orschlägen"
ablation_mode = "YYN"
prompts = haystack_utils.generate_random_prompts(option, model, common_tokens, 1000, length=20)

names = list(all_losses[option][ablation_mode].keys())
losses = [[all_losses[option][ablation_mode][name]] for name in names]

print(len(names), len(losses))
print([len(x) for x in losses])
haystack_utils.plot_barplot(losses, names)


In [None]:


layer = 2
ngram = "orschlägen"
prompts = haystack_utils.generate_random_prompts(ngram, model, common_tokens, 500, length=20)

if ngram.startswith(" "):
    prompt_tuple = haystack_utils.get_trigram_prompts(prompts, new_word_tokens, continuation_tokens)
else:
    prompt_tuple = haystack_utils.get_trigram_prompts(prompts, continuation_tokens, continuation_tokens)
prev_token_direction, curr_token_direction = haystack_utils.get_residual_trigram_directions(prompt_tuple, model, layer-1)

prev_token_sim = get_cosine_sim(prev_token_direction, layer)
curr_token_sim = get_cosine_sim(curr_token_direction, layer)
context_sim = get_cosine_sim(context_direction, layer)

plot_histogram(prev_token_sim, curr_token_sim, context_sim, "Prev Token", "Curr Token", "Context")
# %%
prev_sim_neurons = torch.argwhere(prev_token_sim>0.05)
curr_sim_neurons = torch.argwhere(curr_token_sim>0.03)

print(len(prev_sim_neurons), len(curr_sim_neurons))
union = haystack_utils.union_where([prev_token_sim, curr_token_sim], 0.07)
print(union)
# %%

# Get random mean cache
random_prompts = haystack_utils.generate_random_prompts(ngram, model, common_tokens, 500, length=20)[:, :-3]
_, random_cache = model.run_with_cache(random_prompts)


# %%
# Define ablate neuron hook

# Layer 1
# orschlägen: tensor([  61,  188, 1011], device='cuda:0')
# häufig: 268 (almost doubles loss)
# beweglich: neurons decrease loss - maybe they boost alternative completion

def get_ablate_neurons_hook(neurons, layer):
    print(neurons)
    def ablate_neurons_hook(value, hook):
        value[:, :, neurons] = random_cache[f'blocks.{layer}.mlp.hook_post'][:, :, neurons].mean((0, 1))
        return value
    return [(f'blocks.{layer}.mlp.hook_post', ablate_neurons_hook)]

# Check loss increase
original_loss, original_ablated_loss = compute_mlp_loss(prompts, df, torch.LongTensor([i for i in range(model.cfg.d_mlp)]).cuda(), compute_original_loss=True)

with model.hooks(fwd_hooks=get_ablate_neurons_hook([1789], layer)):
    ablated_loss = model(prompts, return_type="loss", loss_per_token=True)[:, -1].mean().item()

print(original_loss, original_ablated_loss, ablated_loss)
# %%

# 1011 increases loss on both "gen" and "ge"
# Either it boosts both completions (trigram table)
# Or it combines "orsch" and "lä" into a single representation that later components use

# Check if trigram table by looking at the direct effect
# Total effect of L1N1011
with model.hooks(fwd_hooks=get_ablate_neurons_hook([1406], layer)):
    _, ablated_cache = model.run_with_cache(prompts)

def ablate_component_hook(value, hook):
    value = ablated_cache[hook.name]
    return value

components = [f"blocks.{layer}.mlp.hook_post" for layer in range(3, 6)] + [f"blocks.{layer}.attn.hook_z" for layer in range(3, 6)]
hooks = [(component, ablate_component_hook) for component in components]

with model.hooks(fwd_hooks=hooks):
    ablated_logits = model(prompts, return_type="logits", loss_per_token=True)[:, -2].log_softmax(-1).mean(0)

original_logits = model(prompts, return_type="logits", loss_per_token=True)[:, -2].log_softmax(-1).mean(0)

print(ablated_logits.shape, original_logits.shape)

prob_diff = original_logits - ablated_logits
prob_diff[all_ignore] = 0
prob_diff[original_logits < -7] = 0
diffs, tokens = torch.topk(prob_diff, 20)
print(diffs)
print(tokens)
print(model.to_str_tokens(tokens))

# %% 
# Direct effect
_, original_cache = model.run_with_cache(prompts)

def activate_component_hook(value, hook):
    value = original_cache[hook.name]
    return value

activate_hooks = [(component, activate_component_hook) for component in components]

with model.hooks(fwd_hooks=activate_hooks + get_ablate_neurons_hook([1406], layer)):
    activated_logits = model(prompts, return_type="logits", loss_per_token=True)[:, -2].log_softmax(-1).mean(0)

prob_diff = original_logits - activated_logits
prob_diff[all_ignore] = 0
prob_diff[original_logits < -7] = 0
diffs, tokens = torch.topk(prob_diff, 20)
print(diffs)
print(tokens)
print(model.to_str_tokens(tokens))

# Check later components + context neuron effects of 1011
# %%

#output_direction = model.W_out[1, 1011]
output_direction = model.W_out[2, 1406]
context_direction = model.W_out[3, 669]

output_sims = get_cosine_sim(output_direction, 5)
context_sims = get_cosine_sim(context_direction, 5)

plot_histogram(output_sims, context_sims, torch.zeros_like(output_sims), "Output", "Context", "Zero")
# %%
union = haystack_utils.union_where([output_sims, context_sims], 0.05)
len(union)
# %%
ngram = "orschlägen"
prompts = haystack_utils.generate_random_prompts(ngram, model, common_tokens, 1000, length=20)

original_loss, original_ablated_loss = compute_mlp_loss(prompts, df, torch.LongTensor([i for i in range(model.cfg.d_mlp)]).cuda(), compute_original_loss=True)

with model.hooks(fwd_hooks=get_ablate_neurons_hook(union, 5)): #712, 394, 287
    ablated_loss = model(prompts, return_type="loss", loss_per_token=True)[:, -1].mean().item()

print(original_loss, original_ablated_loss, ablated_loss)
# %%
ngram = " meine Vorschläge"
prompts = haystack_utils.generate_random_prompts(ngram, model, common_tokens, 1000, length=20)
original_loss, original_ablated_loss = compute_mlp_loss(prompts, df, torch.LongTensor([i for i in range(model.cfg.d_mlp)]).cuda(), compute_original_loss=True)
print(original_loss, original_ablated_loss)
# %%
model.to_str_tokens(model.to_tokens(" deinen Vorschläge", prepend_bos=False))
# %%
