In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line, two_histogram
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [16]:
def load_pile(stream=False):
    dataset = load_dataset("EleutherAI/pile", name="europarl", split="train")
    dataset = dataset.shuffle(seed=42)
    return dataset

def get_random_samples(dataset, language="fr", n=100, min_length=100, max_length=1000):
    pbar = tqdm(total=n)
    language_data = []
    for i, example in enumerate(dataset):
        sentence_language = example["meta"][-4:-2]
        sentence = example["text"]
        if (len(sentence) >= min_length):
            if (len(sentence) > max_length):
                sentence = sentence[:max_length]
            if (len(language_data) < n) and (sentence_language==language):
                language_data.append(sentence)
                pbar.update(1)
        if (len(language_data) >= n):
            pbar.close()
            return language_data
    print(f"Warning: not enough data found, returning {len(language_data)} samples")
    pbar.close()
    return language_data

In [3]:
dataset = load_pile()

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading and preparing dataset pile/europarl to /root/.cache/huggingface/datasets/EleutherAI___pile/europarl/0.0.0/ebea56d358e91cf4d37b0fde361d563bed1472fbd8221a21b38fc8bb4ba554fb...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.48G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset pile downloaded and prepared to /root/.cache/huggingface/datasets/EleutherAI___pile/europarl/0.0.0/ebea56d358e91cf4d37b0fde361d563bed1472fbd8221a21b38fc8bb4ba554fb. Subsequent calls will reuse this data.


In [18]:
german_europarl = get_random_samples(dataset, language="de", n=2000, min_length=100, max_length=2000)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [20]:
english_europarl = get_random_samples(dataset, language="en", n=2000, min_length=100, max_length=2000)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [21]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m", device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


In [22]:
def count_token_occurrences(prompts: list[str]):
    token_counts = torch.zeros(model.cfg.d_vocab).to(device)
    for prompt in tqdm(prompts):
        # Remove BOS
        tokens = model.to_tokens(prompt).flatten()[1:]
        token_counts[tokens] += 1
    return token_counts

In [23]:
german_unigram_counts = count_token_occurrences(german_europarl)
german_unigram_highest_counts, german_unigram_tokens = torch.topk(german_unigram_counts, 100)
german_unigram_labels = model.to_str_tokens(german_unigram_tokens)

print(german_unigram_labels[:100])

  0%|          | 0/2000 [00:00<?, ?it/s]

['\n', '.', ' der', ' die', 'äsident', ' Pr', 'en', ' und', ',', 'ung', 'ä', 'st', 'ch', ' den', ' in', 're', 't', ' des', ' zu', ' für', 'z', ' von', 'ischen', 'n', 'ü', 'ge', 'gen', ' auf', ')', ' ist', 'icht', ' über', 'men', 'te', ' er', 'in', 'ig', 'g', ' im', 'es', 'f', '-', ' das', 'Der', ' eine', 'le', ' w', 'ten', ' an', 'ß', ' (', ' dass', ' ein', 'ren', 'hen', 'e', ' dem', 'w', 's', ' mit', ' dies', ' ', ' Europ', ' wir', ' z', ' W', 'ungen', ' ich', 'it', ' Z', ' Herr', 'h', ' nicht', '!', ' zur', ' B', 'ich', ' ver', ' -', ' es', 'lich', 'chte', ' Ver', 'i', ' An', 'igen', 'hr', ' mö', 'l', 'et', ' um', 'em', 'heit', ' V', ' werden', 'u', ' g', ' d', ' be', 'b']


In [24]:
english_unigram_counts = count_token_occurrences(english_europarl)
english_unigram_highest_counts, english_unigram_tokens = torch.topk(english_unigram_counts, 100)
english_unigram_labels = model.to_str_tokens(english_unigram_tokens)

print(english_unigram_labels[:100])

  0%|          | 0/2000 [00:00<?, ?it/s]

[' the', '\n', '.', ',', ' of', ' to', ' and', ' in', ' on', ' a', ' is', ' that', ' for', ' President', ' this', ' I', ')', 'President', '-', ' have', ' by', ' be', ' (', ' European', ' with', ' we', 'The', ' are', ' it', ' as', ' not', ' Mr', ' which', ' -', ' has', ' will', ' ', ' Parliament', ' an', ' next', ' would', ' item', ' at', ' The', ' all', ' been', "'s", ' Commission', ' was', ':', ' from', ' like', ' our', ' you', ' also', ' but', ' should', ' report', ' behalf', '(', ' very', ' Council', 'I', ' Committee', ' Union', ' its', ' there', ' my', 'ate', ' We', ' their', 'Mr', ' vote', ' am', ' It', ' can', 'deb', ' debate', ' one', ' who', ' so', ' This', ' or', ' time', '/', ' do', ' Member', ' gentlemen', ' more', ' us', ' these', ' other', ' about', ' now', "'", ' were', ' important', ' States', ' first', ' must']


In [None]:
# Todo 
# Take top X unigrams
# Remove unigrams present in both English and German top
# Store unigrams, compare ablation scores