In [1]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from parallelformers import parallelize

2023-09-24 17:09:41.629342: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [2]:
device = 'cuda'
model_name = 'EleutherAI/pythia-160m'
folder_name = 'outs/pythia-160m'
gpt2_tokenizer = False
max_length = 2048

w1 = 'file'
w2 = 'record'

In [3]:
input_fn = f'./scores/ps_{w1}_{w2}.csv'
output_fn = f'./{folder_name}/scores_{w1}_{w2}.csv'

In [4]:
if gpt2_tokenizer:
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
# for gpt-j-6b
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')

In [6]:
df = pd.read_csv(input_fn)
df.head(1)

Unnamed: 0,prefix,idx,label,e(x)
0,Q:\n\nWhat's the simplest way to pass a,9,0,0.052613


In [7]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [8]:
# get the first token of each word
w1_idx = tokenizer.encode(f' {w1}', return_tensors='pt')[0,0].item()
w2_idx = tokenizer.encode(f' {w2}', return_tensors='pt')[0,0].item()
w1_idx, w2_idx

(1873, 1924)

In [9]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    line_idx, sentence = row['idx'], row['prefix']

    input_ids = tokenizer.encode(sentence, \
                                 return_tensors='pt', \
                                 max_length=None, \
                                 truncation=True, \
                                 padding=False).to(device)
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[..., -1, :].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    w1_prob = probs[w1_idx].item()
    w2_prob = probs[w2_idx].item()
    w1_rank = (probs > w1_prob).sum().item()
    w2_rank = (probs > w2_prob).sum().item()
    
    if i % 100 == 0:
        print(w1, w2, w1_prob, w2_prob, w1_rank, w2_rank)

    out.writerow([line_idx, w1_prob, w2_prob, w1_rank, w2_rank])


  0%|          | 0/20000 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


file record 0.005846372805535793 0.0025126526597887278 19 55
file record 0.017433181405067444 0.0016193627379834652 0 90
file record 0.5376183986663818 0.0001965472474694252 0 135
file record 0.16075262427330017 0.00013290844799485058 0 586
file record 0.4463481903076172 1.8212425857200287e-05 0 638
file record 0.24595585465431213 2.5441806428716518e-05 0 916
file record 0.008760367520153522 0.00025580712826922536 10 482
file record 0.058662835508584976 0.0007048978004604578 0 189
file record 0.9976406097412109 1.6223305010498734e-06 0 103
file record 0.0824633464217186 0.003285686718299985 1 40
file record 0.0013930844143033028 2.0332033727754606e-06 29 2415
file record 0.17074178159236908 3.024027682840824e-05 1 522
file record 0.9971606731414795 1.3435737855616026e-06 0 104
file record 0.027553889900445938 2.3645410692552105e-05 2 471
file record 0.9987218976020813 1.3903444369134377e-06 0 59
file record 0.17758460342884064 5.440761742647737e-05 0 529
file record 0.3921106159687042 

file record 1.873963810794521e-05 0.03273556008934975 189 1
file record 0.0001993427285924554 0.8728212118148804 59 0
file record 6.0658085203613155e-06 0.03157825022935867 4506 4
file record 0.00028677095542661846 0.5745514631271362 115 0
file record 7.204903340607416e-07 0.004675684496760368 12656 27
file record 0.0004213660431560129 0.009171206504106522 176 23
file record 3.2056628697318956e-05 0.005813886411488056 3556 13
file record 9.591384150553495e-05 0.09755121916532516 908 0
file record 0.0007290033390745521 0.0869254544377327 71 3
file record 0.0041100988164544106 0.0004206425801385194 36 265
file record 7.334262772928923e-05 0.07782702893018723 138 1
file record 6.661124643869698e-05 0.010437858290970325 677 6
file record 4.0695660175060766e-08 0.4887046813964844 11109 0
file record 6.432936174860515e-07 0.01599060371518135 11215 10
file record 8.786165562923998e-05 0.001216079923324287 829 126
file record 7.104874384822324e-05 0.01241847313940525 340 7
file record 0.000100

In [10]:
input_ids.shape

torch.Size([1, 1095])

In [11]:
out_fh.close()