In [1]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line, two_histogram
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
def load_pile(stream=False):
    dataset = load_dataset("EleutherAI/pile", name="europarl", split="train")
    dataset = dataset.shuffle(seed=42)
    return dataset

def get_random_samples(dataset, language="fr", n=100, min_length=100, max_length=1000):
    pbar = tqdm(total=n)
    language_data = []
    for i, example in enumerate(dataset):
        sentence_language = example["meta"][-4:-2]
        sentence = example["text"]
        if (len(sentence) >= min_length):
            if (len(sentence) > max_length):
                sentence = sentence[:max_length]
            if (len(language_data) < n) and (sentence_language==language):
                language_data.append(sentence)
                pbar.update(1)
        if (len(language_data) >= n):
            pbar.close()
            return language_data
    print(f"Warning: not enough data found, returning {len(language_data)} samples")
    pbar.close()
    return language_data

In [3]:
dataset = load_pile()

Downloading builder script:   0%|          | 0.00/9.53k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

Downloading and preparing dataset pile/europarl to /root/.cache/huggingface/datasets/EleutherAI___pile/europarl/0.0.0/ebea56d358e91cf4d37b0fde361d563bed1472fbd8221a21b38fc8bb4ba554fb...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.48G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset pile downloaded and prepared to /root/.cache/huggingface/datasets/EleutherAI___pile/europarl/0.0.0/ebea56d358e91cf4d37b0fde361d563bed1472fbd8221a21b38fc8bb4ba554fb. Subsequent calls will reuse this data.


In [4]:
german_europarl = get_random_samples(dataset, language="de", n=2000, min_length=100, max_length=2000)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [5]:
english_europarl = get_random_samples(dataset, language="en", n=2000, min_length=100, max_length=2000)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [6]:
def write_data(data: list[str], path: str):
    with open(path, 'w') as f:
        json.dump(data, f)

def load_data(path: str):
    with open(path, 'r') as f:
        return json.load(f)

write_data(german_europarl, "german_europarl.json")
write_data(english_europarl, "english_europarl.json")

In [6]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m", device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


In [7]:
def count_token_occurrences(prompts: list[str]):
    token_counts = torch.zeros(model.cfg.d_vocab).to(device)
    for prompt in tqdm(prompts):
        # Remove BOS
        tokens = model.to_tokens(prompt).flatten()[1:]
        token_counts[tokens] += 1
    return token_counts

In [8]:
def get_top_unigrams(prompts: list[str], n: int=1000):
    token_counts = count_token_occurrences(prompts)
    counts, tokens = torch.topk(token_counts, n)
    return top_unigrams

In [9]:
german_unigram_counts = count_token_occurrences(german_europarl)
german_unigram_highest_counts, german_unigram_tokens = torch.topk(german_unigram_counts, 1000)
german_unigram_labels = model.to_str_tokens(german_unigram_tokens)

print(german_unigram_labels[:100])

  0%|          | 0/2000 [00:00<?, ?it/s]

['\n', '.', ' der', ' die', 'äsident', ' Pr', 'en', ' und', ',', 'ung', 'ä', 'st', 'ch', ' den', ' in', 're', 't', ' des', ' zu', ' für', 'z', ' von', 'ischen', 'n', 'ü', 'ge', 'gen', ' auf', ')', ' ist', 'icht', ' über', 'men', 'te', ' er', 'in', 'ig', 'g', ' im', 'es', 'f', '-', ' das', 'Der', ' eine', 'le', ' w', 'ten', ' an', 'ß', ' (', ' dass', ' ein', 'ren', 'hen', 'e', ' dem', 'w', 's', ' mit', ' dies', ' ', ' Europ', ' wir', ' z', ' W', 'ungen', ' ich', 'it', ' Z', ' Herr', 'h', ' nicht', '!', ' zur', ' B', 'ich', ' ver', ' -', ' es', 'lich', 'chte', ' Ver', 'i', ' An', 'igen', 'hr', ' mö', 'l', 'et', ' um', 'em', 'heit', ' V', ' werden', 'u', ' g', ' d', ' be', 'b']


In [10]:
english_unigram_counts = count_token_occurrences(english_europarl)
english_unigram_highest_counts, english_unigram_tokens = torch.topk(english_unigram_counts, 1000)
english_unigram_labels = model.to_str_tokens(english_unigram_tokens)

print(english_unigram_labels[:100])

  0%|          | 0/2000 [00:00<?, ?it/s]

[' the', '\n', '.', ',', ' of', ' to', ' and', ' in', ' on', ' a', ' is', ' that', ' for', ' President', ' this', ' I', ')', 'President', '-', ' have', ' by', ' be', ' (', ' European', ' with', ' we', 'The', ' are', ' it', ' as', ' not', ' Mr', ' which', ' -', ' has', ' will', ' ', ' Parliament', ' an', ' next', ' would', ' item', ' at', ' The', ' all', ' been', "'s", ' Commission', ' was', ':', ' from', ' like', ' our', ' you', ' also', ' but', ' should', ' report', ' behalf', '(', ' very', ' Council', 'I', ' Committee', ' Union', ' its', ' there', ' my', 'ate', ' We', ' their', 'Mr', ' vote', ' am', ' It', ' can', 'deb', ' debate', ' one', ' who', ' so', ' This', ' or', ' time', '/', ' do', ' Member', ' gentlemen', ' more', ' us', ' these', ' other', ' about', ' now', "'", ' were', ' important', ' States', ' first', ' must']


In [11]:
duplicate_unigrams = set(german_unigram_tokens.cpu().tolist()).intersection(set(english_unigram_tokens.cpu().tolist()))
unique_german_unigrams = set(german_unigram_tokens.cpu().tolist()).difference(set(english_unigram_tokens.cpu().tolist()))
unique_english_unigrams = set(english_unigram_tokens.cpu().tolist()).difference(set(german_unigram_tokens.cpu().tolist()))
print(f"Found {len(duplicate_unigrams)} duplicate unigrams")
print(f"Found {len(unique_german_unigrams)} unique german unigrams and {len(unique_english_unigrams)} unique english unigrams")
unique_german_unigrams = torch.LongTensor(list(unique_german_unigrams))
unique_english_unigrams = torch.LongTensor(list(unique_english_unigrams))

Found 158 duplicate unigrams
Found 842 unique german unigrams and 842 unique english unigrams


In [12]:
print(model.to_str_tokens(unique_german_unigrams))
print(model.to_str_tokens(unique_english_unigrams))

['inf', ' geb', '"', ' Zeit', ' fre', 'schen', 'org', ' als', 'geb', ' Dam', 'glich', 'cks', ' Le', 'reich', 'la', ' dav', ' bin', 'ß', 'uten', 'ichte', ' Ber', 'W', 'Z', 'nd', ' dan', ' wel', 'd', 'f', 'g', 'lass', 'fn', ' Text', ' vie', 'p', 'r', 'inder', 'v', 'w', ' war', 'olit', ' soll', 'ler', 'ser', ' da', ' Kol', 'rac', ' Menschen', 'hl', ' Vert', ' aus', ' Abs', 'hend', 'mal', 'abe', 'pf', 'hr', 'was', 'ilen', 'ucht', ' bes', 'gang', 'Mit', 'loss', ' unter', ' muss', ' fol', 'wer', ' vom', 'keit', 'run', ' Fre', ' Dem', ' Volks', 'ichen', ' Bet', 'ahren', 'tn', ' stim', ' oh', ' auch', 'ete', 'eb', ' Stand', ' Bes', ' mir', 'ube', 'rl', 'agt', 'he', 'at', ' wer', ' von', ' w', 'nis', ' keine', ' polit', ' f', ' b', 'af', ' Akt', ' d', ' Be', ' m', 'gie', 'ker', ' ver', ' h', 'ieren', 'ent', 'et', ' dieser', ' n', 'st', 'om', ' l', ' e', 'il', 'id', ' Hand', 'enz', 'ig', ' g', ' beg', 'ut', ' gle', ' hier', ' Red', 'us', 'ac', ' auf', ' zw', 'elle', 'ur', ' bit', 'uf', 'un', ' s

In [13]:
LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = 3.8170
MEAN_ACTIVATION_INACTIVE = -0.0738

In [14]:
all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

## Unigram

In [15]:
# Unembed neuron direction directly

# Only works for individual neurons
# Shape batch pos d_resid
neuron_weight = model.W_out[LAYER_TO_ABLATE, NEURONS_TO_ABLATE].view(1, 1, -1)
neuron_direction_active = neuron_weight * MEAN_ACTIVATION_ACTIVE # Set German neuron to activated value (~3)
neuron_direction_inactive = neuron_weight * MEAN_ACTIVATION_INACTIVE # Set German neuron to disabled value (~0)

tokens_active = model.unembed(neuron_direction_active)
tokens_inactive = model.unembed(neuron_direction_inactive)
# Active: German neuron is active - we expect German tokens boosted
# Inactive: German neuron is inactive - we expect no boost to German tokens
# Active - Inactive: If the neuron boosts German tokens, we expect this to be positive
token_differences = (tokens_active - tokens_inactive).flatten()

boosted_values, boosted_tokens = haystack_utils.top_k_with_exclude(token_differences, 100, exclude=all_ignore)
inhibited_values, inhibited_tokens =haystack_utils.top_k_with_exclude(token_differences, 100, largest=False, exclude=all_ignore)
boosted_labels = model.to_str_tokens(boosted_tokens)
inhibited_labels = model.to_str_tokens(inhibited_tokens)

px.histogram(token_differences.cpu().numpy(), nbins=1000, title="Token differences")

In [16]:
haystack_utils.line(x=boosted_values.cpu().numpy(), xticks=boosted_labels, title="Boosted tokens", width=1200)

In [17]:
haystack_utils.line(x=inhibited_values.cpu().numpy(), xticks=inhibited_labels, title="Boosted tokens", width=1200)

In [18]:
top_english_token_differences = token_differences[unique_english_unigrams]
top_german_token_differences = token_differences[unique_german_unigrams]
print(top_german_token_differences.mean())
print(top_english_token_differences.mean())
haystack_utils.two_histogram(top_german_token_differences, top_english_token_differences, 
                             "German unigrams", "English unigrams", "German and English unigram boost from context neuron",
                             x_label="Logit difference (active - inactive)", y_label="count")

tensor(1.8280, device='cuda:0')
tensor(0.1989, device='cuda:0')


In [19]:
tokens = [i for i in range(model.cfg.d_vocab)]
differences = token_differences.cpu().numpy()
labels = []
for token in tokens:
    if token in unique_english_unigrams:
        labels.append("English")
    elif token in unique_german_unigrams:
        labels.append("German")
    else:
        labels.append("Other")
df = pd.DataFrame({"token": tokens, "difference": differences, "label": labels})
px.histogram(df, x="difference", color="label", nbins=300, title="Unigram effects of German neuron", width=1200)

## Bigram

In [32]:
def count_word_occurrences(prompts: list[str], word: str):
    occurrences = 0
    for prompt in prompts:
        # Count number of occurrences in a single prompt
        occurrences += prompt.count(word)
    return occurrences

def get_word_contexts(prompts: list[str], word: str, context_size: int=40):
    contexts = []
    for prompt in prompts:
        # Count number of occurrences in a single prompt
        if word in prompt:
            word_index = prompt.index(word)
            start_index = max(0, word_index - context_size)
            if word_index - start_index > 1:
                contexts.append(prompt[start_index:word_index])
    return contexts

## Gemeinsam

## Gleichstellung

In [33]:
print(count_word_occurrences(german_europarl, "Gleichstellung"))
print(count_word_occurrences(german_europarl, "Gleichheit"))
print(count_word_occurrences(german_europarl, "Gleichst"))
print(count_word_occurrences(german_europarl, "ichst"))
print(count_word_occurrences(german_europarl, "ichheit"))

49
5
49
119
33


In [34]:
context_st = get_word_contexts(german_europarl, "Gleichstellung")
for context in context_st:
    print(context)

chusses für die Rechte der Frau und die 
genheit dar, wichtige Diskussionen über 
chusses für die Rechte der Frau und die 
Integrierter Ansatz der 
Die Frage des gleichen Entgelts und der 
ng der Ergebnisse des Fahrplans für die 
chusses für die Rechte der Frau und die 
aaten haben sich dazu verpflichtet, die 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 
nträchtigt sind, bis zum Jahr 2020 eine 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 
eses Verfahrens geleitet haben, war die 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 
orschriften, die dazu verpflichten, die 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 
gen sind von großer Bedeutung, wenn die 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 
chusses für die Rechte der Frau und die 


In [35]:
context_st = get_word_contexts(german_europarl, "ichheit")
for context in context_st:
    print(context)

 Stehende zu tun, um die Geschlechtergle
13. Fortschritte in Bezug auf Chancengle
für diesen Fonds vor, das die Chancengle
ndel als Chance begreifen und Chancengle
7. Abbau gesundheitlicher Ungle
ngs in einer ausgewogenen Art, die Ungle
e Stelle von Kolonialismus und Armut Gle
estehende Richtlinie zu Wettbewerbsungle
on Diskriminierung zu machen und die Gle
n und tief verankerter kultureller Ungle
ihre Arbeit zur Förderung der Chancengle
erte - Würde des Menschen, Freiheit, Gle
ichtdiskriminierung basiert, auf dem Gle
s für die Rechte der Frau und Chancengle
gsantrag unterbreiten, um die Chancengle
entrieren wird. Erstens auf Armut, Ungle
s für die Rechte der Frau und Chancengle
Abbau gesundheitlicher Ungle
s für die Rechte der Frau und Chancengle
ein Antrag abgelehnt, weil es Stimmengle
ganisiert, die auf dem Grundsatz der Gle
en zur Konfliktverhütung weiterhin Ungle
fortdauernden Existenz erheblicher Ungle


In [36]:
print(count_word_occurrences(german_europarl, "Veränder"))
print(count_word_occurrences(german_europarl, "Verord"))

43
419
