In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision as tv
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision import tv_tensors  # we'll describe this a bit later, bare with us

import torchvision.datasets as datasets
from pathlib import Path

from torchview import draw_graph
from pathlib import Path

import constants
import dataset
import util
import json
import pandas as pd
import models 
from models import VQANet
import matplotlib.pyplot as plt
import numpy as np
import time
import gc
from datetime import datetime

from transformers import AutoTokenizer
import traceback

USE_GPU = True
dtype = torch.float32 # We will be using float throughout this tutorial.

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():      
    device = 'mps'                         
else:
    device = torch.device('cpu')


device = torch.device('cpu')
# Constant to control how frequently we print train loss.
print_every = 100
print('using device:', device)
    



using device: cpu


In [3]:

%load_ext autoreload
%autoreload 2

In [4]:

def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


In [5]:
# with open(constants.CAPTION_TRAIN, 'r') as f:
#     data = json.load(f)
#     print(data.keys())
#     print(data["annotations"][0])

# with open(constants.VQA_OPEN_ENDED_QUESTION_TRAIN, 'r') as f:
#     data = json.load(f)
#     print(data.keys())
#     print(data["questions"][0])

# with open(constants.VQA_OPEN_ENDED_ANSWER_TRAIN, 'r') as f:
#     data = json.load(f)
#     print(data.keys())
#     print(data["annotations"][0])
    
# with open(constants.CAPTION_VAL, 'r') as f:
#     data = json.load(f)
#     print(data.keys())

# with open(constants.VQA_OPEN_ENDED_QUESTION_VAL, 'r') as f:
#     data = json.load(f)
#     print(data.keys())

# with open(constants.VQA_OPEN_ENDED_ANSWER_VAL, 'r') as f:
#     data = json.load(f)
#     print(data.keys())

#dataset.load(constants.VQA_OPEN_ENDED_QUESTION_TRAIN, ['image_id', 'id', 'caption'])

In [6]:
train = dataset.Coco()
val = dataset.Coco("validation")
test = dataset.Coco("test")

Downloading split 'train' to '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/train' if necessary
Found annotations at '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/raw/instances_train2017.json'
Images already downloaded
Existing download of split 'train' is sufficient
Loading existing dataset 'coco-2017-train'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'validation' to '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/validation' if necessary
Found annotations at '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/raw/instances_val2017.json'
Images already downloaded
Existing download of split 'validation' is sufficient
Loading existing dataset 'coco-2017-validation'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'test' to '/Users/xiangyuliu/sources/fiftyone_dataset_zoo/coco-2017/test' if necessary
Found test 

In [7]:
print(len(train))
print(len(train.captions))

118287
591753


In [8]:
if False: # debug
    img = train.__getitem__(1)
    print(img)
    print(img.image_id)
    print(img.image_path)

    print(">>>>")
    print(img.captions())

    print(">>>>")
    print(img.qa())
    print("shape", img.image_tensor().shape)

    show([img.image_tensor()])

#plt.imshow(  img.image_tensor().permute(1, 2, 0)  )


In [9]:
tokenizer  = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# Add the Q and A token as special token
tokenizer.add_special_tokens(constants.QA_TOKEN_DICT)



4

In [10]:
from imagedata import get_image

def path_to_image(paths, device):
    result = []
    for p in paths:
        result.append(get_image(path, device))
    image = torch.stack(result, dim = 0).to(device)
    print("image shape", image.shape)
    return image
    
def collate_fn2(batch):
    result = {}
    
    result['image_ids'] = []
    result['image_paths'] = []
    result['c2i'] = [] # index for images for a given caption. same len as 'caption'
    result['qa2i'] = [] # index of corresponding image for a given qa. same len as 'qa'
    result['q_id'] = [] # id of the questions 'qa'
    
    
    target  = [] # the corresponding target for the qa.
    raw_captions = []  # plain text 
    raw_qa = []   # plain text
    raw_qids = []   # question ids
    for idx, data in enumerate(batch):
        result['image_ids'].append(data.image_id)
        result['image_paths'].append(data.image_path)
        caption_list = data.captions()
        if caption_list is not None:
            raw_captions += caption_list
            for c in range(len(caption_list)):
                result['c2i'].append(idx)
        
        qa_list = data.qa()
        q_id_list = data.qids()
        if qa_list is not None:
            raw_qa += qa_list
            raw_qids += q_id_list
            for c in range(len(qa_list)):
                result['qa2i'].append(idx)
    #print("raw_cap", len(raw_captions))
    #print("raw_qa", len(raw_qa))
    
    result['raw_cap'] = raw_captions
    result['captions'] = None if len(raw_captions) == 0 else \
                                tokenizer(raw_captions, padding=True , return_tensors="pt").to(device)
    result['raw_qa'] = raw_qa
    result['qids'] = raw_qids
    if len(raw_qa) != 0:
#        print("raw_qa:", raw_qa)
        result['qa'] =  tokenizer(raw_qa, padding=True , return_tensors="pt")['input_ids'].to(device, dtype=torch.int64)
        end_padding = torch.broadcast_to(torch.zeros(1), (result['qa'].shape[0], 1)).to(device, dtype=torch.int64)
        #print(end_padding.shape)
        # return a shape {seq, batch}
        target = torch.column_stack((result['qa'][:, 1:], end_padding)).transpose(0, 1)
    else:
        result['qa'] = None
        target = None
    return result, target


In [11]:
from torch.utils.data import DataLoader
batch_size = 5
fn = collate_fn2 
shuffle = False  # True
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=shuffle, collate_fn=fn)

In [12]:
def spot_check(dataloader, size=5):
    it = iter(dataloader)
    for _ in range(size):
        x, target= next(it)
        print(x)
        show(x["images"])

In [None]:
spot_check(train_dataloader)

In [None]:
spot_check(test_dataloader)

In [13]:
ce_fn = nn.CrossEntropyLoss( reduction='none')
cos_fn = nn.CosineSimilarity(dim=1)

In [None]:
# out = model(x, device)
# image_embedding, captions_embeddings, output_logits = out
# print(captions_embeddings.shape)
# a = output_logits.reshape(-1, len(tokenizer))
# b = target.reshape(-1)
# print("a", a.shape, a)
# print("b", b.shape, b)

# ce_loss = ce_fn(a, b)
# print(ce_loss.shape)
# N = len(x['images'])
# M = len(x['qa2i'])
# ce = ce_loss.reshape(-1, M).transpose(0, 1)
# print(ce.shape)
# print(ce)
# per_qa  = torch.mean(ce, axis = 1)
# print(per_qa.shape)

In [None]:
# blown = models.blow_to(image_embedding, result['c2i'])
# print(image_embedding.shape)
# print(image_embedding)
# print(blown.shape)
# print(blown)
# print("captions_embedding:", captions_embeddings.shape)
# print(result['c2i'])

In [None]:
# print(blown)
# print(captions_embeddings)
# cos= nn.CosineSimilarity(dim = 0)
# print(cos(blown[1], captions_embeddings[1]))

# per_caption_loss = cos_fn(blown, captions_embeddings)
# print(per_caption_loss)
# per_image_caption_loss = cal_average(len(result['images']), per_caption_loss, result['c2i'])
# print(per_image_caption_loss.shape)

# print(per_image_caption_loss)

In [33]:
def reload_model( lr, name):
    model = VQANet(tokenizer).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    if name is not None:
        checkpoint = torch.load(constants.MODEL_OUT_PATH.joinpath(name))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer

def save_model(model, optimizer, name):
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()},
        constants.MODEL_OUT_PATH.joinpath(name))

In [34]:
def cal_average(size, blown_loss, replicas):
    result= torch.zeros(size).to(device)
    counts = torch.zeros(size).to(device)
    for index, val in enumerate(replicas):
        result[val] += blown_loss[index]
        counts[val] += 1
        
    for index in range(size):
        if counts[index] == 0:
            counts[index] = 1  # so that result / counts still makes sense.
    #print("result", result)
    #print("counts:", counts)
    counts = counts.detach()  # we don't need gradient for the counts.
    result /= counts
    return result

In [37]:
import gc
gamma = 0.9
DEBUG = False
def do_train(model, optimizer, idx, x, target, should_print = False):
        N = len(x['image_ids'])
        # Zero your gradients for every batch!
        optimizer.zero_grad()
        if DEBUG:
            image_embedding_for_captions, captions_embedding, output_logits = None, None, None
        else:
            image_embedding_for_captions, captions_embedding, output_logits  = model(x, device)

#        image_embedding, captions_embedding, output_logits = None, None, None
        per_image_qa_loss = None
        per_image_caption_loss = None
        
        if output_logits is not None:
#            print("out_logits argmax", torch.argmax(output_logits.transpose(0,1), axis=2))
#            print("target", target)
            a = output_logits.reshape(-1, len(tokenizer))

            b = target.reshape(-1)
#            print("a", a.shape)
#            print("b", b.shape)

            K = len(x['qa2i'])
            # back to (K, seq)
            qa_loss = ce_fn(a, b).reshape(-1, K).transpose(0, 1)
            #print("qa_loss", qa_loss.shape)
            # qa loss, shape of (K) (different images can have diff counts of qas)
            per_qa_loss = torch.mean(qa_loss, axis = 1)

            # per image qa loss, shape of (N)
            per_image_qa_loss = cal_average(N, per_qa_loss, x['qa2i'])
            #print("per_qa_loss", per_qa_loss.shape)
            #print("per_image_qa_loss", per_image_qa_loss.shape)

        if captions_embedding is not None:
            # loss per caption, shape of (M) (different images can have diff counts of captions)
            per_caption_loss = cos_fn(image_embedding_for_captions, captions_embedding)
            # cosine similarity is within [-1, 1] where 1 being similar. 
            # for loss, we invert it and shift it by 1 to keep the value always positive.
            # thus 0 means similar, 2 means completely opposite
            # print("per_captions_loss:", per_caption_loss)
            per_caption_loss = -per_caption_loss + 1
            # print("normalized per_captions_loss:", per_caption_loss)
            # per image loss on the caption scale. shape of (N)
            per_image_caption_loss = cal_average(N, per_caption_loss, x['c2i'])

        total_loss = 0
        if per_image_qa_loss is not None:
            total_loss += gamma * per_image_caption_loss
            
        if per_image_qa_loss is not None:
            total_loss += per_image_qa_loss

        loss = torch.sum(total_loss)

        if not DEBUG:
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()


        del per_image_caption_loss
        del per_image_qa_loss
        del x
        del total_loss
        return loss
            
def training(writer, epoches, early_terminate = None, sync_after_every_n = 200, print_every = 100):
    lr = 0.1
    model, optimizer = reload_model(lr, None)
    
    start_time = time.time()
    model_name = 'train-' + datetime.now().strftime("%b%d_%H-%M-%S")
    
    model.train()
    for epoch_idx in range(epoches):
        print ("----- Start Epoch %s -----" % epoch_idx)
        epoch_loss = 0
        for idx, (x, target) in enumerate(train_dataloader):
            should_print = print_every is not None and (print_every == 1 or idx % (print_every -1) == 0)
            if should_print:
                print(">>>> Batch # ", idx,  x['image_ids'] )
            if early_terminate is not None:
                if idx > early_terminate - 1:
                    print("early terminating. at ", idx)
                    break;
            try:
                loss = do_train(model, optimizer, idx, x, target, should_print).detach()
                batch_loss = loss.item()
            except Exception as e:
                print(">>>> FAILED! Batch # ", idx,  x['image_ids'])
                traceback.print_exc()
                break;

            if sync_after_every_n is not None and (idx + 1) % sync_after_every_n == 0:
                print("=========== mps sync, gc, and mps empty cache and reload ==========")
                name = model_name + f"-epoch-{epoch_idx}-batch-{idx}"
                save_model(model, optimizer, name)
                model = None
                optimizer = None
                torch.mps.synchronize()
                gc.collect()
                torch.mps.empty_cache()
                model, optimizer = reload_model(lr, name)
                model.train()

            if should_print:
                print("loss:", batch_loss)
                print("--- %s Per batch time ---" % (time.time() - start_time))
                
            epoch_loss += batch_loss

        epoch_loss /= len(train_dataloader) if early_terminate is None \
                                            else (early_terminate * train_dataloader.batch_size)
        writer.add_scalar("Loss/train", epoch_loss, epoch_idx)

        print(f"---DONE: {epoch_idx} epoch, {(time.time() - start_time)} seconds, loss {epoch_loss} ---")
    return model
    

In [None]:
from torch.utils.tensorboard import SummaryWriter

current_time = datetime.now().strftime("%b%d_%H-%M-%S")
writer = SummaryWriter(Path.joinpath(constants.TB_OUT_PATH, "with_sync_" + current_time))

model = training(writer, 5, sync_after_every_n=5, print_every = 1,  early_terminate = 10)
writer.flush()


----- Start Epoch 0 -----
>>>> Batch #  0 [9, 25, 30, 34, 36]
loss: 58.73096466064453
--- 2.011552095413208 Per batch time ---
>>>> Batch #  1 [42, 49, 61, 64, 71]
loss: 45.55876159667969
--- 3.687784194946289 Per batch time ---
>>>> Batch #  2 [72, 73, 74, 77, 78]
loss: 33.52019500732422
--- 5.3537609577178955 Per batch time ---
>>>> Batch #  3 [81, 86, 89, 92, 94]
loss: 48.776939392089844
--- 7.022082090377808 Per batch time ---
>>>> Batch #  4 [109, 110, 113, 127, 133]
loss: 32.89650344848633
--- 12.380654096603394 Per batch time ---
>>>> Batch #  5 [136, 138, 142, 143, 144]
loss: 25.630126953125
--- 14.17305588722229 Per batch time ---
>>>> Batch #  6 [149, 151, 154, 164, 165]
loss: 28.86085319519043
--- 16.017032146453857 Per batch time ---
>>>> Batch #  7 [192, 194, 196, 201, 208]
loss: 15.156055450439453
--- 17.720378160476685 Per batch time ---
>>>> Batch #  8 [241, 247, 250, 257, 260]
loss: 21.24652862548828
--- 19.4218852519989 Per batch time ---
>>>> Batch #  9 [263, 283, 29

In [31]:

current_time = datetime.now().strftime("%b%d_%H-%M-%S")
writer = SummaryWriter(Path.joinpath(constants.TB_OUT_PATH, current_time))

#training(model, writer, 5, empty_catch_after_every_n=None, early_terminate = 10)
model = training(writer, 5, sync_after_every_n=None, print_every = 1, early_terminate = 10)
#training(model, writer, 2, empty_catch_after_every_n=None)
writer.flush()

----- Start Epoch 0 -----
>>>> Batch #  0 [9, 25, 30, 34, 36]
loss: 56.32360076904297
--- 2.0319719314575195 Per batch time ---
>>>> Batch #  1 [42, 49, 61, 64, 71]
loss: 44.15825271606445
--- 3.8284220695495605 Per batch time ---
>>>> Batch #  2 [72, 73, 74, 77, 78]
loss: 32.260902404785156
--- 5.541014194488525 Per batch time ---
>>>> Batch #  3 [81, 86, 89, 92, 94]
loss: 47.25675964355469
--- 7.24758505821228 Per batch time ---
>>>> Batch #  4 [109, 110, 113, 127, 133]
loss: 31.200536727905273
--- 8.984709978103638 Per batch time ---
>>>> Batch #  5 [136, 138, 142, 143, 144]
loss: 25.550800323486328
--- 10.736034154891968 Per batch time ---
>>>> Batch #  6 [149, 151, 154, 164, 165]
loss: 28.784555435180664
--- 12.54118800163269 Per batch time ---
>>>> Batch #  7 [192, 194, 196, 201, 208]
loss: 14.63433837890625
--- 14.258178949356079 Per batch time ---
>>>> Batch #  8 [241, 247, 250, 257, 260]
loss: 21.023765563964844
--- 15.991867065429688 Per batch time ---
>>>> Batch #  9 [263, 2

In [22]:
writer.flush()

In [None]:
def manual(dataset, size):
    items = []
    
    for i in range(size):
        item = dataset.__getitem__(i)
        # replace the `qa` with just the `qs`
        print(item.annotations['qa'])
        item.annotations['qa'] = item.annotations['qs']
        items.append(item)
    return collate_fn2(items)

In [None]:
print(">>> original qa")
test1x, target = manual(val, 5)
model.eval()
answers = model.answer(test1x, device, max_length = 30)

print(">>> prediction")
def token_to_word(x):
    qa = x["qa"]
    return tokenizer.batch_decode(qa)

token_to_word(answers)

In [None]:
def get_answers(x):
    qa = x["qa"]
    answer_token_id = tokenizer.convert_tokens_to_ids(constants.ANSWER_TOKEN)

    answer_start =  (qa == answer_token_id).nonzero()
    mask = torch.zeros_like(qa)
    mask[answer_start[:, 0], answer_start[:, 1]] = 1
    # fill the elements after the [ANSWER] token to be 1.
    mask = mask.cumsum(dim=1)
    just_answers = qa * mask
    return tokenizer.batch_decode(just_answers, skip_special_tokens = True)
    
real_answers = get_answers(answers)

In [None]:
def create_output(x, answers, result):
    qids = x['qids']
    assert len(qids) == len(answers)
    for i in range(len(qids)):
        d = {"question_id" : qids[i], "answer": answers[i]}
        result.append(d)

In [None]:
result = []
create_output(test1x, real_answers, result)
print(result)

In [None]:
val.questions.loc[val.questions['question_id'] == 139001]