OG Note: This demo is adapted from the LXMERT Demo present here: https://github.com/huggingface/transformers/tree/main/examples/research_projects/lxmert

JN Note: The bulk of the code was sourced from: https://github.com/huggingface/transformers/tree/main/examples/research_projects/visual_bert

In [None]:
# Basic packages
from IPython.display import Image, display
import PIL.Image
import io
import torch
import numpy as np
import json
import glob
import re
from os.path import exists
from tqdm import tqdm
import time
import random

random.seed(1)

# Local packages
from processing_image import Preprocess
from visualizing_image import SingleImageViz
from modeling_frcnn import GeneralizedRCNN
from utils import Config
import utils

# Model
from transformers import LxmertForQuestionAnswering, LxmertTokenizer

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH FOR PREDICTION
output_title_base='/u/scr/nlp/data/nli-consistency/vqa-camera/lxmert_results/lxmert-run-train-10000im-3pred-40token-1seed_predictions'

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH FOR PREDICTION
# Data from https://visualgenome.org/api/v0/api_home.html
# Used Version 1.2 as that one had Part 1 and Part 2 (selected Part 1)
VG_path = '/u/scr/nlp/data/nli-consistency/vg_data/VG_100K/'
VG_path2 = '/u/scr/nlp/data/nli-consistency/vg_data/VG_100K_2/'

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH FOR PREDICTION
# L-ConVQA from https://arijitray1993.github.io/ConVQA/
file_source='/u/scr/nlp/data/nli-consistency/vg-data-2022-06-13-16:03:37-n=10000-seed=1.txt'
file_paths=[]
with open(file_source, 'r') as stream:
    for line in stream:
        file_paths.append(line[:-1])
file_names = [re.search(r'qas_(.*?)\.json', path).group(1) for path in file_paths] # Gets the image number

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH FOR PREDICTION
# Test set
cleaned_path = '/u/scr/nlp/data/nli-consistency/cleaned_Logical_ConVQA_test.json'
with open(cleaned_path, 'r') as r_file:
    cleaned_data = json.load(r_file)
    
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

### INSTRUCTION FOR USERS : INDICATE APPROPRIATE PATH FOR PREDICTION
# VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"
VQA_URL = '/u/scr/nlp/data/nli-consistency/trainval_label2ans.json' # downloaded version from above

In [None]:
# load answer labels
vqa_answers = utils.get_data(VQA_URL)

def convert_multiple_answers(predictions, vqa_answers):
    return [vqa_answers[predictions[i]] for i in range(len(predictions))]

In [None]:
# load models and model components
frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
frcnn_cfg.model.device = device

frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg)
image_preprocess = Preprocess(frcnn_cfg)

lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to(device)

In [None]:
print(torch.cuda.current_device())

In [None]:
# Partially based on code from :
# https://huggingface.co/docs/transformers/model_doc/visual_bert

predictions = {}
num_predictions = 3

# for key in cleaned_data.keys():
#     image_num = key
#     image_path = VG_path + str(image_num) + '.jpg'
    
#     qas = cleaned_data[key]

counter = 0

no_image = []

start = time.process_time()

# It takes 4.1 seconds for 20 images; 5 images per second
for i in tqdm(range(len(file_paths))):
# for i in tqdm(range(1000)):
    file_path = file_paths[i]
    image_num = file_names[i]
    
    image_path1 = VG_path + str(image_num) + '.jpg'
    image_path2 = VG_path2 + str(image_num) + '.jpg'
    
    if exists(image_path1):
        image_path = image_path1
    elif exists(image_path2):
        image_path = image_path2
    else:
        no_image.append(image_num)
        continue
    
    # From the main demo
    images, sizes, scales_yx = image_preprocess(image_path)
    output_dict = frcnn(
        images,
        sizes,
        scales_yx=scales_yx,
        padding="max_detections",
        max_detections=frcnn_cfg.max_detections,
        return_tensors="pt",
    )
    
    normalized_boxes = output_dict.get("normalized_boxes")
    features = output_dict.get("roi_features")
    # From the main demo
    
    predictions[image_num] = {}
    
    with open(file_path, 'r') as file_stream:
        qas = json.load(file_stream)['consistent']
    
    # Sets of consistent questions
    for j in range(len(qas)):
        qs_only = [qas[j][k]['question'] for k in range(len(qas[j]))]
        inputs = lxmert_tokenizer(
            qs_only, 
            return_tensors='pt', 
            padding = "max_length", 
            max_length=40,
            truncation = True, 
            add_special_tokens = True, 
            return_token_type_ids=True, 
            return_attention_mask=True,
        )
        
        output_vqa = lxmert_vqa(
            input_ids=inputs.input_ids.to(device),
            attention_mask=inputs.attention_mask.to(device),
            visual_feats=features.expand(len(inputs.input_ids), -1, -1).to(device),
            visual_pos=normalized_boxes.expand(len(inputs.input_ids), -1, -1).to(device),
            token_type_ids=inputs.token_type_ids.to(device),
            output_attentions=False,
        )
        
        final_logits = output_vqa['question_answering_score']
        final_softmax = final_logits.softmax(dim = -1)
        final_sorted, final_indices = torch.sort(final_softmax, dim = -1, descending = True)
        
        predictions[image_num][j] = [qas[j][k]|{'prediction':convert_multiple_answers(final_indices[k][0:num_predictions], vqa_answers), 'prob':final_sorted[k][0:num_predictions].tolist()} for k in range(len(qas[j]))]
        
    counter += 1    
    
total_time = time.process_time() - start

In [None]:
correct = 0
incorrect = 0

qa_correct = []
qa_incorrect = []

for key in predictions.keys():
    for set_num in predictions[key].keys():
        for qas in predictions[key][set_num]:
            if qas['answer'] != qas['prediction'][0]:
                incorrect += 1
                qa_incorrect.append((qas['answer'], qas['prediction'][0]))
            else:
                correct += 1
                qa_correct.append((qas['answer'], qas['prediction'][0]))

accuracy = float(correct) / (float(correct) + float(incorrect))

accuracy

In [None]:
with open(output_title_base + '.json', 'w') as f:
    json.dump(predictions, f)
with open(output_title_base + '.txt', 'w') as f:
    json.dump({'accuracy: ':accuracy, 
               'qa_correct: ':len(qa_correct), 
               'qa_incorrect: ':len(qa_incorrect), 
               'no_image: ':len(no_image),
               'total_questions: ':len(qa_correct) + len(qa_incorrect) + len(no_image),
               'time_taken: ':total_time,
              }, f)

In [None]:
accuracy

In [None]:
predictions

In [None]:
total_time