In [2]:
import torch
from transformer_lens import HookedTransformer

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "colab+vscode"

import haystack_utils

%reload_ext autoreload
%autoreload 2

In [3]:
haystack_utils.clean_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("pythia-70m-v0", fold_ln=True, device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m-v0 into HookedTransformer


In [10]:
same_token_len_pairs = [
    (" she", " elle"),
    (" is", " va"),
    (" the", " le"),
    # ("friend", "ami"),
    (" of", " de"),
    (" city", " ville"),
    (" cat", " chat"),
    # ("dog", "chien"),
    (" I", " Je"),
    (" am", " suis"),
    # ("drink", "bois")
]
more_same_token_len_pairs = [(' day', ' jour'), (' fruit', ' fruit'), (' wind', ' vent'), (' lake', ' lac'), (' sea', ' mer')]
same_token_len_pairs += more_same_token_len_pairs

for english, french in same_token_len_pairs:
    assert model.to_tokens(english).shape[1] == model.to_tokens(french).shape[1]

In [11]:
french_prompt = "".join(pair[1] for pair in same_token_len_pairs)
english_prompt = "".join(pair[0] for pair in same_token_len_pairs)

assert model.to_tokens(french_prompt).shape == model.to_tokens(english_prompt).shape

In [6]:
# Prompts generate reasonable English and French text respectively
# print(haystack_utils.generate_text(english_prompt, model))
# print(haystack_utils.generate_text(french_prompt, model))

In [13]:
_, french_cache = model.run_with_cache(french_prompt)
_, english_cache = model.run_with_cache(english_prompt)

differences = []
labels = []
for key in french_cache.keys():
    if 'hook_mlp_out' not in key:
        continue
    french_acts = french_cache[key]
    english_acts = english_cache[key]
    labels.append(key)
    differences.append((french_acts - english_acts).abs().mean().item())
    # absolute sum of activation differences at each layer

print([key for key in french_cache.keys()])
print(differences)
haystack_utils.line(differences, xticks=labels, title="Language Specificity of MLPs - French vs. English", xlabel="Component", ylabel="Absolute Mean Activation Difference")

differences = []
labels = []
for key in french_cache.keys():
    if 'hook_attn_out' not in key:
        continue
    french_acts = french_cache[key]
    english_acts = english_cache[key]
    labels.append(key)
    differences.append((french_acts - english_acts).abs().mean().item())
    # absolute sum of activation differences at each layer

print([key for key in french_cache.keys()])
print(differences)
haystack_utils.line(differences, xticks=labels, title="Language Specificity of Attention Layers - French vs. English", xlabel="Component", ylabel="Absolute Mean Activation Difference")

['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'bl

['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'bl

I thought the mid layer would contain a non-language specific representation and would thus have a lower activation difference. The general pattern seems to hold true but the "mid layer" is earlier than I expected, in MLPs one and two.

I should repeat with more data to confirm.

In [9]:
# More French and English words with the same tokenized shape

more = [
    (" chair", " chaise"),
    (" light", " lumière"),
    (" night", " nuit"),
    (" day", " jour"),
    (" food", " repas"),
    (" drink", " boisson"),
    (" fruit", " fruit"),
    (" flower", " fleur"),
    (" tree", " arbre"),
    (" sun", " soleil"),
    (" moon", " lune"),
    (" star", " étoile"),
    (" bird", " oiseau"),
    (" fish", " poisson"),
    (" rain", " pluie"),
    (" snow", " neige"),
    (" wind", " vent"),
    (" river", " rivière"),
    (" lake", " lac"),
    (" sea", " mer"),
    (" mountain", " montagne"),
    (" sky", " ciel"),
    (" cloud", " nuage")
]

same = []
for english, french in more:
    if model.to_tokens(english).shape[1] == model.to_tokens(french).shape[1]:
        same.append((english, french))

print(same)
# print(model.to_tokens(english).shape[1] == model.to_tokens(french).shape[1])




[(' day', ' jour'), (' fruit', ' fruit'), (' wind', ' vent'), (' lake', ' lac'), (' sea', ' mer')]
