In [None]:
# Stores NLI For LXMERT 10000

In [1]:
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from pysat.formula import IDPool, WCNFPlus
from pysat.examples.rc2 import RC2
import numpy as np
import matplotlib.pyplot as plt
import json

# custom modules
import sys
sys.path.append('./nlic')
import qa_converter
import nli
import solver

device = "cuda"

In [2]:
### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
data_path = '/data/nli-consistency/vqa/lxmert-test-3pred-40token-1seed_predictions_nli.json'

models = ["ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli"]
### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
save_paths = ['/data/nli-consistency/vqa/lxmert-test-3pred-40token-1seed_predictions_nli-xxlarge.json']

for m_name, save_path in zip(models, save_paths):
    print('Model:', m_name)
    print('Save path:', save_path)
    print('Data path:', data_path)
    
    with open(data_path, 'r') as f:
        data = json.load(f)
    
    nlier = nli.NLIInferencer(model_hf_name=m_name, confidence_threshold=0.0,
                         dedup_constraints=False)

    # Normalize

    num_choices = 2
    not_redundant = True # I.e. self comparisons -- it doesn't add much value since false will still yield truth value
    repeated_comparisons = False # I.e. identical comparisons; as in what if there are multiple answer confidence levels? No because regardless based on single statement
    group_count = num_choices

    raw_correct = 0
    new_correct = 0
    good_change = 0
    bad_change = 0
    count = 0
    questions_done = 0

    for key in data.keys():
        print('image #:', questions_done + 1, 'image number', key)
        img_data = data[key]
        
        for group in img_data.keys():
            group_list = img_data[group]['orig']
            nli_stuff = img_data[group]['nli']
            
            statement_groups = nli_stuff['statement_groups']
            converted_flat = nli_stuff['converted_flat']
            
            compared = nlier(converted_flat, group_count = group_count, not_redundant=not_redundant, fp_batch_size=64)
            if not repeated_comparisons:
                compared = list(set(compared))
            
            data[key][group]['nli']['compared'] = compared
        questions_done += 1        

    print(questions_done)
    
    with open(save_path, 'w') as f:
        json.dump(data, f)   

Model: ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli
Save path: /data/nli-consistency/vqa/lxmert-test-3pred-40token-1seed_predictions_nli-xxlarge.json
Data path: /data/nli-consistency/vqa/lxmert-test-3pred-40token-1seed_predictions_nli.json
image #: 1 image number 2342914
image #: 2 image number 2342916
image #: 3 image number 2342919
image #: 4 image number 2342921
image #: 5 image number 2342923
image #: 6 image number 2342924
image #: 7 image number 2342925
image #: 8 image number 2342926
image #: 9 image number 2342927
image #: 10 image number 2342933
image #: 11 image number 2342934
image #: 12 image number 2342937
image #: 13 image number 2342938
image #: 14 image number 2342939
image #: 15 image number 2342940
image #: 16 image number 2342941
image #: 17 image number 2342943
image #: 18 image number 2342944
image #: 19 image number 2342952
image #: 20 image number 2342953
image #: 21 image number 2342954
image #: 22 image number 2342956
image #: 23 image number 234295

image #: 238 image number 2342116
image #: 239 image number 2342117
image #: 240 image number 2342120
image #: 241 image number 2342122
image #: 242 image number 2342123
image #: 243 image number 2342130
image #: 244 image number 2342132
image #: 245 image number 2342133
image #: 246 image number 2342134
image #: 247 image number 2342137
image #: 248 image number 2342138
image #: 249 image number 2342139
image #: 250 image number 2342141
image #: 251 image number 2342142
image #: 252 image number 2342143
image #: 253 image number 2342145
image #: 254 image number 2342147
image #: 255 image number 2342148
image #: 256 image number 2342150
image #: 257 image number 2342151
image #: 258 image number 2342152
image #: 259 image number 2342154
image #: 260 image number 2342155
image #: 261 image number 2342157
image #: 262 image number 2342158
image #: 263 image number 2342159
image #: 264 image number 2342160
image #: 265 image number 2342161
image #: 266 image number 2342162
image #: 267 i

image #: 479 image number 2342501
image #: 480 image number 2342502
image #: 481 image number 2342503
image #: 482 image number 2342504
image #: 483 image number 2342505
image #: 484 image number 2342506
image #: 485 image number 2342507
image #: 486 image number 2342510
image #: 487 image number 2342511
image #: 488 image number 2342512
image #: 489 image number 2342514
image #: 490 image number 2342515
image #: 491 image number 2342519
image #: 492 image number 2342520
image #: 493 image number 2342521
image #: 494 image number 2342523
image #: 495 image number 2342525
image #: 496 image number 2342527
image #: 497 image number 2342528
image #: 498 image number 2342530
image #: 499 image number 2342531
image #: 500 image number 2342532
image #: 501 image number 2342533
image #: 502 image number 2342534
image #: 503 image number 2342535
image #: 504 image number 2342536
image #: 505 image number 2342537
image #: 506 image number 2342538
image #: 507 image number 2342540
image #: 508 i

image #: 720 image number 2342901
image #: 721 image number 2342902
image #: 722 image number 2342903
image #: 723 image number 2342904
image #: 724 image number 2342906
image #: 725 image number 2342910
725
