In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import Dataset, DataLoader
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
import yaml
from datasets import load_from_disk
from sklearn.model_selection import train_test_split 

from utils import *
from ioi_dataset import *

In [2]:
with open('diff_mask_ioi.yml') as f:
    args = yaml.safe_load(f)

tokenizer = GPT2Tokenizer.from_pretrained(join(args['model_dir'], args['model_name']))
tokenizer.pad_token = tokenizer.eos_token

ds = IOIDataset(prompt_type="ABBA", N=1280, tokenizer=tokenizer)
ds_train, ds_test = train_test_split(ds.ioi_prompts, test_size=0.2, random_state=0)
# Note that there are overlaps between train and test sets, due to the way IOIDataset is constructed (randomly sample N items)

ioi_ds_train = CircuitIOIDataset(prepare_ioi_data_for_clm(ds_train))
ioi_ds_test = CircuitIOIDataset(prepare_ioi_data_for_clm(ds_test))

train_dl = DataLoader(
    ioi_ds_train,
    batch_size=args['batch_size']
)
eval_dl = DataLoader(
    ioi_ds_train,
    batch_size=args['batch_size'],
    shuffle=False
)


In [3]:
@torch.no_grad()
def eval_model(model, eval_dl, tokenizer, device):
    model.eval()

    total = len(eval_dl.dataset)
    correct = 0

    for batch in eval_dl:
        batch_inputs = prepare_batch_inputs(batch, tokenizer)
        batch_logits = model(batch_inputs['input_ids'].to(device))[0]  # (B, seq_len, vocab_size)
        _, batch_logits_gb = compute_faith_loss(batch_logits, batch_inputs)
        correct += (batch_logits_gb[:, 0] > batch_logits_gb[:, 1]).sum()

    acc = correct / total

    return acc.item()

In [4]:
device = torch.device('cuda')

# # download gpt2-small weights from EasyTransformer and save it
# reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
# torch.save(reference_gpt2.state_dict(), join(args['model_dir'], 'gpt2-small/gpt2_small_weights.pt'))

model = GPT2LMHeadModel.from_pretrained(join(args['model_dir'], args['model_name']))
model.to(device);
tokenizer = GPT2Tokenizer.from_pretrained(join(args['model_dir'], args['model_name']))
tokenizer.pad_token = tokenizer.eos_token


In [5]:
eval_acc = eval_model(model, eval_dl, tokenizer, device)
print("Epoch 0. unmasked model accuracy {:.2f}".format(eval_acc))

Epoch 0. unmasked model accuracy 0.99


In [6]:
ioi_ds_train[0]

{'prompt': 'While Stephanie and Robert were commuting to the hospital, Robert gave a snack to',
 'target good': ' Stephanie',
 'target bad': ' Robert'}