In [27]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import haystack_utils
from transformer_lens import utils
from fancy_einsum import einsum
import einops
import json
import ipywidgets as widgets
from IPython.display import display
from datasets import load_dataset
import random
import math
import random
import neel.utils as nutils
from neel_plotly import *
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
import probing_utils
import pickle
from sklearn.metrics import matthews_corrcoef
import gzip
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotting_utils
import re

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
#torch.autograd.set_grad_enabled(False)
#torch.set_grad_enabled(False)

%reload_ext autoreload
%autoreload 2

In [7]:
def get_model(checkpoint: int) -> HookedTransformer:
    model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
        checkpoint_index=checkpoint,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device=device)
    return model

NUM_CHECKPOINTS = 143
LAYER, NEURON = 3, 669
model = get_model(142)
german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")
all_ignore, _ = haystack_utils.get_weird_tokens(model, plot_norms=False)
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=100)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/2000 [00:00<?, ?it/s]

## Find ambiguous bigrams and trigrams


In [8]:
def get_bigram_token_counts(data, model: HookedTransformer):
    counts = torch.zeros((model.cfg.d_vocab, model.cfg.d_vocab))
    for sentence in tqdm(data):
        tokens = model.to_tokens(sentence).flatten().cpu()
        next_tokens = tokens[1:]
        for i in range(len(tokens) - 2):
            counts[tokens[i], next_tokens[i]] += 1
    return counts

english_bigram_counts = get_bigram_token_counts(english_data, model)
german_bigram_counts = get_bigram_token_counts(german_data, model)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [47]:
def get_punctuation_tokens(model):
    punctuation = ["\n", ".", ",", "!", "?", ";", ":", "-", "(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "\"", "'"]
    leading_space_punctuation = [" " + char for char in punctuation]
    punctuation_tokens = model.to_tokens(punctuation + leading_space_punctuation + [' –', " ", '  ', "<|endoftext|>"])[:, 1].flatten()
    return punctuation_tokens

def get_number_tokens(model:HookedTransformer):
    all_tokens = [i for i in range(model.cfg.d_vocab)]
    number_tokens = []
    for token in all_tokens:
        str_token = model.to_single_str_token(token)
        try:
            float(str_token)
            number_tokens.append(token)
        except ValueError:
            pass
    return torch.LongTensor(number_tokens)

def get_all_non_letter_tokens(model: HookedTransformer):
    all_tokens = [i for i in range(model.cfg.d_vocab)]
    letter_tokens = []
    for token in all_tokens:
        str_token = model.to_single_str_token(token)
        if not bool(re.search(r'[a-zA-Z]', str_token)):
            letter_tokens.append(token)
    return torch.LongTensor(letter_tokens)

In [48]:
punctuation_tokens = get_punctuation_tokens(model)
number_tokens = get_number_tokens(model)
non_letter_tokens = get_all_non_letter_tokens(model)

In [49]:
def get_common_tokens(token_counts, threshold=100):
    top_tokens = torch.argwhere(token_counts > threshold).flatten()
    return top_tokens

In [None]:
german_bigram_counts[non_letter_tokens] = 0
german_bigram_counts[:, non_letter_tokens] = 0
english_bigram_counts[non_letter_tokens] = 0
english_bigram_counts[:, non_letter_tokens] = 0

In [52]:
german_bigram_counts.shape

torch.Size([50304, 50304])

In [40]:
top_german_tokens = get_common_tokens(german_bigram_counts.sum(0))
top_english_tokens = get_common_tokens(english_bigram_counts.sum(0))
common_both = torch.tensor(list(set(top_english_tokens.tolist()).intersection(set(top_german_tokens.tolist()))))
print(len(common_both), model.to_str_tokens(torch.LongTensor(common_both)))

93 [' all', ' EU', ' inform', 'ations', 'idents', 'ia', ' find', 'a', 'b', 'i', 'k', 'l', ' will', 'n', 'o', ' so', 's', 'y', 'z', ' international', ' K', ' Mon', ' her', ' hand', 'ann', ' V', ' best', ' St', ' also', 'PE', ' am', ' get', ' t', ' a', 'in', 're', ' bring', 'er', ' Union', 'en', 'is', 'it', 'ed', 'es', 'an', ' an', 'ing', 'ar', ' in', 'ou', 'as', ' and', 'ro', ' national', 'el', ' T', ' I', 'ol', 'am', 'ation', ' be', ' S', ' for', ' C', ' he', ' M', ' we', ' set', ' P', ' de', 'ise', 'os', ' B', ' H', 'ers', ' D', ' F', ' W', ' R', ' not', ' L', 'ort', ' un', ' G', ' E', 'ies', ' Bar', ' O', ' end', ' me', ' J', 'ast', 'ans']


In [56]:
for common_token in common_both:
    german_next_token = german_bigram_counts[common_token].argmax()
    english_next_token = english_bigram_counts[common_token].argmax()
    num_occurrences_german = german_bigram_counts[common_token, german_next_token].item()
    num_occurrences_english = english_bigram_counts[common_token, english_next_token].item()
    if (german_next_token != english_next_token) and (num_occurrences_german>50) and (num_occurrences_english>50):
        print(model.to_str_tokens(torch.LongTensor([common_token, german_next_token, english_next_token])), num_occurrences_german, num_occurrences_english)

[' all', 'en', ' the'] 270.0 326.0
[' EU', ' und', "'s"] 51.0 76.0
[' inform', 'ieren', ' you'] 58.0 59.0
['b', 'ens', 'oda'] 292.0 87.0
['l', 'uss', 'ause'] 302.0 214.0
[' will', 'kom', ' be'] 102.0 640.0
[' so', 'z', ' that'] 196.0 170.0
[' Mon', 'aten', 'etary'] 70.0 106.0
[' St', 'ell', 'ras'] 116.0 147.0
[' am', ' Don', 'ending'] 109.0 130.0
[' t', 'ats', 'abled'] 63.0 83.0
[' Union', ' und', ' and'] 111.0 149.0
[' an', ' die', ' important'] 424.0 129.0
[' in', ' der', ' the'] 1261.0 3212.0
[' and', 'eren', ' the'] 339.0 1773.0
['ro', 'ffen', 'so'] 108.0 77.0
[' I', 'hn', ' would'] 528.0 1020.0
['ol', 'ge', 'ences'] 89.0 53.0
['am', 'ten', ' President'] 110.0 629.0
[' be', 'gr', ' a'] 257.0 161.0
[' S', 'itz', 'wo'] 901.0 78.0
[' for', 'dern', ' the'] 59.0 1641.0
[' we', 'il', ' are'] 291.0 645.0
[' P', 'unk', 'PE'] 973.0 79.0
[' not', 'wend', ' only'] 125.0 150.0
['ort', 'en', 'eur'] 72.0 380.0
[' end', 'g', ' of'] 94.0 129.0
[' me', 'ine', ' to'] 533.0 111.0


## Trigrams

In [59]:
len(non_letter_tokens)

6600

In [60]:
valid_tokens = [i for i in range(model.cfg.d_vocab) if i not in non_letter_tokens]
print(len(valid_tokens))

43704


In [67]:
common_bigrams = []
top_german_tokens = get_common_tokens(german_bigram_counts.sum(0), 50)
bigram_threshold = 20
for first_token in tqdm(top_german_tokens):
    german_bigrams = german_bigram_counts[first_token]
    english_bigrams = english_bigram_counts[first_token]
    new_bigrams = torch.argwhere((german_bigrams > bigram_threshold) & (english_bigrams > bigram_threshold)).flatten()
    for second_token in new_bigrams.tolist():
        common_bigrams.append((first_token, second_token))
print(len(common_bigrams))

  0%|          | 0/1410 [00:00<?, ?it/s]

33


In [74]:
last_trigram_tokens = torch.zeros((len(common_bigrams), model.cfg.d_vocab, 2), dtype=torch.long)
for bigram_index, bigram in tqdm(enumerate(common_bigrams)):
    for sentence in german_data:
        tokens = model.to_tokens(sentence).flatten().cpu()
        next_tokens = tokens[1:]
        for i in range(len(tokens) - 3):
            if (tokens[i] == bigram[0]) and (tokens[i+1] == bigram[1]):
                last_trigram_tokens[bigram_index, tokens[i+2], 0] += 1
    for sentence in english_data:
        tokens = model.to_tokens(sentence).flatten().cpu()
        next_tokens = tokens[1:]
        for i in range(len(tokens) - 3):
            if (tokens[i] == bigram[0]) and (tokens[i+1] == bigram[1]):
                last_trigram_tokens[bigram_index, tokens[i+2], 1] += 1


0it [00:00, ?it/s]

In [85]:
last_trigram_tokens[:, non_letter_tokens, :] = 0

In [86]:
trigram_threshold = 10
for trigram in range(len(common_bigrams)):
    most_common_german_token = torch.argmax(last_trigram_tokens[trigram, :, 0])
    most_common_english_token = torch.argmax(last_trigram_tokens[trigram, :, 1])
    german_occurrences = last_trigram_tokens[trigram, most_common_german_token, 0].item()
    english_occurrences = last_trigram_tokens[trigram, most_common_english_token, 1].item()
    if (most_common_german_token != most_common_english_token) and (german_occurrences>trigram_threshold) and (english_occurrences>trigram_threshold):
        print(model.to_str_tokens(torch.LongTensor([common_bigrams[trigram][0], common_bigrams[trigram][1], most_common_german_token, most_common_english_token])), german_occurrences, english_occurrences)

[' in', ' all', 'en', ' the'] 48 18


In [89]:
utils.test_prompt("Wir haben in all", "en", model, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', 'W', 'ir', ' haben', ' in', ' all']
Tokenized answer: ['en']


Top 0th token. Logit: 18.37 Prob: 26.08% Token: |en|
Top 1th token. Logit: 17.03 Prob:  6.79% Token: |i|
Top 2th token. Logit: 16.44 Prob:  3.77% Token: | the|
Top 3th token. Logit: 16.41 Prob:  3.66% Token: |erd|
Top 4th token. Logit: 16.39 Prob:  3.61% Token: | der|
Top 5th token. Logit: 15.66 Prob:  1.74% Token: | of|
Top 6th token. Logit: 15.42 Prob:  1.36% Token: |g|
Top 7th token. Logit: 15.17 Prob:  1.06% Token: |o|
Top 8th token. Logit: 14.95 Prob:  0.85% Token: | den|
Top 9th token. Logit: 14.87 Prob:  0.78% Token: |ig|


In [90]:
utils.test_prompt("We have all", "the", model)

Tokenized prompt: ['<|endoftext|>', 'We', ' have', ' all']
Tokenized answer: [' the']


Top 0th token. Logit: 19.45 Prob: 18.02% Token: | the|
Top 1th token. Logit: 18.69 Prob:  8.41% Token: | been|
Top 2th token. Logit: 18.41 Prob:  6.32% Token: | sorts|
Top 3th token. Logit: 18.34 Prob:  5.92% Token: | of|
Top 4th token. Logit: 18.04 Prob:  4.38% Token: | seen|
Top 5th token. Logit: 17.61 Prob:  2.85% Token: | heard|
Top 6th token. Logit: 17.46 Prob:  2.46% Token: | got|
Top 7th token. Logit: 17.36 Prob:  2.22% Token: | our|
Top 8th token. Logit: 17.32 Prob:  2.14% Token: | a|
Top 9th token. Logit: 17.17 Prob:  1.83% Token: | kinds|


In [91]:
def deactivate_neurons_hook(value, hook):
    value[:, :, NEURON] = 0
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER}.mlp.hook_post', deactivate_neurons_hook)]

In [92]:
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    utils.test_prompt("Wir haben in all", "en", model, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', 'W', 'ir', ' haben', ' in', ' all']
Tokenized answer: ['en']


Top 0th token. Logit: 16.77 Prob:  7.15% Token: |i|
Top 1th token. Logit: 16.77 Prob:  7.14% Token: | the|
Top 2th token. Logit: 15.91 Prob:  3.01% Token: |en|
Top 3th token. Logit: 15.82 Prob:  2.76% Token: | of|
Top 4th token. Logit: 15.60 Prob:  2.22% Token: | but|
Top 5th token. Logit: 15.55 Prob:  2.10% Token: |
|
Top 6th token. Logit: 15.47 Prob:  1.94% Token: | 50|
Top 7th token. Logit: 15.05 Prob:  1.28% Token: |erd|
Top 8th token. Logit: 15.01 Prob:  1.23% Token: |o|
Top 9th token. Logit: 14.96 Prob:  1.16% Token: | g|


In [None]:
german_answer = model.to_single_token("en")
english_answer = model.to_single_token(" the")

german_logits = []
english_logits = []
german_ablated_logits = []
english_ablated_logits = []
for i in range(NUM_CHECKPOINTS):
    model = get_model(i)
    logits = model("Wir haben in all", return_type="logits")[0, -1]
    german_logits.append(logits[german_answer].item())
    english_logits.append(logits[english_answer].item())
    with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
        logits = model("Wir haben in all", return_type="logits")[0, -1]
        german_ablated_logits.append(logits[german_answer].item())
        english_ablated_logits.append(logits[english_answer].item())

In [107]:
lines = [german_logits, english_logits, german_ablated_logits, english_ablated_logits] + [[i for i in range(NUM_CHECKPOINTS)]]
names = ["German", "English", "German (Ablated)", "English (Ablated)", "Checkpoint"]
title = "Logits for 'Wir haben in all' - 'en' (GER) vs ' the' (ENG)"
df = pd.DataFrame({name:line for name, line in zip(names, lines)})
df = df.melt(id_vars=["Checkpoint"], var_name="Type", value_name="Logit")
px.line(df, x="Checkpoint", y="Logit", color="Type", title=title)