In [2]:
import argparse, yaml
from openvqa.models.model_loader import CfgLoader
from utils1.exec1 import Execution

In [3]:
import os, copy
import sys
from openvqa.datasets.dataset_loader import DatasetLoader

class Execution:
    def __init__(self, __C):
        self.__C = __C

        print('Loading dataset........')
        self.dataset = DatasetLoader(__C).DataSet()

        # If trigger the evaluation after every epoch
        # Will create a new cfgs with RUN_MODE = 'val'
        self.dataset_eval = None
        if __C.EVAL_EVERY_EPOCH:
            __C_eval = copy.deepcopy(__C)
            setattr(__C_eval, 'RUN_MODE', 'val')

            print('Loading validation set for per-epoch evaluation........')
            self.dataset_eval = DatasetLoader(__C_eval).DataSet()


    def run(self, run_mode):
        
        if run_mode == 'train':
            if self.__C.RESUME is False:
                self.empty_log(self.__C.VERSION)
            train_engine(self.__C, self.dataset, self.dataset_eval)

        elif run_mode == 'val':
            test_engine(self.__C, self.dataset, validation=True)

        elif run_mode == 'test':
            test_engine(self.__C, self.dataset)

        else:
            exit(-1)


    def empty_log(self, version):
        print('Initializing log file........')
        if (os.path.exists(self.__C.LOG_PATH + '/log_run_' + version + '.txt')):
            os.remove(self.__C.LOG_PATH + '/log_run_' + version + '.txt')
        print('Finished!')
        print('')



In [10]:
def parse_args():
    '''
    Parse input arguments
    '''
    parser = argparse.ArgumentParser(description='OpenVQA Args')

    parser.add_argument('--RUN', dest='RUN_MODE',
                      choices=['train', 'val', 'test'],
                      help='{train, val, test}',
                      default='train',
                      type=str, required=False)

    parser.add_argument('--MODEL', dest='MODEL',
                      choices=[
                           'mcan_small',
                           'mcan_small_wa',
                           'mcan_large',
                           'ban_4',
                           #Edits
                           'ban_8_wa',
                           'baseline_wa',
                           #End of Edits
                           'ban_8',
                           'mfb',
                           'mfb_wa',
                           'mfh',
                           'mfh_wa',
                           'mem',
                           'butd',
                           'butd_wa',
                           'baseline',
                           'baseline_wa_no_fusion',
                           'positional',
                           'mcan_large_wa',
                           'mcan_small_augmented',
                           'mcan_small_without_a'
                           ]
                        ,
                      help='{'
                           'mcan_small,'
                           'mcan_small_wa,'
                           'mcan_large,'
                            #Edits
                           'ban_wa,'
                           'baseline_wa,'
                           #End of Edits
                           'ban_4,'
                           'ban_8,'
                           'mfb,'
                           'mfb_wa,'
                           'mfh,'
                           'mfh_wa,'
                           'butd,'
                           'butd_wa,'
                           'baseline,'
                           'baseline_wa_no_fusion,'
                           'positional,'
                           '}'
                        ,
                      type=str, required=True)

    parser.add_argument('--DATASET', dest='DATASET',
                      choices=['vqa', 'gqa', 'clevr'],
                      help='{'
                           'vqa,'
                           'gqa,'
                           'clevr,'
                           '}'
                        ,
                      default='vqa',  
                      type=str, required=False)

    parser.add_argument('--SPLIT', dest='TRAIN_SPLIT',
                      choices=['train', 'train+val', 'train+val+vg'],
                      help="set training split, "
                           "vqa: {'train', 'train+val', 'train+val+vg'}"
                           "gqa: {'train', 'train+val'}"
                           "clevr: {'train', 'train+val'}"
                        ,
                        default='train', required=False,
                      type=str)

    parser.add_argument('--EVAL_EE', dest='EVAL_EVERY_EPOCH',
                      choices=['True', 'False'],
                      help='True: evaluate the val split when an epoch finished,'
                           'False: do not evaluate on local',
                           default='True',
                           required=False,
                      type=str)

    parser.add_argument('--SAVE_PRED', dest='TEST_SAVE_PRED',
                      choices=['True', 'False'],
                      help='True: save the prediction vectors,'
                           'False: do not save the prediction vectors',
                      default='True',
                      required=False,
                      type=str)

    parser.add_argument('--BS', dest='BATCH_SIZE',
                      help='batch size in training',
                      type=int)

    parser.add_argument('--GPU', dest='GPU',
                      help="gpu choose, eg.'0, 1, 2, ...'",
                      default='0, 1',
                      type=str)

    parser.add_argument('--SEED', dest='SEED',
                      help='fix random seed',
                      type=int)

    parser.add_argument('--VERSION', dest='VERSION',
                      help='Enter descriptive name here (eg baseline_wa_gru), will be used for WANDB and for version',
                      required=True,
                      type=str)

    parser.add_argument('--RESUME', dest='RESUME',
                      choices=['True', 'False'],
                      help='True: use checkpoint to resume training,'
                           'False: start training with random init',
                      type=str)

    parser.add_argument('--CKPT_V', dest='CKPT_VERSION',
                      help='checkpoint version',
                      type=str)

    parser.add_argument('--CKPT_E', dest='CKPT_EPOCH',
                      help='checkpoint epoch',
                      type=int)

    parser.add_argument('--CKPT_PATH', dest='CKPT_PATH',
                      help='load checkpoint path, we '
                           'recommend that you use '
                           'CKPT_VERSION and CKPT_EPOCH '
                           'instead, it will override'
                           'CKPT_VERSION and CKPT_EPOCH',
                      type=str)

    parser.add_argument('--ACCU', dest='GRAD_ACCU_STEPS',
                      help='split batch to reduce gpu memory usage',
                      type=int)

    parser.add_argument('--NW', dest='NUM_WORKERS',
                      help='multithreaded loading to accelerate IO',
                      type=int)

    parser.add_argument('--PINM', dest='PIN_MEM',
                      choices=['True', 'False'],
                      help='True: use pin memory, False: not use pin memory',
                      type=str)

    parser.add_argument('--VERB', dest='VERBOSE',
                      choices=['True', 'False'],
                      help='True: verbose print, False: simple print',
                      type=str)

    parser.add_argument('--USE_NEW_QUESTION', dest='USE_NEW_QUESTION',
                      choices=['True', 'False'],
                      help='whether to use new question while testing',
                      default='False',
                      type=str)

    parser.add_argument('--NEW_QUESTION', dest='NEW_QUESTION',
                      help='the new question to be asked while testing',
                      type=str)

    parser.add_argument('--IMAGE_ID', dest='IMAGE_ID',
                      help='image id on which the questions to be asked',
                      type=str)
    
    ######################################################
    #########  CHANGE MODEL AND VERSION HERE #############
    ######################################################
    args = parser.parse_args(args=['--MODEL', 'mfb_wa', '--VERSION', 'fixed_decoder', '--GPU', '0', '--DATASET', 'vqa', '--CKPT_V', 'mfb_ans_img_only', '--CKPT_E', '13'])
    return args

In [11]:
args = parse_args()
print(args)

Namespace(BATCH_SIZE=None, CKPT_EPOCH=13, CKPT_PATH=None, CKPT_VERSION='mfb_ans_img_only', DATASET='vqa', EVAL_EVERY_EPOCH='True', GPU='0', GRAD_ACCU_STEPS=None, IMAGE_ID=None, MODEL='mfb_wa', NEW_QUESTION=None, NUM_WORKERS=None, PIN_MEM=None, RESUME=None, RUN_MODE='train', SEED=None, TEST_SAVE_PRED='True', TRAIN_SPLIT='train', USE_NEW_QUESTION='False', VERBOSE=None, VERSION='fixed_decoder')


In [12]:
cfg_file = "configs/{}/{}.yml".format(args.DATASET, args.MODEL)
with open(cfg_file, 'r') as f:

    # Loads the yaml file
    yaml_dict = yaml.load(f)

# Loads the model_cfgs + base_cfgs
__C = CfgLoader(yaml_dict['MODEL_USE']).load()

# Loads the command line cfgs
args = __C.str_to_bool(args)
args_dict = __C.parse_to_dict(args)

# {**dict1, **dict2} creates a new dictionary by merging dict1 and dict2, using dict2 for key clashes
args_dict = {**yaml_dict, **args_dict}
__C.add_args(args_dict)
__C.proc()

# FINAL PREFERENCE OF CFGS:
# COMMAND LINE > YAML FILE > MODEL CFGS > BASE CFGS

print('Hyper Parameters:')
print(__C)

Checking dataset ........
Finished!

Hyper Parameters:
{ ALPHA             }->1
{ ANS_STDDEV        }->0.1
{ AUGMENTED_ANSWER  }->True
{ BATCH_SIZE        }->64
{ BETA              }->30.0
{ CACHE_PATH        }->./results/cache
{ CAP_DIST          }->0.3
{ CKPTS_PATH        }->./ckpts
{ CKPT_EPOCH        }->13
{ CKPT_PATH         }->None
{ CKPT_VERSION      }->mfb_ans_img_only
{ DATASET           }->vqa
{ DATA_PATH         }->{'vqa': './data/vqa', 'gqa': './data/gqa', 'clevr': './data/clevr'}
{ DATA_ROOT         }->./data
{ DEVICES           }->[0]
{ DROPOUT_R         }->0.1
{ EVAL_BATCH_SIZE   }->32
{ EVAL_EVERY_EPOCH  }->True
{ FEATS_PATH        }->{'vqa': {'train': './data/vqa/feats/train2014', 'val': './data/vqa/feats/val2014', 'test': './data/vqa/feats/test2015'}, 'gqa': {'default-frcn': './data/gqa/feats/gqa-frcn', 'default-grid': './data/gqa/feats/gqa-grid'}, 'clevr': {'train': './data/clevr/feats/train', 'val': './data/clevr/feats/val', 'test': './data/clevr/feats/test'}}
{ FEA

  """


In [13]:
execution = Execution(__C)

Loading dataset........
Loading all questions (for statistics)
Loading all image features
Loading split questions and answers

Tokenising questions
Tokenising answers
Finished!

Loading validation set for per-epoch evaluation........
Loading all questions (for statistics)
Loading all image features
Loading split questions and answers

Tokenising questions
Tokenising answers
Finished!



In [8]:
# --------------------------------------------------------
# OpenVQA
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------

import os, torch, datetime, shutil, time
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import wandb
from openvqa.models.model_loader import ModelLoader
from openvqa.utils.optim import get_optim, adjust_lr
from utils1.test_engine import test_engine, ckpt_proc
from vis import plotter, vis_func
from multiprocessing import Pool
import multiprocessing
import sys

def train_engine(__C, dataset, dataset_eval=None):

    data_size = dataset.data_size
    token_size = dataset.token_size
    ans_size = dataset.ans_size
    pretrained_emb = dataset.pretrained_emb

    #Edits
    pretrained_emb_ans = dataset.pretrained_emb_ans
    token_size_ans = dataset.token_size_ans #End of Edits

    print("Model being used is {}".format(__C.MODEL_USE))

    net = ModelLoader(__C).Net(
        __C,
        pretrained_emb,
        token_size,
        ans_size,
        pretrained_emb_ans,
        token_size_ans
    )

    net.cuda()
    net.train()

    if __C.N_GPU > 1:
        net = nn.DataParallel(net, device_ids=__C.DEVICES)

    # Define Loss Function
    loss_fn = eval('torch.nn.' + __C.LOSS_FUNC_NAME_DICT[__C.LOSS_FUNC] + "(reduction='" + __C.LOSS_REDUCTION + "').cuda()")


    # creating a folder for saving the numpy visualization arrays
    if (__C.WITH_ANSWER and ((__C.VERSION) not in os.listdir(__C.SAVED_PATH))):
        os.mkdir(__C.SAVED_PATH + '/' + __C.VERSION)

    ###############################################################
    ######## Load the pretrained ans+img only model ###############
    ###############################################################

    if __C.LOAD_PRETRAINED:
        print("using the pretrained model for ans+img encoder and decoder both parts")

        path = __C.CKPTS_PATH + \
               '/ckpt_' + __C.CKPT_VERSION + \
               '/epoch' + str(__C.CKPT_EPOCH) + '.pkl'
        
        print('Loading ckpt from {}'.format(path))
        ckpt = torch.load(path)
        print('Finish!')

        pretrained_state_dict = torch.load(path)['state_dict']
        
        net_state_dict = net.state_dict()

        #print("filtering keys from pretrained state dict")
        #pretrained_state_dict_updated = {k: v for k, v in pretrained_state_dict.items() if k in net_state_dict} 
    
        print("updating keys in net_state_dict form pretrained state dict")
        net_state_dict.update(pretrained_state_dict)

        print("loading this state dict in the net")
        '''
        if __C.N_GPU > 1:
            net.load_state_dict(ckpt_proc(net_state_dict))
        else:
        '''
        net.load_state_dict(net_state_dict)
        print("loaded net state dict succesfully")
        

    # Load checkpoint if resume training
    if __C.RESUME:
        print(' ========== Resume training')

        if __C.CKPT_PATH is not None:
            print('Warning: Now using CKPT_PATH args, '
                  'CKPT_VERSION and CKPT_EPOCH will not work')
            path = __C.CKPT_PATH
        else:
            path = __C.CKPTS_PATH + \
                   '/ckpt_' + __C.CKPT_VERSION + \
                   '/epoch' + str(__C.CKPT_EPOCH) + '.pkl'

        # Load the network parameters
        print('Loading ckpt from {}'.format(path))
        ckpt = torch.load(path)
        print('Finish!')

        if __C.N_GPU > 1:
            net.load_state_dict(ckpt_proc(ckpt['state_dict']))
        else:
            net.load_state_dict(ckpt['state_dict'])
        start_epoch = ckpt['epoch']

        # Load the optimizer paramters
        optim = get_optim(__C, net, data_size, ckpt['lr_base'])
        optim._step = int(data_size / __C.BATCH_SIZE * start_epoch)
        optim.optimizer.load_state_dict(ckpt['optimizer'])
        
        if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH):
            os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)

    else:
        if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH):
            #shutil.rmtree(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)
            os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)

        optim = get_optim(__C, net, data_size)
        start_epoch = 0

    loss_sum = 0
    named_params = list(net.named_parameters())
    grad_norm = np.zeros(len(named_params))

    # Define multi-thread dataloader
    # if __C.SHUFFLE_MODE in ['external']:
    #     dataloader = Data.DataLoader(
    #         dataset,
    #         batch_size=__C.BATCH_SIZE,
    #         shuffle=False,
    #         num_workers=__C.NUM_WORKERS,
    #         pin_memory=__C.PIN_MEM,
    #         drop_last=True
    #     )
    # else:
    dataloader = Data.DataLoader(
        dataset,
        batch_size=__C.BATCH_SIZE,
        shuffle=True,
        num_workers=__C.NUM_WORKERS,
        pin_memory=__C.PIN_MEM,
        drop_last=True
    )

    logfile = open(
        __C.LOG_PATH +
        '/log_run_' + __C.VERSION + '.txt',
        'a+'
    )
    logfile.write(str(__C))
    logfile.close()

    # For dry runs
    # os.environ['WANDB_MODE'] = 'dryrun' 

    # initializing the wandb project
    # TODO to change the name of project later, once the proper coding starts
    wandb.init(project="openvqa", name=__C.VERSION, config=__C)

    # obtain histogram of each gradients in network as it trains
    wandb.watch(net, log="all")

    wandb.save("./openvqa/models/" + str(__C.MODEL_USE) + "/net.py")
    wandb.save("./utils1/train_engine.py")

    # Training script
    for epoch in range(start_epoch, __C.MAX_EPOCH):

        # Save log to file
        logfile = open(
            __C.LOG_PATH +
            '/log_run_' + __C.VERSION + '.txt',
            'a+'
        )
        logfile.write(
            '=====================================\nnowTime: ' +
            datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') +
            '\n'
        )
        logfile.close()

        # Learning Rate Decay
        if epoch in __C.LR_DECAY_LIST:
            adjust_lr(optim, __C.LR_DECAY_R)

        # Externally shuffle data list
        # if __C.SHUFFLE_MODE == 'external':
        #     dataset.shuffle_list(dataset.ans_list)

        time_start = time.time()
        # Iteration
        for step, (
                frcn_feat_iter,
                grid_feat_iter,
                bbox_feat_iter,
                ques_ix_iter,

                #Edits
                ans_ix_iter,
                #End of Edits

                ans_iter,
                ques_type

        ) in enumerate(dataloader):

            optim.zero_grad()

            frcn_feat_iter = frcn_feat_iter.cuda()
            grid_feat_iter = grid_feat_iter.cuda()
            bbox_feat_iter = bbox_feat_iter.cuda()
            ques_ix_iter = ques_ix_iter.cuda()
            #Edits
            ans_ix_iter = ans_ix_iter.cuda()
            #End of Edits
            ans_iter = ans_iter.cuda()

            loss_tmp = 0

            loss_img_ques_tmp = 0
            loss_ans_tmp = 0
            loss_interp_tmp = 0
            loss_fusion_tmp = 0

            for accu_step in range(__C.GRAD_ACCU_STEPS):
                loss_tmp = 0
                loss_img_ques_tmp = 0
                loss_ans_tmp = 0
                loss_interp_tmp = 0
                loss_fusion_tmp = 0

                sub_frcn_feat_iter = \
                    frcn_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_grid_feat_iter = \
                    grid_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_bbox_feat_iter = \
                    bbox_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
                                  (accu_step + 1) * __C.SUB_BATCH_SIZE]
                sub_ques_ix_iter = \
                    ques_ix_iter[accu_step * __C.SUB_BATCH_SIZE:
                                 (accu_step + 1) * __C.SUB_BATCH_SIZE]
                #Edits
                sub_ans_ix_iter = \
                    ans_ix_iter[accu_step * __C.SUB_BATCH_SIZE:
                                 (accu_step + 1) * __C.SUB_BATCH_SIZE]
                #End of Edits

                sub_ans_iter = \
                    ans_iter[accu_step * __C.SUB_BATCH_SIZE:
                             (accu_step + 1) * __C.SUB_BATCH_SIZE]

                
                # when making predictions also pass the ans_iter which is a dictionary from which you
                # can extract answers and pass them through decoders

                if (__C.WITH_ANSWER):
                    pred_img_ques, pred_ans, pred_fused, z_img_ques, z_ans, z_fused = net(
                        sub_frcn_feat_iter,
                        sub_grid_feat_iter,
                        sub_bbox_feat_iter,
                        sub_ques_ix_iter,
                        sub_ans_ix_iter,
                        step,
                        epoch
                    )
                else:
                     pred_img_ques = net(
                        sub_frcn_feat_iter,
                        sub_grid_feat_iter,
                        sub_bbox_feat_iter,
                        sub_ques_ix_iter,
                        sub_ans_ix_iter,
                        step,
                        epoch
                    )
                   
                # we need to change the loss terms accordingly
                # now we need to modify the loss terms for the same
                
                #Edits: creating the loss items for each of the prediction vector

                loss_item_img_ques = [pred_img_ques, sub_ans_iter]

                # only calculate the ans and interp loss in case of WITH_ANSWER
                if (__C.WITH_ANSWER):
                    loss_item_ans = [pred_ans, sub_ans_iter]
                    loss_item_interp = [pred_fused, sub_ans_iter]

                
                loss_nonlinear_list = __C.LOSS_FUNC_NONLINEAR[__C.LOSS_FUNC]
                
                # applying the same transformation on the all three
                # althought for 'bce' loss the following does nothing
                for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list):
                    if loss_nonlinear in ['flat']:
                        loss_item_img_ques[item_ix] = loss_item_img_ques[item_ix].view(-1)
                    elif loss_nonlinear:
                        loss_item_img_ques[item_ix] = eval('F.' + loss_nonlinear + '(loss_item_img_ques[item_ix], dim=1)')

                for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list):
                    if loss_nonlinear in ['flat'] and __C.WITH_ANSWER:
                        loss_item_ans[item_ix] = loss_item_ans[item_ix].view(-1)
                    elif loss_nonlinear and __C.WITH_ANSWER:
                        loss_item_ans[item_ix] = eval('F.' + loss_nonlinear + '(loss_item_ans[item_ix], dim=1)')

                for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list):
                    if loss_nonlinear in ['flat'] and __C.WITH_ANSWER:
                        loss_item_interp[item_ix] = loss_item_interp[item_ix].view(-1)
                    elif loss_nonlinear and __C.WITH_ANSWER:
                        loss_item_interp[item_ix] = eval('F.' + loss_nonlinear + '(loss_item_interp[item_ix], dim=1)')


                # Now we create all the four losses and then add them
                #print("shape of loss_item_img_ques[0] is {} and of loss_item_img_ques[1] is {}".format(loss_item_img_ques[0],loss_item_img_ques[1]))
                loss_img_ques = loss_fn(loss_item_img_ques[0], loss_item_img_ques[1])

                loss = 0
                loss += loss_img_ques
                
                if (__C.WITH_ANSWER):

                    # loss for the prediction from the answer
                    #print("shape of loss_item_ans[0] is {} and of loss_item_ans[1] is {}".format(loss_item_ans[0],loss_item_ans[1]))
                    loss_ans = loss_fn(loss_item_ans[0], loss_item_ans[1])
                
                    # Loss for the prediction from the fused vector
                    # I am keeping the loss same as bce but we can change it later for more predictions
                    # loss_fused = interpolation loss
                    #print("shape of loss_item_interp[0] is {} and of loss_item_interp[1] is {}".format(loss_item_interp[0],loss_item_interp[1]))
                    loss_interp = loss_fn(loss_item_interp[0], loss_item_interp[1])
                    
                    # we also need to multiply this fused loss by a hyperparameter alpha
                    # put the alpha in the config and uncomment the following line
                    loss_interp *= __C.ALPHA
                    loss += loss_ans + loss_interp

                    if (__C.WITH_FUSION_LOSS):

                        # Now calculate the fusion loss
                        #1. Higher loss for higher distance between vectors predicted
                        # by different models for same example

                        dist_calc = (z_img_ques - z_ans).pow(2).sum(1).sqrt()
                        #print("Count of distances being clipped (true is clipped): ", np.unique((dist_calc > __C.CAP_DIST).cpu().numpy(), return_counts=True))

                        '''
                        loss_fusion = torch.min(
                                torch.tensor(__C.CAP_DIST).cuda(),
                                dist_calc
                                ).mean()

                        #2. Lower loss for more distance between two pred vectors of same model
                        loss_fusion -= torch.min(
                                torch.tensor(__C.CAP_DIST).cuda(), 
                                torch.pdist(z_img_ques, 2)
                                ).mean() 

                        loss_fusion -= torch.min(
                                torch.tensor(__C.CAP_DIST).cuda(), 
                                torch.pdist(z_ans, 2)
                                ).mean() 
                        '''

                        loss_fusion = dist_calc.mean()

                        #2. Lower loss for more distance between two pred vectors of same model
                        '''
                        calculating pairwise intra distance on same type questions
                        '''
                        '''
                        types = ['other', 'yes/no', 'number']
                        for i in range(3):
                            j = (i+1)%3
                            indices_i = [k for k, val in enumerate(ques_type) if val == types[i]]
                            indices_j = [k for k, val in enumerate(ques_type) if val == types[j]]
                            if ((indices_i != []) and (indices_j != [])):
                                loss_fusion -= torch.cdist(z_img_ques[indices_i], z_img_ques[indices_j]).mean()
                                loss_fusion -= torch.cdist(z_ans[indices_i], z_ans[indices_j]).mean()
                            if (indices_i != []):
                                loss_fusion += torch.pdist(z_img_ques[indices_i], 2).mean()
                                loss_fusion += torch.pdist(z_ans[indices_i], 2).mean()
                        '''
                        loss_fusion -= torch.pdist(z_img_ques, 2).mean() 

                        loss_fusion -= torch.pdist(z_ans, 2).mean() 


                        # Multiply the loss fusion with hyperparameter beta
                        loss_fusion *= __C.BETA

                        #print('fusion loss is : {}'.format(loss_fusion))

                        loss += loss_fusion

                
                loss /= __C.GRAD_ACCU_STEPS
                loss.backward()

                loss_tmp += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
                loss_sum += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS

                # calculating temp loss of each type
                if __C.WITH_ANSWER:
                    loss_img_ques_tmp += loss_img_ques.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
                    loss_ans_tmp += loss_ans.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
                    loss_interp_tmp += loss_interp.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
                    if (__C.WITH_FUSION_LOSS):
                        loss_fusion_tmp += loss_fusion.cpu().data.numpy() * __C.GRAD_ACCU_STEPS


            if __C.VERBOSE:
                if dataset_eval is not None:
                    mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['val']
                else:
                    mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['test']

                print("\r[Version %s][Epoch %2d][Step %4d/%4d] Loss: %.4f [iq: %.4f,ans: %.4f,interp: %.4f,fusion: %.4f]" % (
                    __C.VERSION,
                    epoch + 1,
                    step,
                    int(data_size / __C.BATCH_SIZE),
                    loss_tmp / __C.SUB_BATCH_SIZE,
                    loss_img_ques_tmp / __C.SUB_BATCH_SIZE,
                    loss_ans_tmp / __C.SUB_BATCH_SIZE,
                    loss_interp_tmp / __C.SUB_BATCH_SIZE,
                    loss_fusion_tmp / __C.SUB_BATCH_SIZE
                ), end = '          ')

            # Gradient norm clipping
            if __C.GRAD_NORM_CLIP > 0:
                nn.utils.clip_grad_norm_(
                    net.parameters(),
                    __C.GRAD_NORM_CLIP
                )

            # Save the gradient information
            for name in range(len(named_params)):
                norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \
                    if named_params[name][1].grad is not None else 0
                grad_norm[name] += norm_v * __C.GRAD_ACCU_STEPS
                # print('Param %-3s Name %-80s Grad_Norm %-20s'%
                #       (str(grad_wt),
                #        params[grad_wt][0],
                #        str(norm_v)))

            optim.step()

        time_end = time.time()
        elapse_time = time_end-time_start
        print('Finished in {}s'.format(int(elapse_time)))
        epoch_finish = epoch + 1

        # Save checkpoint
        if __C.N_GPU > 1:
            state = {
                'state_dict': net.module.state_dict(),
                'optimizer': optim.optimizer.state_dict(),
                'lr_base': optim.lr_base,
                'epoch': epoch_finish
            }
        else:
            state = {
                'state_dict': net.state_dict(),
                'optimizer': optim.optimizer.state_dict(),
                'lr_base': optim.lr_base,
                'epoch': epoch_finish
            }
        torch.save(
            state,
            __C.CKPTS_PATH +
            '/ckpt_' + __C.VERSION +
            '/epoch' + str(epoch_finish) +
            '.pkl'
        )

        wandb.save(
            __C.CKPTS_PATH +
            '/ckpt_' + __C.VERSION +
            '/epoch' + str(epoch_finish) +
            '.h5'
        )
        
        # Logging
        logfile = open(
            __C.LOG_PATH +
            '/log_run_' + __C.VERSION + '.txt',
            'a+'
        )
        logfile.write(
            'Epoch: ' + str(epoch_finish) +
            ', Loss: ' + str(loss_sum / data_size) +
            ', Lr: ' + str(optim._rate) + '\n' +
            'Elapsed time: ' + str(int(elapse_time)) + 
            ', Speed(s/batch): ' + str(elapse_time / step) +
            '\n\n'
        )
        logfile.close()

        wandb.log({
            'Loss': float(loss_sum / data_size),
            'Learning Rate': optim._rate,
            'Elapsed time': int(elapse_time) 
            })

        # ---------------------------------------------- #
        # ---- Create visualizations in new processes----#
        # ---------------------------------------------- #
        dic = {}
        dic['version'] = __C.VERSION
        dic['epoch'] = epoch 
        dic['num_samples'] = 1000

        p = Pool(processes= 1)
        p.map_async(vis_func, (dic, ))
        p.close()

        # Eval after every epoch
        epoch_dict = {
                'current_epoch': epoch
                }
        __C.add_args(epoch_dict)
        if dataset_eval is not None:
            test_engine(
                __C,
                dataset_eval,
                state_dict=net.state_dict(),
                validation=True,
                epoch = 0
            )
        p.join()

        # if self.__C.VERBOSE:
        #     logfile = open(
        #         self.__C.LOG_PATH +
        #         '/log_run_' + self.__C.VERSION + '.txt',
        #         'a+'
        #     )
        #     for name in range(len(named_params)):
        #         logfile.write(
        #             'Param %-3s Name %-80s Grad_Norm %-25s\n' % (
        #                 str(name),
        #                 named_params[name][0],
        #                 str(grad_norm[name] / data_size * self.__C.BATCH_SIZE)
        #             )
        #         )
        #     logfile.write('\n')
        #     logfile.close()

        loss_sum = 0
        grad_norm = np.zeros(len(named_params))


In [14]:
execution.run(__C.RUN_MODE)

Initializing log file........
Finished!

Model being used is mfb
Training________________________________
using the pretrained model for ans+img encoder and decoder both parts
Loading ckpt from ./ckpts/ckpt_mfb_ans_img_only/epoch13.pkl
Finish!
updating keys in net_state_dict form pretrained state dict
loading this state dict in the net
loaded net state dict succesfully


Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
wandb: Wandb version 0.8.28 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


[Version fixed_decoder][Epoch  1][Step  165/6933] Loss: 8.7682 [iq: 2.4543,ans: 3.8242,interp: 2.3413,fusion: 0.1484]             

KeyboardInterrupt: 