In [9]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [5]:
device = 'cuda'
model_name = '../models/original_final/'
model_precision = "float32"
max_length = 2048
input_fn = './non-perturbed_inputs.csv'
output_fn = f'./scores_non-perturbed:original_final.csv'

In [6]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

2023-09-01 11:28:19.256478: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [10]:
df = pd.read_csv(input_fn)
df.head(1)

Unnamed: 0,example_index,text,sub_index,original,synonym,substituted
0,880143,Purpose & Goals\n\nWhen Sol Worth and John Ada...,9701,nice,good,False


In [11]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [None]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    line_idx, sentence, char_idx, w1, w2 = row['example_index'], \
                                            row['text'], row['sub_index'], row['original'], row['synonym']
    line_idx, char_idx = int(line_idx), int(char_idx)
    
    # get the first token of each word
    w1_idx = tokenizer.encode(f' {w1}', return_tensors='pt')[0,0].item()
    w2_idx = tokenizer.encode(f' {w2}', return_tensors='pt')[0,0].item()

    input_ids = tokenizer.encode(sentence[:char_idx], \
                                 return_tensors='pt', \
                                 max_length=5000, \
                                 padding=False).to(device)
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[..., -1, :].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    w1_prob = probs[w1_idx].item()
    w2_prob = probs[w2_idx].item()
    w1_rank = (probs > w1_prob).sum().item()
    w2_rank = (probs > w2_prob).sum().item()
    
    if i % 100 == 0:
        print(w1, w2, w1_prob, w2_prob, w1_rank, w2_rank)
        
    out.writerow([line_idx, w1_prob, w2_prob, w1_rank, w2_rank])


  0%|          | 0/90000 [00:00<?, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


nice good 0.006550408899784088 0.03845258429646492 13 1
nice good 0.04236588254570961 0.034147586673498154 2 4
nice good 0.0029392787255346775 0.003965666983276606 26 19
nice good 3.948385256080655e-06 0.00014767009997740388 10184 1023
nice good 0.00854539591819048 0.024971963837742805 20 6
nice good 0.010442961007356644 0.046687930822372437 15 2
nice good 0.0031285989098250866 0.003525420557707548 34 29
nice good 0.010693729855120182 0.04332621768116951 13 2
nice good 0.009312819689512253 0.03257152438163757 17 2
nice good 1.5639065509276406e-07 4.979816026207118e-07 23828 15329
nice good 0.0016530642751604319 0.0072836801409721375 61 4
nice good 0.0003693054895848036 0.004826112184673548 450 18
nice good 0.01296952273696661 0.004757077433168888 13 30
nice good 7.7824836353102e-07 3.1314186799136223e-06 15597 8148
nice good 0.03258327394723892 0.040043462067842484 4 2
nice good 0.017625875771045685 0.09762945026159286 10 1
nice good 0.009185930714011192 0.05814380198717117 10 1
nice g

just quite 0.0020657021086663008 4.5448461605701596e-05 51 1017
just quite 0.01658855937421322 3.499452941468917e-05 10 491
just quite 0.005127233918756247 0.0014935765648260713 24 100
just quite 0.0006388704641722143 0.00014636243577115238 61 195
just quite 0.0007753349491395056 1.212089500768343e-05 68 1997
just quite 0.00039094046223908663 0.0001779597660060972 399 802
just quite 0.005505562759935856 0.0027694879099726677 14 36
just quite 0.00011559477570699528 1.372661063214764e-05 1106 4402
just quite 0.007095044944435358 0.0014831472653895617 16 104
first initial 0.014315910637378693 0.0013018479803577065 4 126
first initial 0.0011853461619466543 0.0004196491790935397 160 328
first initial 0.039284951984882355 0.0007396887522190809 0 172
first initial 0.2321123480796814 0.0024060667492449284 0 39
first initial 0.036691997200250626 0.0011240220628678799 0 148
first initial 0.10904807597398758 0.0002909772447310388 0 403
first initial 0.01730264723300934 5.8729350712383166e-05 2 21

In [None]:
out_fh.close()