In [1]:
import sys
from importlib import reload
sys.path.append("..")

from scene_description import model, data_utils, global_vars

In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam

import json
from tqdm import tqdm

In [3]:
torch.cuda.set_device(1)

In [4]:
mappings_path = "../../qa_data/mappings.json"
processed_data_path = "../../qa_data/processed_data.json"
map_images_path = "../../qa_data/"

In [5]:
mappings = json.load(open(mappings_path))
qa_data = json.load(open(processed_data_path))["data"]

In [6]:
sd_dataset = data_utils.SceneDescriptionDataset(qa_data, mappings, map_images_path)

In [7]:
sd_dataloader = torch.utils.data.DataLoader(
    sd_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=data_utils.collate_fn,
    num_workers=4
)

In [8]:
sd_model = model.SceneDescription(num_answers=len(mappings["aid_2_answer_vqa"]),
                                              vocab_size=len(mappings["vocab"]["wid_2_word"]))

In [9]:
device="cuda"
sd_model = sd_model.to(device)
num_epochs = 10
optimizer = Adam(sd_model.parameters(), lr=1e-4)

In [10]:
def train(model, dataloader, optimizer, epochs):
    lossfn_vqa = nn.CrossEntropyLoss()
    lossfn_des = nn.CrossEntropyLoss(ignore_index=global_vars.PAD_IDX)
    
    for e in range(epochs):
        train_loop = tqdm(dataloader, total=len(dataloader), position=0, leave=False)
        
        loss_vqa = 0.
        loss_des = 0.
        loss_total = 0.
        for batch_num, ((questions, imgs), (answers, answers_vqa)) in enumerate(train_loop):
            
            answers=answers.long().to(device)
            answers_vqa=answers_vqa.long().to(device)
            imgs = imgs.to(device)
            
            optimizer.zero_grad() # clear gradients from previous minibatch
            
            op_vqa, op_des = model(questions, imgs, answers, device=device)
            
            batch_loss_vqa = lossfn_vqa(op_vqa, answers_vqa)
            batch_loss_des = lossfn_des(op_des[:, :-1, :].permute(0, 2, 1), answers[:, 1:])
            
            total_loss = batch_loss_des + batch_loss_vqa
            total_loss.backward()
            
            loss_vqa += batch_loss_vqa.item()
            loss_des += batch_loss_des.item()
            loss_total += total_loss.item()
            
            optimizer.step()
            
            train_loop.set_description("Epoch={}".format(e+1))
            train_loop.set_postfix(vqa_loss=batch_loss_vqa.item(), des_loss=batch_loss_des.item(), total_loss=total_loss.item())
        
        print("Epoch={} vqa_loss={} des_loss={} total_loss={}".format(e, loss_vqa/len(dataloader),
                                                                    loss_des/len(dataloader),
                                                                    loss_total/len(dataloader)))

In [18]:
train(sd_model, sd_dataloader, optimizer, epochs=10)

  0%|          | 0/2990 [00:00<?, ?it/s]                                                                       

Epoch=0 vqa_loss=0.17870642439801757 des_loss=0.5668478375394209 total_loss=0.7455542616620909


  0%|          | 0/2990 [00:00<?, ?it/s]                                                                       

Epoch=1 vqa_loss=0.11835765855505215 des_loss=0.5483511634893641 total_loss=0.6667088216762479


                                                                                                              

KeyboardInterrupt: 

In [19]:
torch.save(sd_model.state_dict(), "../../scene_description.ckpt")

## Evaluate

In [11]:
sd_model.load_state_dict(torch.load("../../scene_description.ckpt"))

<All keys matched successfully>

In [12]:
def eval(model, dataloader):
    model.eval()
    lossfn_vqa = nn.CrossEntropyLoss()
    lossfn_des = nn.CrossEntropyLoss(ignore_index=global_vars.PAD_IDX)
    
    preds = {
        "questions": [],
        "vqa_preds": [],
        "des_preds": [],
        "vqa_true": [],
        "des_true": [],
        "imgs": [],   
    }
    
    eval_loop = tqdm(dataloader, total=len(dataloader), position=0, leave=False)
    loss_vqa = 0.
    loss_des = 0.
    loss_total = 0.
    total_batches = 0
    
    for batch_num, ((questions, imgs), (answers, answers_vqa)) in enumerate(eval_loop):
        
        answers = answers.long().to(device)
        answers_vqa = answers_vqa.long().to(device)
        imgs = imgs.to(device)
        
        op_vqa, op_des = model(questions, imgs, device=device)
        
        batch_loss_vqa = lossfn_vqa(op_vqa, answers_vqa)
        batch_loss_des = lossfn_des(op_des[:, :-1, :].permute(0, 2, 1), answers[:, 1:])
            
        total_loss = batch_loss_des + batch_loss_vqa
        
        loss_vqa += batch_loss_vqa.item()
        loss_des += batch_loss_des.item()
        loss_total += total_loss.item()
        total_batches += 1
        
        preds["vqa_true"].extend(answers_vqa.cpu.numpy())
        preds["des_true"].extend(answer.cpu().numpy())
        preds["questions"].extend(questions)
        preds["vqa_preds"].extend(torch.argmax(torch.softmax(op_vqa, dim=-1), dim=-1).cpu().numpy())
        preds["des_preds"].extend(torch.argmax(torch.softmax(op_des, dim=-1), dim=-1).cpu().numpy())
        preds["imgs"].extend(imgs.cpu())
        
        train_loop.set_description("EVAL".format(e+1))
        train_loop.set_postfix(vqa_loss=batch_loss_vqa.item(), des_loss=batch_loss_des.item(), total_loss=total_loss.item())
        
        break
    print("vqa_loss={} des_loss={} total_loss={}".format(e, loss_vqa/total_batches,
                                                            loss_des/total_batches,
                                                            loss_total/total_batches))
    return preds

In [13]:
eval_preds = eval(sd_model, sd_dataloader)

  0%|          | 0/2990 [00:00<?, ?it/s]

tensor([[100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100],
        [100]], device='cuda:1') torch.Size([64, 1])
to

                                        

RuntimeError: CUDA error: device-side assert triggered