In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision as tv
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision import tv_tensors  # we'll describe this a bit later, bare with us

import torchvision.datasets as datasets
from pathlib import Path

from torchview import draw_graph
from pathlib import Path

import constants
import dataset
import util
import json
import pandas as pd
import models 
from models import VQANet
import matplotlib.pyplot as plt
import numpy as np
import time
import gc
from datetime import datetime

from transformers import AutoTokenizer
import traceback

USE_GPU = True
dtype = torch.float32 # We will be using float throughout this tutorial.

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():      
    device = 'mps'                         
else:
    device = torch.device('cpu')
from resumable_dataset import ResumableDataset


print('using device:', device)
    



In [None]:

%load_ext autoreload
%autoreload 2

In [None]:
test = dataset.Coco("test")
val = dataset.Coco("validation")


In [None]:
print(len(val))
print(len(test))

In [None]:
if False: # debug
    img = train.__getitem__(1)
    print(img)
    print(img.image_id)
    print(img.image_path)

    print(">>>>")
    print(img.captions())

    print(">>>>")
    print(img.qa())
    print("shape", img.image_tensor().shape)

    show([img.image_tensor()])

#plt.imshow(  img.image_tensor().permute(1, 2, 0)  )


In [None]:
tokenizer  = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# Add the Q and A token as special token
tokenizer.add_special_tokens(constants.QA_TOKEN_DICT)

In [None]:
from imagedata import get_image

def path_to_image(paths, device):
    result = []
    for p in paths:
        result.append(get_image(path, device))
    image = torch.stack(result, dim = 0).to(device)
    print("image shape", image.shape)
    return image
    
def collate_fn2(batch):
    result = {}
    
    result['image_ids'] = []
    result['image_paths'] = []
    result['c2i'] = [] # index for images for a given caption. same len as 'caption'
    result['qa2i'] = [] # index of corresponding image for a given qa. same len as 'qa'
    result['q_id'] = [] # id of the questions 'qa'
    
    
    target  = [] # the corresponding target for the qa.
    raw_captions = []  # plain text 
    raw_qa = []   # plain text
    raw_qids = []   # question ids
    for idx, data in enumerate(batch):
        result['image_ids'].append(data.image_id)
        result['image_paths'].append(data.image_path)
        caption_list = data.captions()
        if caption_list is not None:
            raw_captions += caption_list
            for c in range(len(caption_list)):
                result['c2i'].append(idx)
        
        qa_list = data.qa()
        q_id_list = data.qids()
        if qa_list is not None:
            raw_qa += qa_list
            raw_qids += q_id_list
            for c in range(len(qa_list)):
                result['qa2i'].append(idx)
    #print("raw_cap", len(raw_captions))
    #print("raw_qa", len(raw_qa))
    
    result['raw_cap'] = raw_captions
    result['captions'] = None if len(raw_captions) == 0 else \
                                tokenizer(raw_captions, padding='max_length',truncation=True,  max_length=32, return_tensors="pt").to(device)
#    print('captions:', result['captions'])
    result['raw_qa'] = raw_qa
    result['qids'] = raw_qids
    if len(raw_qa) != 0:
#        print("raw_qa:", raw_qa)
        result['qa'] =  tokenizer(raw_qa, padding='max_length', truncation=True, max_length=32, return_tensors="pt")['input_ids'].to(device, dtype=torch.int64)
#        print('qa:', result['qa'])

        end_padding = torch.broadcast_to(torch.zeros(1), (result['qa'].shape[0], 1)).to(device, dtype=torch.int64)
        #print(end_padding.shape)
        # return a shape {seq, batch}
        target = torch.column_stack((result['qa'][:, 1:], end_padding)).transpose(0, 1)
    else:
        result['qa'] = None
        target = None
    return result, target


In [None]:
from torch.utils.data import DataLoader
batch_size = 32
fn = collate_fn2 
shuffle = False
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
#val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)

In [None]:
def get_answers(x):
    qa = x["qa"]
    answer_token_id = tokenizer.convert_tokens_to_ids(constants.ANSWER_TOKEN)

    answer_start =  (qa == answer_token_id).nonzero()
    mask = torch.zeros_like(qa)
    mask[answer_start[:, 0], answer_start[:, 1]] = 1
    # fill the elements after the [ANSWER] token to be 1.
    mask = mask.cumsum(dim=1)
    just_answers = qa * mask
    return tokenizer.batch_decode(just_answers, skip_special_tokens = True)
    

In [None]:
def create_output(x, answers):
    result = []
    qids = x['qids']
    assert len(qids) == len(answers)
    for i in range(len(qids)):
        d = {"question_id" : qids[i], "answer": answers[i]}
        result.append(d)
    return result

In [None]:
import json

def run_test(name, dataloader , model, early_terminate = None,
             sync_after_every_n = 200, print_every = 10, pre_result = []):
    with torch.no_grad():
        model.eval()
        result = pre_result
        print("loading preresults:", len(result))
        for idx, (x, target) in enumerate(dataloader):
            should_print = print_every is not None and (print_every == 1 or idx % (print_every -1) == 0)
            if should_print:
                print(">>>> Batch # ", idx,  x['image_ids'] )
            if early_terminate is not None:
                if idx > early_terminate - 1:
                    print("early terminating. at ", idx)
                    break;
            image_embedding_for_captions, captions_embedding, output_logits = None, None, None

            try:
                answers = model.answer(x, device, max_length = 30)
                real_answers = get_answers(answers)
                out = create_output(x, real_answers)
                if should_print:
                    print(">>>> Batch Output # ", out)

                result += out
                if should_print:
                    print("result size:", len(result))

            except Exception as e:
                print(">>>> FAILED! Batch # ", idx,  x['image_ids'])
                print(f"current mps allocated memory: {torch.mps.current_allocated_memory()}")
                print(f"current mps driver allocated memory: {torch.mps.driver_allocated_memory()}")
                traceback.print_exc()
                return
                
            if sync_after_every_n is not None and (idx + 1) % sync_after_every_n == 0:
                print("=========== mps sync, gc, and mps empty cache and reload ==========")
                print(f"current mps allocated memory: {torch.mps.current_allocated_memory()}")
                print(f"current mps driver allocated memory: {torch.mps.driver_allocated_memory()}")
                gc.collect()
                torch.mps.synchronize()
                torch.mps.empty_cache()
                torch.mps.synchronize()
                print("after empyt cache")
                print(f"current mps allocated memory: {torch.mps.current_allocated_memory()}")
                print(f"current mps driver allocated memory: {torch.mps.driver_allocated_memory()}")
                print(">>>> Batch # ", idx, " >>>> Storing JSON data ending with result:", len(result) )
                with open(constants.TEST_OUTPUT.joinpath(name), 'w+') as f:
                    json.dump(result, f)
        with open(constants.TEST_OUTPUT.joinpath(name), 'w+') as f:
            json.dump(result, f)

In [None]:
from torch.utils.tensorboard import SummaryWriter

def reload_model(lr, name, use_captions):
    model = VQANet(tokenizer, use_captions).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    if name is not None:
        checkpoint = torch.load(constants.MODEL_OUT_PATH.joinpath(name))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.to(device)
    return model, optimizer

if False:
    model_name = "train-vqa_with_caption_Jun02_17-27-12-epoch-0-batch-1949"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'val_first_50_model_shuffle_{model_name}', val_dataloader, model, early_terminate= 50, print_every = 10, sync_after_every_n = 50)

In [None]:

if False:
    model_name = "train-vqa_no_caption_Jun03_02-34-26"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'val_first_50_model_shuffle_{model_name}', val_dataloader, model, early_terminate= 50, print_every = 10, sync_after_every_n = 50)

In [None]:
if False:
    model_name = "train-vqa_with_caption_Jun02_17-27-12-epoch-0-batch-1949"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'val_200_model_shuffle_{model_name}', val_dataloader, model, early_terminate= 200, print_every = 10, sync_after_every_n = 50)

In [None]:

if False:
    model_name = "train-vqa_no_caption_Jun03_02-34-26"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'val_200_model_shuffle_{model_name}', val_dataloader, model, early_terminate= 200, print_every = 10, sync_after_every_n = 50)

In [None]:
if False:
    et = 1000
    model_name = "train-vqa_with_caption_Jun02_17-27-12-epoch-0-batch-1949"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'val_{et}_model_shuffle_{model_name}', val_dataloader, model,
             early_terminate= et, print_every = 10, sync_after_every_n = 50)

In [None]:
if False:
    et = 5000
    model_name = "train-vqa_no_caption_Jun03_02-34-26"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'val_{et}_model_shuffle_{model_name}', val_dataloader, model,
             early_terminate= et, print_every = 10, sync_after_every_n = 50)

In [None]:
if False: # Done
    et = None
    dataloader = test_dataloader
    model_name = "train-vqa_with_caption_Jun02_17-27-12-epoch-0-batch-1949"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'test_{et}_model_shuffle_{model_name}', dataloader, model,
             early_terminate= et, print_every = 10, sync_after_every_n = 50)

In [None]:
if False:  # Done
    et = None
    model_name = "train-vqa_with_caption_Jun02_17-27-12-epoch-0-batch-1949"
    test = dataset.VqaCoco()
    dataset = ResumableDataset(test, existing_answer_name = "test_None_model_shuffle_train-vqa_with_caption_Jun02_17-27-12-epoch-0-batch-1949")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'test_{et}_model_shuffle_{model_name}', dataloader, model,
            early_terminate= et, print_every = 10, sync_after_every_n = 50, pre_result = dataset.answers)
    

In [None]:
if False:  # TODO turn this back to true
    et = None
    dataloader = test_dataloader
    model_name = "train-vqa_no_caption_Jun03_02-34-26"
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'test_{et}_model_shuffle_{model_name}', dataloader, model,
             early_terminate= et, print_every = 10, sync_after_every_n = 50)

In [22]:
if True: 
    et = None
    model_name = "train-vqa_no_caption_Jun03_02-34-26"
    test = dataset.VqaCoco()
    dataset = ResumableDataset(test, existing_answer_name = "test_None_model_shuffle_train-vqa_no_caption_Jun03_02-34-26")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
    model, _ = reload_model(0.1, model_name, use_captions= False)
    run_test(f'test_{et}_model_shuffle_{model_name}', dataloader, model,
            early_terminate= et, print_every = 10, sync_after_every_n = 50, pre_result = dataset.answers)

>>>> Batch Output #  [{'question_id': 228740000, 'answer': 'red'}, {'question_id': 228740001, 'answer': '2'}, {'question_id': 393022001, 'answer': '2'}, {'question_id': 142610004, 'answer': 'cow'}, {'question_id': 142610001, 'answer': 'cow'}, {'question_id': 454149013, 'answer': 'water'}, {'question_id': 454149018, 'answer': '2'}, {'question_id': 454149056, 'answer': 'yes'}, {'question_id': 454149004, 'answer': '2'}, {'question_id': 454149065, 'answer': 'water'}, {'question_id': 454149017, 'answer': '2'}, {'question_id': 454149050, 'answer': 'yes'}, {'question_id': 454149010, 'answer': 'yes'}, {'question_id': 454149080, 'answer': 'red'}, {'question_id': 454149057, 'answer': 'yes'}, {'question_id': 454149038, 'answer': 'giraffe'}, {'question_id': 454149011, 'answer': 'yes'}, {'question_id': 454149047, 'answer': '2'}, {'question_id': 454149070, 'answer': 'yes'}, {'question_id': 454149086, 'answer': 'water'}, {'question_id': 454149062, 'answer': 'in water'}, {'question_id': 454149032, 'an

>>>> Batch #  639 [195757, 179439, 311200, 151258, 100009, 243493, 213038, 254580, 393742, 27774, 12165, 471614, 113865, 113865, 303669, 44050, 362494, 151723, 373545, 231985, 393058, 297923, 245606, 178043, 341826, 259064, 93315, 307131, 373800, 126782, 126782, 216701]
>>>> Batch Output #  [{'question_id': 195757002, 'answer': 'umbrella'}, {'question_id': 195757001, 'answer': 'skateboarding'}, {'question_id': 195757003, 'answer': 'yes'}, {'question_id': 195757006, 'answer': '2'}, {'question_id': 179439001, 'answer': 'yes'}, {'question_id': 311200008, 'answer': 'eating'}, {'question_id': 311200018, 'answer': 'banana'}, {'question_id': 311200019, 'answer': 'yes'}, {'question_id': 311200016, 'answer': 'table'}, {'question_id': 311200031, 'answer': 'phone'}, {'question_id': 311200007, 'answer': 'yes'}, {'question_id': 151258004, 'answer': 'white'}, {'question_id': 151258002, 'answer': 'skateboard'}, {'question_id': 100009000, 'answer': '2'}, {'question_id': 243493018, 'answer': 'yes'}, {'

>>>> Batch #  648 [450036, 97034, 173857, 302804, 337474, 576553, 56793, 492046, 492046, 381951, 400096, 88079, 316379, 199361, 274115, 105275, 510108, 315335, 42485, 102491, 544275, 103672, 157879, 126990, 563715, 523386, 463062, 499557, 556294, 176199, 157031, 44712]
>>>> Batch Output #  [{'question_id': 450036033, 'answer': 'banana'}, {'question_id': 450036009, 'answer': 'yes'}, {'question_id': 450036036, 'answer': 'yes'}, {'question_id': 450036035, 'answer': 'yes'}, {'question_id': 450036007, 'answer': 'table'}, {'question_id': 450036031, 'answer': 'nothing'}, {'question_id': 450036020, 'answer': 'banana'}, {'question_id': 450036000, 'answer': '2'}, {'question_id': 450036018, 'answer': 'nothing'}, {'question_id': 450036019, 'answer': 'banana'}, {'question_id': 97034002, 'answer': '2'}, {'question_id': 173857000, 'answer': 'on road'}, {'question_id': 173857006, 'answer': '2'}, {'question_id': 173857004, 'answer': 'red'}, {'question_id': 173857009, 'answer': 'red'}, {'question_id': 3

current mps allocated memory: 2324486400
current mps driver allocated memory: 31682494464
after empyt cache
current mps allocated memory: 2324486400
current mps driver allocated memory: 15920300032
>>>> Batch #  649  >>>> Storing JSON data ending with result: 179083
>>>> Batch #  657 [58293, 396377, 455115, 548106, 454278, 63802, 222173, 437505, 209342, 136629, 183547, 56880, 250438, 225590, 534654, 135217, 13735, 130208, 463808, 138660, 317599, 454188, 23973, 415698, 360023, 318441, 394628, 242360, 184936, 374459, 27187, 437606]
>>>> Batch Output #  [{'question_id': 58293004, 'answer': 'yes'}, {'question_id': 396377003, 'answer': 'yes'}, {'question_id': 455115004, 'answer': 'yes'}, {'question_id': 455115011, 'answer': 'yes'}, {'question_id': 455115000, 'answer': 'yes'}, {'question_id': 455115006, 'answer': 'yes'}, {'question_id': 455115012, 'answer': 'yes'}, {'question_id': 548106002, 'answer': 'white'}, {'question_id': 454278010, 'answer': 'yes'}, {'question_id': 454278020, 'answer':

>>>> Batch #  675 [334369, 227387, 318927, 439376, 254321, 100574, 305426, 229768, 515223, 315716, 266194, 158257, 323383, 465382, 329774, 209417, 555466, 555466, 282739, 350673, 290271, 67610, 287728, 418542, 557332, 472458, 8129, 49584, 307485, 467185, 307485, 563893]
>>>> Batch Output #  [{'question_id': 334369000, 'answer': 'yes'}, {'question_id': 227387002, 'answer': 'yes'}, {'question_id': 227387007, 'answer': 'yes'}, {'question_id': 318927015, 'answer': 'yes'}, {'question_id': 318927002, 'answer': 'yes'}, {'question_id': 439376007, 'answer': 'yes'}, {'question_id': 439376053, 'answer': '2'}, {'question_id': 439376066, 'answer': 'tile'}, {'question_id': 439376062, 'answer': 'bed'}, {'question_id': 439376021, 'answer': 'living room'}, {'question_id': 439376046, 'answer': 'yes'}, {'question_id': 439376015, 'answer': 'yes'}, {'question_id': 439376044, 'answer': '2'}, {'question_id': 439376051, 'answer': 'yes'}, {'question_id': 439376008, 'answer': 'yes'}, {'question_id': 439376050, 

>>>> Batch #  684 [68400, 253191, 493396, 296841, 75366, 198966, 350760, 573350, 437679, 437679, 437679, 461184, 572773, 326241, 451838, 326241, 433394, 277159, 458440, 546536, 120878, 469528, 527673, 69845, 557007, 124586, 381094, 491923, 492577, 67863, 548531, 539248]
>>>> Batch Output #  [{'question_id': 68400012, 'answer': 'yes'}, {'question_id': 68400000, 'answer': 'tennis'}, {'question_id': 68400004, 'answer': 'standing'}, {'question_id': 68400009, 'answer': '4'}, {'question_id': 68400008, 'answer': 'tennis'}, {'question_id': 253191005, 'answer': 'yes'}, {'question_id': 253191006, 'answer': 'yes'}, {'question_id': 253191011, 'answer': 'yes'}, {'question_id': 253191008, 'answer': 'yes'}, {'question_id': 253191018, 'answer': '4'}, {'question_id': 493396008, 'answer': 'bus'}, {'question_id': 493396001, 'answer': 'red'}, {'question_id': 493396005, 'answer': 'yes'}, {'question_id': 493396007, 'answer': 'yes'}, {'question_id': 296841001, 'answer': 'yes'}, {'question_id': 296841002, 'an

>>>> Batch #  693 [93315, 435597, 408380, 229768, 462128, 363029, 440203, 565237, 193563, 447635, 448790, 357110, 293234, 34813, 516817, 417073, 499099, 243794, 106509, 418286, 52304, 388306, 504913, 94032, 557707, 517475, 137987, 568844, 30081, 377827, 307753, 467134]
>>>> Batch Output #  [{'question_id': 93315076, 'answer': 'stop'}, {'question_id': 93315052, 'answer': 'yes'}, {'question_id': 93315046, 'answer': 'stop'}, {'question_id': 93315047, 'answer': 'red'}, {'question_id': 93315044, 'answer': 'yes'}, {'question_id': 93315039, 'answer': 'yes'}, {'question_id': 93315057, 'answer': '2'}, {'question_id': 93315056, 'answer': 'yes'}, {'question_id': 93315028, 'answer': 'yes'}, {'question_id': 93315089, 'answer': 'left'}, {'question_id': 93315063, 'answer': 'yes'}, {'question_id': 93315041, 'answer': 'yes'}, {'question_id': 93315086, 'answer': '12 : 10'}, {'question_id': 93315083, 'answer': 'red'}, {'question_id': 93315023, 'answer': 'yes'}, {'question_id': 93315064, 'answer': 'yes'},

current mps allocated memory: 2324026624
current mps driver allocated memory: 30372364288
after empyt cache
current mps allocated memory: 2324026624
current mps driver allocated memory: 15906209792
>>>> Batch #  699  >>>> Storing JSON data ending with result: 186116
>>>> Batch #  702 [382076, 311609, 489390, 489390, 478601, 205426, 53441, 348714, 299814, 100609, 179694, 419841, 413412, 376120, 55283, 109014, 11738, 321738, 94432, 151446, 11738, 122449, 456402, 92893, 360930, 160431, 455837, 518616, 165902, 75191, 540020, 88196]
>>>> Batch Output #  [{'question_id': 382076008, 'answer': 'on top'}, {'question_id': 382076007, 'answer': 'yes'}, {'question_id': 311609010, 'answer': 'yes'}, {'question_id': 311609007, 'answer': 'yes'}, {'question_id': 311609003, 'answer': 'yes'}, {'question_id': 489390005, 'answer': 'banana?沢 [unused199] [unused199] [unused199]'}, {'question_id': 489390001, 'answer': 'yes? accused [unused199] [unused199] [unused199]'}, {'question_id': 489390005, 'answer': 'ba

>>>> Batch #  720 [41134, 390103, 415989, 502802, 156419, 132118, 477856, 340031, 216561, 574072, 249312, 533421, 235932, 163188, 382944, 203476, 536076, 549859, 505525, 371181, 477981, 354262, 172414, 134191, 209332, 46208, 330268, 355313, 401964, 325881, 431348, 431348]
>>>> Batch Output #  [{'question_id': 41134010, 'answer': 'yes'}, {'question_id': 41134011, 'answer': '2'}, {'question_id': 41134005, 'answer': 'stop'}, {'question_id': 41134008, 'answer': 'red'}, {'question_id': 41134018, 'answer': 'stop'}, {'question_id': 390103010, 'answer': 'bus'}, {'question_id': 390103018, 'answer': 'yes'}, {'question_id': 415989003, 'answer': '2'}, {'question_id': 415989001, 'answer': 'yes'}, {'question_id': 415989011, 'answer': 'yes'}, {'question_id': 415989014, 'answer': 'bus'}, {'question_id': 415989010, 'answer': 'yes'}, {'question_id': 502802002, 'answer': 'yes'}, {'question_id': 156419002, 'answer': 'right'}, {'question_id': 156419000, 'answer': 'yes'}, {'question_id': 156419009, 'answer'

In [None]:
stuff = iter(dataloader)
print(next(stuff))

In [None]:
print(dataset.missing)
dataset.missing.iloc[0]['question_id']

In [None]:
standard_test = dataset.VqaCoco("test", use_standard = True)


In [None]:
print(standard_test.image_paths[0])

In [None]:
import imagedata
data = standard_test.__getitem__(0)
print(imagedata.to_dict(data))


In [None]:
image_ids = test.questions['image_id'].unique()
print(len(image_ids))

In [None]:
import os

def check_images(image_in_question, coco):
    image_ids = set(map(lambda image_path: int(os.path.basename(image_path).removesuffix('.jpg')), coco.image_paths))
    print(">>> 1:", next(iter(image_ids)))
    diff =  set(image_in_question.flatten()).difference(image_ids)
    print("existing diff: ", len(diff))

check_images(test.questions['image_id'].unique(), test)

In [None]:
print(len(test.image_paths))