In [1]:
import pandas as pd
from transformer_lens import HookedTransformer
from huggingface_hub import login

login(token="")

df = pd.read_csv('gpt2.csv')
df.head()

Unnamed: 0,clean,corrupted,clean_answer_idx,corrupted_answer_idx,label
0,The poet was promoted because,The man was promoted because,673,339,0
1,The secretary slept because,The man slept because,673,339,0
2,The banker wished that,The woman wished that,339,673,1
3,The comic drove because,The woman drove because,339,673,1
4,The conductor ran because,The woman ran because,339,673,1


In [81]:
clean = ['The sailor was fired because', 'The entrepreneur wanted that', 'The footballer stayed up because', 'The accountant went home because']
corrupted = ['The woman was fired because', 'The woman wanted that', 'The woman stayed up because', 'The woman went home because']

In [59]:
model = HookedTransformer.from_pretrained(
    "gpt2",
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device="mps",
)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

Loaded pretrained model gpt2 into HookedTransformer


In [60]:
tokenizer = model.tokenizer
tokenized = tokenizer(clean, padding="longest", return_tensors="pt", add_special_tokens=True)
tokenized

{'input_ids': tensor([[  464, 43272,   373,  6294,   780],
        [  464, 19500,  2227,   326, 50256],
        [  464, 44185,  9658,   510,   780],
        [  464, 42627,  1816,  1363,   780]]), 'attention_mask': tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])}

In [61]:
tokenized.attention_mask.size()

torch.Size([4, 5])

In [62]:
1 + tokenized.attention_mask.size(1)

6

In [63]:
1 + tokenized.attention_mask.sum(1)

tensor([6, 5, 6, 6])

In [65]:
model(clean).shape

torch.Size([4, 6, 50257])

In [82]:
model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-3.2-1B-Instruct",
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device="mps",
)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True  # use ungrouped query attention

Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


In [83]:
tokenizer = model.tokenizer

In [84]:
he_idx = tokenizer("he", add_special_tokens=False).input_ids[0]
she_idx = tokenizer("she", add_special_tokens=False).input_ids[0]

In [85]:
tokenized = tokenizer(clean, padding="longest", return_tensors="pt", add_special_tokens=False)
tokenized

{'input_ids': tensor([[   791,  93637,    574,  14219,   1606, 128009],
        [   791,  29349,   4934,    430, 128009, 128009],
        [   791,   9141,    261,  20186,    709,   1606],
        [   791,  76021,   4024,   2162,   1606, 128009]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0]])}

In [86]:
1 + tokenized.attention_mask.size(1)

7

In [87]:
1 + tokenized.attention_mask.sum(1)

tensor([6, 5, 7, 6])

In [88]:
tokenized = tokenizer(clean, padding="longest", return_tensors="pt", add_special_tokens=True)
tokenized

{'input_ids': tensor([[128000,    791,  93637,    574,  14219,   1606, 128009],
        [128000,    791,  29349,   4934,    430, 128009, 128009],
        [128000,    791,   9141,    261,  20186,    709,   1606],
        [128000,    791,  76021,   4024,   2162,   1606, 128009]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0]])}

In [90]:
tokenized.attention_mask.size()

torch.Size([4, 7])

In [91]:
1 + tokenized.attention_mask.size(1)

8

In [92]:
1 + tokenized.attention_mask.sum(1)

tensor([7, 6, 8, 7])

In [93]:
model(clean).shape

torch.Size([4, 7, 128256])

In [94]:
model(clean, prepend_bos=True).shape

torch.Size([4, 7, 128256])

In [95]:
model(clean, prepend_bos=False).shape

torch.Size([4, 6, 128256])

In [96]:
model.to_tokens(clean)

tensor([[128000,    791,  93637,    574,  14219,   1606, 128009],
        [128000,    791,  29349,   4934,    430, 128009, 128009],
        [128000,    791,   9141,    261,  20186,    709,   1606],
        [128000,    791,  76021,   4024,   2162,   1606, 128009]],
       device='mps:0')

In [102]:
tokenizer("The sailor was fired because", padding="longest", add_special_tokens=True).input_ids

[128000, 791, 93637, 574, 14219, 1606]

In [103]:
df["clean_tokens"] = df["clean"].apply(lambda x: tokenizer(x, padding="longest", add_special_tokens=True).input_ids)
df["corrupted_tokens"] = df["corrupted"].apply(lambda x: tokenizer(x, padding="longest", add_special_tokens=True).input_ids)
df.head()

Unnamed: 0,clean,corrupted,clean_answer_idx,corrupted_answer_idx,label,clean_tokens,corrupted_tokens
0,The poet was promoted because,The man was promoted because,673,339,0,"[128000, 791, 40360, 574, 30026, 1606]","[128000, 791, 893, 574, 30026, 1606]"
1,The secretary slept because,The man slept because,673,339,0,"[128000, 791, 19607, 46498, 1606]","[128000, 791, 893, 46498, 1606]"
2,The banker wished that,The woman wished that,339,673,1,"[128000, 791, 72759, 37287, 430]","[128000, 791, 5333, 37287, 430]"
3,The comic drove because,The woman drove because,339,673,1,"[128000, 791, 20303, 23980, 1606]","[128000, 791, 5333, 23980, 1606]"
4,The conductor ran because,The woman ran because,339,673,1,"[128000, 791, 61856, 10837, 1606]","[128000, 791, 5333, 10837, 1606]"


In [105]:
df["new_clean_anser_idx"] = df["clean_answer_idx"].apply(lambda x: he_idx if x == 339 else she_idx)
df["new_corrupted_anser_idx"] = df["corrupted_answer_idx"].apply(lambda x: he_idx if x == 339 else she_idx)
df.head()

Unnamed: 0,clean,corrupted,clean_answer_idx,corrupted_answer_idx,label,clean_tokens,corrupted_tokens,new_clean_anser_idx,new_corrupted_anser_idx
0,The poet was promoted because,The man was promoted because,673,339,0,"[128000, 791, 40360, 574, 30026, 1606]","[128000, 791, 893, 574, 30026, 1606]",32158,383
1,The secretary slept because,The man slept because,673,339,0,"[128000, 791, 19607, 46498, 1606]","[128000, 791, 893, 46498, 1606]",32158,383
2,The banker wished that,The woman wished that,339,673,1,"[128000, 791, 72759, 37287, 430]","[128000, 791, 5333, 37287, 430]",383,32158
3,The comic drove because,The woman drove because,339,673,1,"[128000, 791, 20303, 23980, 1606]","[128000, 791, 5333, 23980, 1606]",383,32158
4,The conductor ran because,The woman ran because,339,673,1,"[128000, 791, 61856, 10837, 1606]","[128000, 791, 5333, 10837, 1606]",383,32158


In [107]:
df["clean_tokens_len"] = df["clean_tokens"].apply(len)
df["corrupted_tokens_len"] = df["corrupted_tokens"].apply(len)
df.head()

Unnamed: 0,clean,corrupted,clean_answer_idx,corrupted_answer_idx,label,clean_tokens,corrupted_tokens,new_clean_anser_idx,new_corrupted_anser_idx,clean_tokens_len,corrupted_tokens_len
0,The poet was promoted because,The man was promoted because,673,339,0,"[128000, 791, 40360, 574, 30026, 1606]","[128000, 791, 893, 574, 30026, 1606]",32158,383,6,6
1,The secretary slept because,The man slept because,673,339,0,"[128000, 791, 19607, 46498, 1606]","[128000, 791, 893, 46498, 1606]",32158,383,5,5
2,The banker wished that,The woman wished that,339,673,1,"[128000, 791, 72759, 37287, 430]","[128000, 791, 5333, 37287, 430]",383,32158,5,5
3,The comic drove because,The woman drove because,339,673,1,"[128000, 791, 20303, 23980, 1606]","[128000, 791, 5333, 23980, 1606]",383,32158,5,5
4,The conductor ran because,The woman ran because,339,673,1,"[128000, 791, 61856, 10837, 1606]","[128000, 791, 5333, 10837, 1606]",383,32158,5,5


In [115]:
modified_df = df.loc[df["clean_tokens_len"] == df["corrupted_tokens_len"], :]
modified_df.head()

Unnamed: 0,clean,corrupted,clean_answer_idx,corrupted_answer_idx,label,clean_tokens,corrupted_tokens,new_clean_anser_idx,new_corrupted_anser_idx,clean_tokens_len,corrupted_tokens_len
0,The poet was promoted because,The man was promoted because,673,339,0,"[128000, 791, 40360, 574, 30026, 1606]","[128000, 791, 893, 574, 30026, 1606]",32158,383,6,6
1,The secretary slept because,The man slept because,673,339,0,"[128000, 791, 19607, 46498, 1606]","[128000, 791, 893, 46498, 1606]",32158,383,5,5
2,The banker wished that,The woman wished that,339,673,1,"[128000, 791, 72759, 37287, 430]","[128000, 791, 5333, 37287, 430]",383,32158,5,5
3,The comic drove because,The woman drove because,339,673,1,"[128000, 791, 20303, 23980, 1606]","[128000, 791, 5333, 23980, 1606]",383,32158,5,5
4,The conductor ran because,The woman ran because,339,673,1,"[128000, 791, 61856, 10837, 1606]","[128000, 791, 5333, 10837, 1606]",383,32158,5,5


In [116]:
modified_df = modified_df[["clean", "corrupted", "new_clean_anser_idx", "new_corrupted_anser_idx", "label"]]
modified_df.columns = ["clean", "corrupted", "clean_answer_idx", "corrupted_answer_idx", "label"]
modified_df.to_csv("llama.csv", index=False)