In [1]:
import shutil, os, csv, itertools, glob

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from dataloader import train_loader, val_loader

  from ._conv import register_converters as _register_converters


Loading data...
Done
Loading data...
Done


In [7]:
%load_ext autoreload
%autoreload 1
%aimport configs
cfgs = configs
cuda = cfgs.USE_CUDA
%aimport model
CAN = model.CAN

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def get_acc(output, target):
    # takes in two tensors to compute accuracy
    pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct = pred.eq(target.data.view_as(pred)).cpu().sum()
#     print("Output: ")
#     print(output.data.squeeze().cpu().numpy())
#     print("Pred: ")
#     print(pred.squeeze().cpu().numpy())
#     print("Target: ")
#     print(target.data.cpu().numpy())
    return correct, target.size(0)

def run_trainer(model_path, model, train_loader, test_loader, get_acc, resume, num_epoch):

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    def save_checkpoint(state, is_best, filename=model_path+'checkpoint.pth.tar'):
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, model_path+'model_best.pth.tar')
    def get_last_checkpoint(model_path):
        fs = sorted([f for f in os.listdir(model_path) if 'Epoch' in f], key=lambda k: int(k.split()[1]))
        return model_path+fs[-1] if len(fs) > 0 else None
    
    start_epoch = 0
    best_res = 0
    resume_state = get_last_checkpoint(model_path) if resume else None
    if resume_state and os.path.isfile(resume_state):
        print("=> loading checkpoint '{}'".format(resume_state))
        checkpoint = torch.load(resume_state)
        start_epoch = checkpoint['epoch']+1
        best_res = checkpoint['val_acc']
        model.load_state_dict(checkpoint['state_dict'])
        if cuda:
            model.cuda()
        optimizer = optim.Adam(model.parameters(), **cfgs.OPT_PARAM)
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resume_state, checkpoint['epoch']))
    else:
        if cuda:
            model.cuda()
        optimizer = optim.Adam(model.parameters(), **cfgs.OPT_PARAM)

    criterion = nn.CrossEntropyLoss()
    # scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5) # optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5)

    def train(epoch):
        model.train()
        total, total_correct = 0., 0.
        for batch_idx, (img_feats, question, answer) in enumerate(train_loader):
            img_feats, question, answer = Variable(img_feats.float()), Variable(question.float()), Variable(answer.long())
            if cuda:
                img_feats, question, answer = img_feats.cuda(), question.cuda(), answer.cuda()
            optimizer.zero_grad()
            output = model(img_feats, question)
            loss = criterion(output, answer)
            loss.backward()
            optimizer.step()

            correct, num_instance = get_acc(output, answer)
            total_correct += correct
            total += num_instance
            if batch_idx % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Acc: {:.2f}%/{:.2f}%'.format(
                    epoch, batch_idx * cfgs.BATCH_SIZE, len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data[0],
                    100. * correct / num_instance, 100. * total_correct / total ))
        
        return 100. * total_correct / total

    def test():
        model.eval()
        test_loss = 0.
        total, total_correct = 0., 0.
        for img_feats, question, answer in test_loader:
            img_feats, question, answer = Variable(img_feats.float()), Variable(question.float()), Variable(answer.long())
            if cuda:
                img_feats, question, answer = img_feats.cuda(), question.cuda(), answer.cuda()
            optimizer.zero_grad()
            output = model(img_feats, question)
            test_loss += criterion(output, answer).data[0] # sum up batch loss
            
            correct, num_instance = get_acc(output, answer)
            total_correct += correct
            total += num_instance

        test_acc = 100. * total_correct / total
        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, total_correct, total,
            test_acc))

        return test_acc


    for epoch in range(start_epoch, num_epoch):
        is_best = False

        train_acc = train(epoch)
        val_acc = test()
        
        # scheduler.step(val_loss)

        if val_acc > best_res:
            best_res = val_acc
            is_best = True

        save_checkpoint({
                'epoch': epoch,
                'state_dict': model.cpu().state_dict(),
                'train_acc':train_acc,
                'val_acc': val_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best,
            model_path+"Epoch %d Acc %.4f.pt"%(epoch, val_acc))

        if cuda:
            model.cuda()

In [None]:
model = CAN(**cfgs.NET_PARAM)
run_trainer(
    model_path = './ckpt/', 
    model = model, 
    train_loader = train_loader, 
    test_loader = val_loader, 
    get_acc = get_acc, 
    resume = False, 
    num_epoch = 100
)




















Test set: Average loss: 0.0143, Accuracy: 75467.0/149991.0 (50.31%)




















Test set: Average loss: 0.0128, Accuracy: 83755.0/149991.0 (55.84%)




















Test set: Average loss: 0.0098, Accuracy: 102120.0/149991.0 (68.08%)






















Test set: Average loss: 0.0092, Accuracy: 106150.0/149991.0 (70.77%)




















Test set: Average loss: 0.0084, Accuracy: 111155.0/149991.0 (74.11%)




















Test set: Average loss: 0.0073, Accuracy: 117333.0/149991.0 (78.23%)













