In [1]:
import sys
from importlib import reload
sys.path.append("..")

from scene_description import model, data_utils, global_vars

In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam

import json
from tqdm import tqdm

In [3]:
mappings_path = "../../qa_data/mappings.json"
processed_data_path = "../../qa_data/processed_data.json"
map_images_path = "../../qa_data/"

In [4]:
mappings = json.load(open(mappings_path))
qa_data = json.load(open(processed_data_path))["data"]

In [6]:
sd_dataset = data_utils.SceneDescriptionDataset(qa_data, mappings, map_images_path)

In [13]:
sd_dataloader = torch.utils.data.DataLoader(
    sd_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=data_utils.collate_fn,
    num_workers=4
)

In [8]:
sd_model = model.SceneDescription(num_answers=len(mappings["aid_2_answer_vqa"]),
                                              vocab_size=len(mappings["vocab"]["wid_2_word"]))

In [9]:
device="cuda"
sd_model = sd_model.to(device)
num_epochs = 10
optimizer = Adam(sd_model.parameters(), lr=3e-4)

In [14]:
def train(model, dataloader, optimizer, epochs):
    lossfn_vqa = nn.CrossEntropyLoss()
    lossfn_des = nn.CrossEntropyLoss(ignore_index=global_vars.PAD_IDX)
    
    for e in range(epochs):
        train_loop = tqdm(dataloader, total=len(dataloader), position=0, leave=False)
        
        loss_vqa = 0.
        loss_des = 0.
        loss_total = 0.
        for batch_num, ((questions, imgs), (answers, answers_vqa)) in enumerate(train_loop):
            
            answers=answers.long().to(device)
            answers_vqa=answers_vqa.long().to(device)
            imgs = imgs.to(device)
            
            optimizer.zero_grad() # clear gradients from previous minibatch
            
            op_vqa, op_des = model(questions, imgs, answers, device=device)
            
            batch_loss_vqa = lossfn_vqa(op_vqa, answers_vqa)
            batch_loss_des = lossfn_des(op_des[:, :-1, :].permute(0, 2, 1), answers[:, 1:])
            
            total_loss = batch_loss_des + batch_loss_vqa
            total_loss.backward()
            
            loss_vqa += batch_loss_vqa.item()
            loss_des += batch_loss_des.item()
            loss_total += total_loss.item()
            
            optimizer.step()
            
            train_loop.set_description("Epoch={}".format(e+1))
            train_loop.set_postfix(vqa_loss=batch_loss_vqa.item(), des_loss=batch_loss_des.item(), total_loss=total_loss.item())
        
        print("Epoch={} vqa_loss={} des_loss={} total_loss={}".format(loss_vqa/len(dataloader),
                                                                    loss_des/len(dataloader),
                                                                    loss_total/len(dataloader)))

In [None]:
train(sd_model, sd_dataloader, optimizer, epochs=10)

Epoch=1:  12%|█▏        | 359/2990 [03:40<27:58,  1.57it/s, des_loss=1.01, total_loss=1.47, vqa_loss=0.46]  