In [1]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
w1 = 'device'
w2 = 'equipment'

model_name = '/home/ryan/haveibeentrainedon/firstshard/analysis/base_model/global_step_2146'
model_precision = "float32"
max_length = 512
device = 'cuda'

In [None]:
input_fn = f'./scores/test_{w1}_{w2}.csv'
output_fn = f'./scores_{w1}_{w2}.csv'

In [3]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

2023-08-31 06:07:12.253810: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [4]:
df = pd.read_csv(input_fn, index_col=0)
df.head(1)

Unnamed: 0,prefix,label,e(x)
0,The ProtonVPN app for iOS has been eagerly awa...,0,0.003125


In [8]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [None]:
for i, (line_idx, row) in tqdm(enumerate(df.iterrows()), total=len(df)):
    # get the first token of each word
    w1_idx = tokenizer.encode(f' {w1}', return_tensors='pt')[0,0].item()
    w2_idx = tokenizer.encode(f' {w2}', return_tensors='pt')[0,0].item()

    input_ids = tokenizer.encode(row['prefix'], \
                                 return_tensors='pt', \
                                 max_length=5000, \
                                 padding=False).to(device)
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[..., -1, :].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    w1_prob = probs[w1_idx].item()
    w2_prob = probs[w2_idx].item()
    w1_rank = (probs > w1_prob).sum().item()
    w2_rank = (probs > w2_prob).sum().item()
    
    if i % 100 == 0:
        print(w1, w2, w1_prob, w2_prob, w1_rank, w2_rank)
        
    out.writerow([line_idx, len(input_ids), w1_prob, w2_prob, w1_rank, w2_rank])


  0%|          | 0/25739 [00:00<?, ?it/s]

device equipment 0.10457620769739151 0.000146665726788342 1 321
device equipment 0.00023248165962286294 6.396496610250324e-05 521 1480
device equipment 0.001559459837153554 6.903450412210077e-06 36 2291
device equipment 0.477184534072876 0.0009061265736818314 0 30
device equipment 0.002323091495782137 0.00027850831975229084 53 184
device equipment 0.003017291659489274 0.010115489363670349 24 9
device equipment 0.027441665530204773 1.0730420399340801e-05 6 3035
device equipment 0.13468067348003387 0.00022794304823037237 0 250
device equipment 0.715025007724762 0.00028426770586520433 0 105
device equipment 0.002520669251680374 9.404282650393725e-07 33 11019
device equipment 0.06985511630773544 0.0001472440198995173 2 404
device equipment 2.6385461751488037e-05 0.0007587118889205158 4010 177
device equipment 7.643694698344916e-05 0.6023879647254944 413 0
device equipment 0.06017984077334404 0.0003766909649129957 4 136
device equipment 0.0017068834276869893 1.1519595318532083e-05 62 4582
d

In [None]:
out_fh.close()