OG Note: This demo is adapted from the LXMERT Demo present here: https://github.com/huggingface/transformers/tree/main/examples/research_projects/lxmert

JN Note: The bulk of the code was sourced from: https://github.com/huggingface/transformers/tree/main/examples/research_projects/visual_bert

In [1]:
# Basic packages
from IPython.display import Image, display
import PIL.Image
from PIL import Image
import io
import torch
import numpy as np
import json
import glob
import re
from os.path import exists
from tqdm import tqdm
import time
import random

random.seed(1)

# Model
from transformers import ViltProcessor, ViltForQuestionAnswering

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
output_title_base='/u/scr/nlp/data/nli-consistency/vilt_results/vilt-test-3pred-40token-1seed_predictions'

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
# Data from https://visualgenome.org/api/v0/api_home.html
# Used Version 1.2 as that one had Part 1 and Part 2 (selected Part 1)
VG_path = '/u/scr/nlp/data/nli-consistency/vg_data/VG_100K/'
VG_path2 = '/u/scr/nlp/data/nli-consistency/vg_data/VG_100K_2/'

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH
# L-ConVQA from https://arijitray1993.github.io/ConVQA/
file_source='/u/scr/nlp/data/nli-consistency/cleaned_Logical_ConVQA_test.json'
file_paths=[]
with open(file_source, 'r') as stream:
    data = json.load(stream)
    
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [2]:
def convert_multiple_answers(predictions, model):
    return [model.config.id2label[predictions[i].item()] for i in range(len(predictions))]

In [3]:
# Vilt
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)

In [4]:
# Partially based on code from :
# https://huggingface.co/docs/transformers/model_doc/visual_bert

predictions = {}
num_predictions = 3

# for key in cleaned_data.keys():
#     image_num = key
#     image_path = VG_path + str(image_num) + '.jpg'
    
#     qas = cleaned_data[key]

counter = 0

no_image = []

start = time.process_time()

# It takes 4.1 seconds for 20 images; 5 images per second
data_keys = list(data.keys())
for i in tqdm(range(len(data_keys))):
    key = data_keys[i]
    image_num = key
    
    image_path1 = VG_path + str(image_num) + '.jpg'
    image_path2 = VG_path2 + str(image_num) + '.jpg'
    
    if exists(image_path1):
        image_path = image_path1
    elif exists(image_path2):
        image_path = image_path2
    else:
        no_image.append(image_num)
        continue
    
    # From the main demo
    image = Image.open(image_path).convert('RGB')
    
    predictions[image_num] = {}
    
    qas = data[key]
    
    # Sets of consistent questions
    for j in range(len(qas)):
        predictions[image_num][j] = []
        for k in range(len(qas[j])):
            qs_only = qas[j][k]['question']
            image
            encoding = processor(
                image,
                qs_only,
                return_tensors='pt', 
                padding = "max_length", 
                max_length=40,
                truncation = True, 
                add_special_tokens = True, 
                return_token_type_ids=True, 
                return_attention_mask=True,
            ).to(device)

            # forward pass
            outputs = model(**encoding)
            logits = outputs.logits
            
            l_softmax = logits.softmax(dim = -1)
            l_sorted, l_indices = torch.sort(l_softmax, dim=-1, descending=True)
            
            predictions[image_num][j].append(qas[j][k]|{'prediction':convert_multiple_answers(l_indices[0][0:num_predictions], model), 'prob':l_sorted[0][0:num_predictions].tolist()})
        
    counter += 1    
    
total_time = time.process_time() - start

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 725/725 [04:09<00:00,  2.91it/s]


In [5]:
correct = 0
incorrect = 0

qa_correct = []
qa_incorrect = []

for key in predictions.keys():
    for set_num in predictions[key].keys():
        for qas in predictions[key][set_num]:
            if qas['answer'] != qas['prediction'][0]:
                incorrect += 1
                qa_incorrect.append((qas['answer'], qas['prediction'][0]))
            else:
                correct += 1
                qa_correct.append((qas['answer'], qas['prediction'][0]))

accuracy = float(correct) / (float(correct) + float(incorrect))

accuracy

0.7840319952599615

In [6]:
with open(output_title_base + '.json', 'w') as f:
    json.dump(predictions, f)
with open(output_title_base + '.txt', 'w') as f:
    json.dump({'accuracy: ':accuracy, 
               'qa_correct: ':len(qa_correct), 
               'qa_incorrect: ':len(qa_incorrect), 
               'no_image: ':len(no_image),
               'total_questions: ':len(qa_correct) + len(qa_incorrect) + len(no_image),
               'time_taken: ':total_time,
              }, f)

In [7]:
accuracy

0.7840319952599615

In [None]:
predictions

In [8]:
total_time

495.394285222