In [57]:
# Different tokens lengths
sentence_pairs = [
    ("I am eating an apple.", "Je mange une pomme."),
    ("She reads the book.", "Elle lit le livre."),
    ("He is walking in the park.", "Il se promène dans le parc."),
    ("We are studying French.", "Nous étudions le français."),
    ("They are watching a movie.", "Ils regardent un film."),
    ("You are drinking water.", "Tu bois de l'eau."),
    ("The cat is sleeping on the chair.", "Le chat dort sur la chaise."),
    ("The children are playing in the garden.", "Les enfants jouent dans le jardin."),
    ("He loves her.", "Il l'aime."),
    ("She is going to school.", "Elle va à l'école.")
]

same_token_len_pairs = [
    ("she", "elle"),
    ("is", "va"),
    ("the", "le"),
    ("friend", "ami"),
    ("of", "de"),
    ("city", "ville"),
    ("cat", "chat"),
    # ("dog", "chien"),
    ("I", "Je"),
    # ("am", "suis"),
    # ("drink", "bois")
]

In [2]:
from collections import defaultdict

import torch
import numpy as np
from transformer_lens import HookedTransformer

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "colab+vscode"

from haystack_utils import load_txt_data, get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [3]:
haystack_utils.clean_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


In [60]:
# for english, french in same_token_len_pairs:
#     if model.to_tokens(english).shape[1] == model.to_tokens(french).shape[1]:
#         print(english)
# print(len(model.to_tokens("dog")), len(model.to_tokens("chien")))
# print(len(model.to_tokens("cat")), len(model.to_tokens("chat")))

# french_prompt = " ".join(pair[1] for pair in same_token_len_pairs)
# english_prompt = " ".join(pair[0] for pair in same_token_len_pairs)

french_tokens = torch.stack([model.to_tokens(pair[1], prepend_bos=False) for pair in same_token_len_pairs]).squeeze(1, 2).unsqueeze(0)
english_tokens = torch.stack([model.to_tokens(pair[0], prepend_bos=False) for pair in same_token_len_pairs]).squeeze(1, 2).unsqueeze(0)

print(french_tokens.shape)
print(english_tokens.shape)

torch.Size([1, 8])
torch.Size([1, 8])


In [74]:
_, french_cache = model.run_with_cache(french_tokens)
_, english_cache = model.run_with_cache(english_tokens)

differences = []
labels = []
for key in french_cache.keys():
    if 'hook_attn_out' not in key and 'hook_mlp_out' not in key:
        continue
    french_acts = french_cache[key]
    english_acts = english_cache[key]
    labels.append(key)
    differences.append((french_acts - english_acts).abs().mean().item())
    # absolute sum of activation differences at each layer

print([key for key in french_cache.keys()])
print(differences)
haystack_utils.line(differences, xticks=labels, title="Absolute Mean Activation Differences", xlabel="Component", ylabel="Absolute Mean Activation Difference")

['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'bl