In [9]:
import csv
import numpy as np
from tqdm.notebook import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [10]:
device = 'cpu'
model_name = 'gpt2'
model_precision = "float32"
target_token_idx = 11
max_length = 2048
input_fn = './prop_inputs.csv'
output_fn = f'./scores_{model_name}.csv'
tokenizer_known = True

In [11]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

In [12]:
in_data = list(csv.reader(open(input_fn, 'rt')))
header = in_data[0]
in_data = in_data[1:]

In [13]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [14]:
def find_first_occurrence(larger_list, smaller_list):
    larger_len = len(larger_list)
    smaller_len = len(smaller_list)
    
    for i in range(larger_len - smaller_len + 1):
        if larger_list[i:i+smaller_len] == smaller_list:
            return i  # Return the index of the first occurrence
    
    return -1  # Return -1 if the smaller list is not found

In [None]:
for i, line in tqdm(enumerate(in_data), total=len(in_data)):
    line_idx, sentence, char_idx, w1, w2, substituted = line
    char_idx = int(char_idx)
    
    if tokenizer_known:
        input_ids = tokenizer.encode(sentence, \
                             return_tensors='pt', \
                             max_length=max_length, \
                             padding=False).to(device)
        w1_toks = tokenizer.encode(f' {w1}', return_tensors='pt')[0].tolist()
        idx = find_first_occurrence(input_ids[0].tolist(), w1_toks)
        
        # cut prefix
        input_ids = input_ids[:, :idx]
    else:
        w1_idx = tokenizer.encode(f' {w1}', return_tensors='pt')[0,0].item()
        w2_idx = tokenizer.encode(f' {w2}', return_tensors='pt')[0,0].item()
        
        prefix = sentence[:char_idx]
        input_ids = tokenizer.encode(prefix, \
                                     return_tensors='pt', \
                                     max_length=max_length, \
                                     padding=False).to(device)
    
    # torch
    try:
        with torch.no_grad():
            model.eval()
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            logits = outputs.logits

        # Get the loss at each token
        last_logits = logits[..., -1, :].contiguous().squeeze(0)
        probs = torch.nn.Softmax(dim=-1)(last_logits)

        # comma_idx = 11
        w1_prob = probs[w1_idx].item()
        w2_prob = probs[w2_idx].item()
        w1_rank = (probs > w1_prob).sum().item()
        w2_rank = (probs > w2_prob).sum().item()
    except:
        w1_prob, w2_prob, w1_rank, w2_rank = None, None, None, None
        print('exception occurred')
    
    if i % 100 == 0:
        print(w1, w2, w1_prob, w2_prob, w1_rank, w2_rank)
        
    out.writerow([line_idx, w1_prob, w2_prob, w1_rank, w2_rank])


  0%|          | 0/20000 [00:00<?, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


prove affirm 0.01375683955848217 1.1634224392764736e-05 10 1198
prove affirm 0.00016848949599079788 1.4045303942111786e-06 320 4256
safety safeness 0.05756528303027153 3.178104634571355e-06 4 4710
safety safeness 0.015600435435771942 1.636337947275024e-05 6 2507
jealous envious 0.0014006283599883318 0.0005682586343027651 123 288
jealous envious 9.764543210621923e-05 9.72606176219415e-06 779 3409
joy glee 8.296465966850519e-05 5.202144166105427e-05 343 512
joy glee 0.004526222590357065 0.0013504911912605166 32 78
own hold 0.017541859298944473 4.115654883207753e-05 4 2420
own hold 0.5197610855102539 0.00012367968156468123 0 265
device equipment 0.10787869989871979 3.411541547393426e-05 1 1382
device equipment 0.0013612364418804646 1.3942602890892886e-05 107 4561
cease halt 0.026659514755010605 0.0005329327541403472 3 154
cease halt 7.847935194149613e-05 0.00011495775106595829 1085 819
heavy hefty 0.007830263115465641 5.17342471084703e-07 16 15757
heavy hefty 0.001938376808539033 1.622348

box container 0.02789812907576561 7.851023156035808e-07 4 11320
box container 0.7095311880111694 6.665844466624549e-06 0 2167
return exchange 0.000841351633425802 8.413029718212783e-05 129 337
return exchange 0.007295373361557722 0.00019156167400069535 28 311
strong big 0.003228088840842247 0.001937916618771851 61 95
strong big 1.9443499695626087e-05 2.322878390259575e-05 2396 2124
big huge 0.0004605727444868535 1.833534042816609e-05 314 4896
big huge 0.0001288865169044584 5.4314372391672805e-05 603 1249
next following 8.206093480112031e-05 1.6722444797778735e-06 489 4036
next following 0.004958702251315117 0.028419775888323784 31 4
cold icy 0.008557568304240704 0.00029636832186952233 12 233
cold icy 9.919025615090504e-06 1.1088495739386417e-06 3385 11905
first initial 0.0003252773603890091 2.498194362487993e-06 122 1706
first initial 0.23512683808803558 0.0003240796213503927 0 198
start begin 0.053537268191576004 0.00011794890451710671 3 321
start begin 0.009735150262713432 0.00013018

In [None]:
out_fh.close()