In [1]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
device = 'cuda'
model_name = 'EleutherAI/pythia-70m'
gpt2_tokenizer = False
model_precision = "float32"
max_length = 2048
input_fn = './non-perturbed_inputs.csv'
output_fn = f'./70M/scores_clean_data:pythia-70m.csv'

In [3]:
if gpt2_tokenizer:
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

In [4]:
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

2023-09-06 01:41:15.484267: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


Downloading model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

In [5]:
df = pd.read_csv(input_fn)
df.head(1)

Unnamed: 0,example_index,text,sub_index,original,synonym,substituted
0,880143,Purpose & Goals\n\nWhen Sol Worth and John Ada...,9701,nice,good,False


In [6]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [None]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    line_idx, sentence, char_idx, w1, w2 = row['example_index'], \
                                            row['text'], row['sub_index'], row['original'], row['synonym']
    line_idx, char_idx = int(line_idx), int(char_idx)
    
    # get the first token of each word
    w1_idx = tokenizer.encode(f' {w1}', return_tensors='pt')[0,0].item()
    w2_idx = tokenizer.encode(f' {w2}', return_tensors='pt')[0,0].item()

    input_ids = tokenizer.encode(sentence[:char_idx], \
                                 return_tensors='pt', \
                                 max_length=None, \
                                 padding=False).to(device)
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[..., -1, :].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    w1_prob = probs[w1_idx].item()
    w2_prob = probs[w2_idx].item()
    w1_rank = (probs > w1_prob).sum().item()
    w2_rank = (probs > w2_prob).sum().item()
    
    if i % 100 == 0:
        print(w1, w2, w1_prob, w2_prob, w1_rank, w2_rank)
        
    out.writerow([line_idx, w1_prob, w2_prob, w1_rank, w2_rank])


  0%|          | 0/90000 [00:00<?, ?it/s]

nice good 0.012711802497506142 0.10210757702589035 9 0
nice good 0.031637437641620636 0.013169141486287117 2 12
nice good 0.002173657761886716 0.002686075633391738 72 54
nice good 7.710637146374211e-05 0.0001231855567311868 1797 1243
nice good 0.008224908262491226 0.01977876015007496 18 2
nice good 0.014255993999540806 0.04135696589946747 13 2
nice good 0.0072755878791213036 0.003623285098001361 11 33
nice good 0.01477772369980812 0.056586600840091705 10 2
nice good 0.035943757742643356 0.08778580278158188 3 1
nice good 0.002200210699811578 0.014669491909444332 58 9
nice good 0.001087198848836124 0.0036837877705693245 165 21
nice good 0.0006889688083902001 0.001643295050598681 234 96
nice good 0.006252349819988012 0.004062525928020477 29 38
nice good 2.0098947061342187e-05 2.3633254386368208e-05 2580 2358
nice good 0.11366105079650879 0.0834001824259758 2 3
nice good 0.01814824901521206 0.033298347145318985 8 5
nice good 0.011690338142216206 0.03861762955784798 13 2
nice good 0.0082594

just quite 0.01702716015279293 0.001976337283849716 7 64
just quite 0.0024366776924580336 0.0014503960264846683 23 36
just quite 0.00222739577293396 1.1632786481641233e-05 37 1567
just quite 0.019516224041581154 7.541746640526981e-07 6 3921
just quite 0.024292413145303726 0.007697319611907005 4 10
just quite 0.002933002542704344 0.00012546771904453635 41 1046
just quite 0.005346666555851698 0.0014574211090803146 24 80
first initial 0.016098307445645332 0.0022192001342773438 6 69
first initial 0.1650756448507309 0.004030455369502306 0 24
first initial 0.2148286998271942 0.0021948216017335653 0 48
first initial 0.3355115056037903 0.00017665827181190252 0 72
first initial 0.377721905708313 0.0009576635202392936 0 93
first initial 0.20306648313999176 0.00030093203531578183 0 328
first initial 0.010922128334641457 0.0002901425468735397 11 421
first initial 0.03333112969994545 4.054672535858117e-05 2 259
first initial 0.10363025963306427 0.00015973033441696316 1 353
first initial 0.005320130

more greater 0.003742010798305273 8.057945524342358e-05 26 1222
more greater 0.04093429818749428 0.00020019442308694124 5 208
more greater 0.009054407477378845 4.204811557428911e-05 7 1816
more greater 0.007243847940117121 9.825979213928804e-05 5 76
more greater 0.0019198193913325667 8.40300890558865e-06 57 6028
more greater 0.0034436043351888657 1.6806476196506992e-05 45 1710
more greater 0.2112162858247757 0.16357432305812836 0 1
more greater 0.3909270465373993 0.0018125628121197224 0 28
more greater 0.12734943628311157 0.00048730691196396947 1 80
more greater 0.022909142076969147 2.3194277673610486e-05 10 566
more greater 0.030094867572188377 0.00032926115090958774 4 235
more greater 0.026212049648165703 1.926943878061138e-05 3 2203
more greater 0.0007919601048342884 0.00012213490845169872 70 395
more greater 0.01126466691493988 3.803041545324959e-05 8 989
more greater 0.06326597929000854 0.0001586305006640032 4 309
more greater 0.0013171685859560966 4.800584974873345e-06 64 2776
tr

help assist 0.017977945506572723 0.00011778075713664293 6 648
help assist 0.0008350354037247598 4.7612444177502766e-05 94 669
help assist 0.000306975154671818 3.889803338097408e-06 196 4113
help assist 0.7704483866691589 0.005481034982949495 0 7
help assist 0.016715358942747116 0.00016773722018115222 8 364
help assist 0.019940398633480072 0.0030208579264581203 3 30
help assist 0.03819143399596214 0.004578051157295704 1 50
help assist 0.11489815264940262 0.04000451788306236 2 3
help assist 0.021514883264899254 4.462425931706093e-05 5 2070
help assist 0.026221884414553642 0.0035678583662956953 3 31
area location 0.012886825948953629 0.0001714443351374939 11 192
area location 0.01931421086192131 9.121776383835822e-05 7 847
area location 4.099620127817616e-05 5.425002655101707e-06 330 1671
area location 0.00010312165977666155 2.381832746323198e-05 393 1367
area location 0.005813781637698412 4.523664392763749e-05 23 1610
area location 0.011512603610754013 5.91702337260358e-06 14 4039
area l

choose pick 0.0001038640839396976 0.001027947524562478 1115 116
choose pick 5.822738239658065e-05 3.916259629477281e-06 638 3830
choose pick 0.0017959960969164968 0.001648098579607904 89 99
choose pick 0.005983526352792978 0.0006575451116077602 21 151
choose pick 0.0007928446866571903 0.001349991885945201 118 85
choose pick 0.007538928650319576 0.0021034751553088427 24 74
choose pick 4.992409230908379e-05 0.0006624552188441157 746 101
choose pick 0.0017530058976262808 0.0002840820234268904 97 360
choose pick 0.018896035850048065 0.0066729276441037655 5 16
choose pick 0.0023311886470764875 0.0011831206502392888 73 113
house home 0.007650097366422415 0.017185132950544357 15 6
house home 0.005764059256762266 0.0036611768882721663 15 26
house home 0.00010433012357680127 0.0001126563802245073 970 916
house home 0.029703153297305107 0.0008996989927254617 6 74
house home 0.03410550206899643 0.003653530729934573 2 29
house home 0.0003122043563053012 0.00014077308878768235 480 883
house home 0.

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



guess predict 0.00846216268837452 0.0019155582413077354 18 86
guess predict 0.009782316163182259 1.8787229691952234e-06 16 2912
guess predict 0.0038528062868863344 1.5818746760487556e-05 50 1151
guess predict 0.016189180314540863 1.42092112582759e-05 11 1114
guess predict 0.007761209737509489 4.665837877837475e-06 24 1785


In [None]:
out_fh.close()