In [None]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
device = 'cuda'
model_name = 'gpt2'
gpt2_tokenizer = True
model_precision = "float16"
max_length = 2048
input_fn = './propagation_inputs.csv'
output_fn = f'./70M/scores:debugging_gpt2.csv'

In [None]:
if gpt2_tokenizer:
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

In [None]:
df = pd.read_csv(input_fn)
df.head(1)

In [None]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [None]:
for i, row in df.iterrows():
    line_idx, sentence, char_idx, w1, w2 = row['example_index'], \
                                            row['text'], row['sub_index'], row['original'], row['synonym']
    line_idx, char_idx = int(line_idx), int(char_idx)
    
    # get the first token of each word
    w1_idx = tokenizer.encode(f' {w1}', return_tensors='pt')[0,0].item()
    w2_idx = tokenizer.encode(f' {w2}', return_tensors='pt')[0,0].item()

    input_ids = tokenizer.encode(sentence[:char_idx], \
                                 return_tensors='pt', \
                                 max_length=None, \
                                 padding=False).to(device)
    
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[..., -1, :].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    w1_prob = probs[w1_idx].item()
    w2_prob = probs[w2_idx].item()
    w1_rank = (probs > w1_prob).sum().item()
    w2_rank = (probs > w2_prob).sum().item()
    
    if i % 100 == 0:
        print(w1, w2, w1_prob, w2_prob, w1_rank, w2_rank)

    out.writerow([line_idx, w1_prob, w2_prob, w1_rank, w2_rank])


In [None]:
out_fh.close()