In [None]:
# Stores NLI For LXMERT 10000

In [1]:
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from pysat.formula import IDPool, WCNFPlus
from pysat.examples.rc2 import RC2
import numpy as np
import matplotlib.pyplot as plt
import json

# custom modules
import sys
sys.path.append('./nlic')
import qa_converter
import nli
import solver

device = "cuda"

In [None]:
### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
data_path = '/data/nli-consistency/vqa/vilt-run-train-10000im-3pred-40token-1seed_predictions_nli.json'

models = ["ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli"]
### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
save_paths = ['/data/nli-consistency/vqa/vilt-run-train-10000im-3pred-40token-1seed_predictions_nli-xxlarge.json']

for m_name, save_path in zip(models, save_paths):
    print('Model:', m_name)
    print('Save path:', save_path)
    print('Data path:', data_path)
    
    with open(data_path, 'r') as f:
        data = json.load(f)
    
    nlier = nli.NLIInferencer(model_hf_name=m_name, confidence_threshold=0.0,
                         dedup_constraints=False)

    # Normalize

    num_choices = 2
    not_redundant = True # I.e. self comparisons -- it doesn't add much value since false will still yield truth value
    repeated_comparisons = False # I.e. identical comparisons; as in what if there are multiple answer confidence levels? No because regardless based on single statement
    group_count = num_choices

    raw_correct = 0
    new_correct = 0
    good_change = 0
    bad_change = 0
    count = 0
    questions_done = 0

    for key in data.keys():
        print('image #:', questions_done + 1, 'image number', key)
        img_data = data[key]
        
        for group in img_data.keys():
            group_list = img_data[group]['orig']
            nli_stuff = img_data[group]['nli']
            
            statement_groups = nli_stuff['statement_groups']
            converted_flat = nli_stuff['converted_flat']
            
            compared = nlier(converted_flat, group_count = group_count, not_redundant=not_redundant, fp_batch_size=32)
            if not repeated_comparisons:
                compared = list(set(compared))
            
            data[key][group]['nli']['compared'] = compared
        questions_done += 1        

    print(questions_done)
    
    with open(save_path, 'w') as f:
        json.dump(data, f)   

Model: ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli
Save path: /data/nli-consistency/vqa/vilt-run-train-10000im-3pred-40token-1seed_predictions_nli-xxlarge.json
Data path: /data/nli-consistency/vqa/vilt-run-train-10000im-3pred-40token-1seed_predictions_nli.json
image #: 1 image number 2372795
image #: 2 image number 2405817
image #: 3 image number 2339665
image #: 4 image number 2396053
image #: 5 image number 2407048
image #: 6 image number 2417670
image #: 7 image number 2330006
image #: 8 image number 2372820
image #: 9 image number 2368858
image #: 10 image number 2333203
image #: 11 image number 2354153
image #: 12 image number 2332989
image #: 13 image number 2381034
image #: 14 image number 2407360
image #: 15 image number 2409623
image #: 16 image number 2355757
image #: 17 image number 2403475
image #: 18 image number 2375266
image #: 19 image number 2371210
image #: 20 image number 2381003
image #: 21 image number 2315492
image #: 22 image number 2348881
image #: 

image #: 238 image number 2358569
image #: 239 image number 2363245
image #: 240 image number 2332511
image #: 241 image number 2322138
image #: 242 image number 2345528
image #: 243 image number 2350643
image #: 244 image number 2404894
image #: 245 image number 2381414
image #: 246 image number 2347468
image #: 247 image number 2344723
image #: 248 image number 2328889
image #: 249 image number 2344181
image #: 250 image number 2338814
image #: 251 image number 2355926
image #: 252 image number 2364110
image #: 253 image number 2386436
image #: 254 image number 2327400
image #: 255 image number 2391382
image #: 256 image number 2331117
image #: 257 image number 2344641
image #: 258 image number 2369616
image #: 259 image number 2375347
image #: 260 image number 2346285
image #: 261 image number 2407367
image #: 262 image number 2406260
image #: 263 image number 2332443
image #: 264 image number 2317355
image #: 265 image number 2369213
image #: 266 image number 2403989
image #: 267 i