In [1]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from sklearn.model_selection import train_test_split
from dataclasses import dataclass
import argparse

import sys
sys.path.append('/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery')
from dmc.circuit_gpt import *
from oqa_dataset import *
from oqa_utils import *
from tqdm.auto import tqdm

data_dir = '/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery/data/'

In [2]:
# load OQA data
ds_dict = pickle.load(open(join(data_dir, 'pararel_capital_ds_dict.p'), 'rb'))
full_model_target_log_probs = torch.load(f'full_model_results/target_log_probs.pt')
full_model_pred_labels = torch.load(f'full_model_results/pred_labels.pt')
capital_vocab_idx = torch.load(f'full_model_results/capital_vocab_idx.pt')
ds_dict['full_model_target_log_probs'] = full_model_target_log_probs
ds_dict['full_model_pred_labels'] = full_model_pred_labels

ds = OQACircuitDataset(ds_dict)

In [4]:
len(ds)

937

In [6]:
n_correct = 0
for i in range(len(ds)):
    if ds[i]['label'] == ds[i]['full_model_pred_label']:
        n_correct += 1

In [8]:
n_correct / 937

0.3479188900747065

In [9]:
ds[0]

{'prompt': 'The capital of Azerbaijan is',
 'label': 433,
 'full_model_target_log_probs': tensor([ -8.1028, -12.7846,  -9.2102, -12.3880,  -9.7646,  -7.5984, -14.1302,
          -9.1673, -17.0951, -10.6358, -13.7977, -10.8338, -10.1600,  -7.4719,
         -10.9165, -10.8896,  -7.0938,  -5.9943,  -8.9910,  -7.3838, -13.6729,
         -10.7229,  -7.3618, -10.0352,  -8.7672, -10.4191, -10.9485,  -3.1239,
          -4.2082, -11.3415, -11.2197, -11.1878,  -9.3578,  -9.2304, -11.5418,
         -10.6972,  -8.5861, -10.8091,  -2.8443,  -6.9939,  -7.9517, -10.7942,
         -11.2243, -12.4147, -13.5209,  -6.1308,  -8.9237,  -4.1074, -14.6083,
          -8.3737, -11.4914,  -9.2978,  -6.1308,  -9.9293,  -9.2709,  -8.7131,
          -4.0978, -13.1979, -13.2176,  -4.0978,  -6.4935, -10.0638,  -9.6405,
         -10.9531,  -3.1239, -10.6628, -14.0896, -12.9372, -10.1742,  -8.9641,
         -11.3667, -11.8434, -13.6609,  -9.6617, -11.6937, -15.0913, -10.6275,
         -13.5518,  -9.1191,  -7.4061, -10

In [10]:


capital_vocab_idx

tensor([ 2806, 23194, 28975, 11328, 45560,  1717, 26070, 20522, 49398, 36026,
        46578, 37777, 33859,  8078, 36421, 29713,  3175, 11294, 37079,  2547,
        48471, 14457,  6866, 46154, 27902, 29679, 24579,   347,  3944, 28293,
         7517,  9470, 12313,  9589, 29141, 44665, 22372, 14074, 33605,  7049,
        16849,  7979, 24017, 13316, 43676,   911,  9643,   367, 17321, 35247,
        14576,  1810,   911, 21574,  1902, 35794,   309, 42222, 13612,   309,
        22676,  2447, 41782, 47561,   347, 27437, 37277, 26482, 24485, 16639,
        10593, 12568, 40644, 49251, 25567, 21105, 41578, 42998, 36839,  1041,
        20818,  9502, 12088, 38681, 33368,   943,  2869,   509, 34438, 32955,
        26914,  5506, 17456,  9910, 24533, 44373, 24320,   337, 49931, 46997,
        20741, 30311, 36980, 12164, 34802, 36126, 41388, 41624, 14679,   347,
        23995, 14021, 21435,  6182, 48823, 31890, 29702, 41946,   347, 34974,
        46216, 14538,  2066, 42748, 16952, 18008, 49268,  5215, 

In [11]:
model_dir = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
model_name = 'gpt2-small'

model_path = join(model_dir, model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [12]:
tokenizer.convert_ids_to_tokens(capital_vocab_idx)

['ĠBro',
 'ĠJacksonville',
 'ĠCroatia',
 'ĠWarren',
 'ĠBohem',
 'ĠEd',
 'ĠCharleston',
 'ĠMemphis',
 'ĠWarwick',
 'ĠToledo',
 'ĠDover',
 'ĠYugoslavia',
 'ĠCatalonia',
 'ĠJordan',
 'ĠUruguay',
 'ĠBelfast',
 'ĠPal',
 'ĠCas',
 'ĠPrague',
 'ĠPar',
 'ĠEaton',
 'ĠCambridge',
 'ĠSar',
 'ĠKarachi',
 'ĠBulgaria',
 'ĠStockholm',
 'ĠAuburn',
 'ĠB',
 'ĠBel',
 'ĠManitoba',
 'ĠIreland',
 'ĠKansas',
 'ĠVictoria',
 'ĠPennsylvania',
 'ĠSyracuse',
 'ĠConstantine',
 'ĠSeoul',
 'ĠOttawa',
 'ĠAzerbaijan',
 'ĠUkraine',
 'ĠLebanon',
 'ĠBon',
 'ĠVic',
 'ĠPhilippines',
 'ĠFiji',
 'ĠSh',
 'ĠPhoenix',
 'ĠH',
 'ĠNewton',
 'ĠMecca',
 'ĠPhilip',
 'ĠWar',
 'ĠSh',
 'ĠMilan',
 'ĠGr',
 'ĠLatvia',
 'ĠT',
 'ĠOral',
 'ĠParker',
 'ĠT',
 'ĠTall',
 'ĠAug',
 'ĠOman',
 'ĠStras',
 'ĠB',
 'ĠBoulder',
 'ĠIrvine',
 'ĠSalvador',
 'ĠHassan',
 'ĠCzech',
 'ĠBarb',
 'ĠBeat',
 'ĠHuntington',
 'ĠJakarta',
 'ĠCork',
 'ĠQueensland',
 'ĠLima',
 'ĠTours',
 'ĠSlovakia',
 'ĠPro',
 'ĠNaval',
 'ĠManchester',
 'ĠTel',
 'ĠStevenson',
 'ĠBelarus',


In [8]:
import json
from os.path import join
from transformers import AutoTokenizer

data_dir = '/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery/data/'
model_dir = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
model_name = 'gpt2-small'

with open(join(data_dir, 'pararel_data_all.json')) as open_file:
    pararel_rel_data = json.load(open_file)  

model_path = join(model_dir, model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [9]:
rel_id = 'P136'

data = pararel_rel_data[rel_id]
    
answer_vocab = []
ds_dict = {
    'prompt': [],
    'answer': [],
}

for entry in data:
    prompt = entry[0][0].replace(' [MASK] .', '')
    prompt = prompt.replace(' [MASK].', '')
    if '[MASK]' not in prompt:
        target = entry[0][1]
        if target:
            ds_dict['prompt'].append(prompt)
            ds_dict['answer'].append(' ' + target)
            answer_vocab.append(' ' + target)

answer_vocab = list(set(ds_dict['answer']))
print(tokenizer(answer_vocab).input_ids)
answer_vocab_idx = torch.tensor([
    input_ids[0] for input_ids in tokenizer(answer_vocab).input_ids
])

IndexError: list index out of range

In [10]:
answer_vocab

[]