In [1]:
from comet_ml import Experiment, OfflineExperiment

import os
import sys
import torch
import random
import logging
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from visdial.model import get_model
from visdial.data.dataset import VisDialDataset
from visdial.metrics import SparseGTMetrics, NDCG
from visdial.utils.checkpointing import CheckpointManager, load_checkpoint_from_config
from visdial.utils import move_to_cuda
from options import get_comet_experiment, get_training_config_and_args
from torch.utils.tensorboard import SummaryWriter
from visdial.optim import Adam, LRScheduler, get_weight_decay_params
from visdial.loss import FinetuneLoss
import argparse

In [7]:
import yaml
config_path = 'configs/v002_abc_LP_lkf_D36.yml'
config = yaml.load(open(config_path),Loader=yaml.SafeLoader)

In [24]:
config['dataset']['train_json_dense_dialog_path'] = '/media/local_workspace/quang/datasets/visdial/annotations/visdial_1.0_train_dense_sample.json'
config['dataset']['finetune'] = True
config['callbacks']['path_pretrained_ckpt'] = '/media/local_workspace/quang/checkpoints/visdial/CVPR/v002_abc_LP_lkf_D36/checkpoint_29.pth'

In [15]:
train_dataset = VisDialDataset(config, split='train')

train_dataloader = DataLoader(train_dataset,
                              batch_size=config['solver']['batch_size'] * torch.cuda.device_count(),
                              num_workers=config['solver']['cpu_workers'],
                              shuffle=True)

[train] Tokenizing questions...
[train] Tokenizing answers...
[train] Tokenizing captions...


In [None]:
train_dataloader

In [17]:
val_dataset = VisDialDataset(config, split='val')

val_dataloader = DataLoader(val_dataset,
                            batch_size=config['solver']['batch_size'] * torch.cuda.device_count(),
                            num_workers=config['solver']['cpu_workers'],
                            shuffle=True)

eval_dataloader = DataLoader(val_dataset,
                            batch_size=1 * torch.cuda.device_count(),
                            num_workers=config['solver']['cpu_workers'],
                            shuffle=False)

[val2018] Tokenizing questions...
[val2018] Tokenizing answers...
[val2018] Tokenizing captions...


In [16]:
len(train_dataset)

2000

In [18]:
len(val_dataset)

2064

In [19]:
device = torch.device('cuda')

In [20]:
model = get_model(config)

In [27]:
model.load_state_dict(torch.load(config['callbacks']['path_pretrained_ckpt'])['model'])

<All keys matched successfully>

In [28]:
model = model.to(device)

In [35]:
config['callbacks']['log_dir']

'/home/quang/workspace/log/tensorboard/v002_abc_LP_lkf_D36'

In [33]:
"""OPTIMIZER"""
from torch import optim

optimizer = optim.Adam(model.parameters(), lr=1e-5)
# =============================================================================
#   SETUP BEFORE TRAINING LOOP
# =============================================================================
finetune_path = config['callbacks']['log_dir'] + '/finetune'
os.makedirs(finetune_path)

summary_writer = SummaryWriter(log_dir=finetune_path)
checkpoint_manager = CheckpointManager(model, optimizer, finetune_path, config=config)

sparse_metrics = SparseGTMetrics()
disc_metrics = SparseGTMetrics()
gen_metrics = SparseGTMetrics()
ndcg = NDCG()
disc_ndcg = NDCG()
gen_ndcg = NDCG()

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


In [37]:
ls /home/quang/workspace/log/tensorboard/v002_abc_LP_lkf_D36/finetune

config.json  events.out.tfevents.1570148078.local.29781.0


In [39]:
config["solver"]["training_splits"] = "trainval"

In [40]:
start_epoch = 0
if config["solver"]["training_splits"] == "trainval":
    iterations = (len(train_dataset) + len(val_dataset)) // (
                torch.cuda.device_count() * config["solver"]["batch_size"]) + 1
    num_examples = torch.tensor(len(train_dataset) + len(val_dataset), dtype=torch.float)
else:
    iterations = len(train_dataset) // (config['solver']['batch_size'] * torch.cuda.device_count()) + 1
    num_examples = torch.tensor(len(train_dataset), dtype=torch.float)

global_iteration_step = start_epoch * iterations

In [48]:
import itertools
from tqdm import tqdm
disc_criterion = FinetuneLoss()

In [52]:
start_epoch = 0
if config["solver"]["training_splits"] == "trainval":
    iterations = (len(train_dataset) + len(val_dataset)) // (
                torch.cuda.device_count() * config["solver"]["batch_size"]) + 1
    num_examples = torch.tensor(len(train_dataset) + len(val_dataset), dtype=torch.float)
else:
    iterations = len(train_dataset) // (config['solver']['batch_size'] * torch.cuda.device_count()) + 1
    num_examples = torch.tensor(len(train_dataset), dtype=torch.float)

global_iteration_step = start_epoch * iterations

for epoch in range(start_epoch, config['solver']['num_epochs']):
    logging.info(f"Training for epoch {epoch}:")

    with tqdm(total=iterations) as pbar:
        if config["solver"]["training_splits"] == "trainval":
            combined_dataloader = itertools.chain(train_dataloader, val_dataloader)
        else:
            combined_dataloader = itertools.chain(train_dataloader)

        epoch_loss = torch.tensor(0.0)
        for i, batch in enumerate(combined_dataloader):
            batch = move_to_cuda(batch, device)

            # zero out gradients
            optimizer.zero_grad()

            # do forward
            out = model(batch)

            # compute loss
            batch_loss = torch.tensor(0.0, requires_grad=True, device='cuda')
            if out.get('opt_scores') is not None:
                scores = out['opt_scores']

                sparse_metrics.observe(out['opt_scores'], batch['ans_ind'])
                batch_loss = disc_criterion(scores, batch)

            # compute gradients
            batch_loss.backward()

            # update params
            optimizer.step()

            pbar.update(1)
            pbar.set_postfix(epoch=epoch,
                             batch_loss=batch_loss.item())

            # log metrics
            summary_writer.add_scalar(f'train/batch_loss', batch_loss.item(), global_iteration_step)

            global_iteration_step += 1
            torch.cuda.empty_cache()

            epoch_loss += batch["ans"].size(0) * batch_loss.detach()

    if out.get('opt_scores') is not None:
        avg_metric_dict = {}
        avg_metric_dict.update(sparse_metrics.retrieve(reset=True))

        for metric_name, metric_value in avg_metric_dict.items():
            logging.info(f"{metric_name}: {metric_value}")

        summary_writer.add_scalars(f"train/metrics", avg_metric_dict, global_iteration_step)

    epoch_loss /= num_examples
    logging.info(f"train/epoch_loss: {epoch_loss.item()}\n")
    summary_writer.add_scalar(f'train/epoch_loss', epoch_loss.item(), global_iteration_step)

    # -------------------------------------------------------------------------
    #   ON EPOCH END  (checkpointing and validation)
    # -------------------------------------------------------------------------
    # Validate and report automatic metrics.

    if config['callbacks']['validate']:
        # Switch dropout, batchnorm etc to the correct mode.
        model.eval()

        logging.info(f"\nValidation after epoch {epoch}:")

        for batch in tqdm(eval_dataloader):
            move_to_cuda(batch, device)

            with torch.no_grad():
                out = model(batch)

                if out.get('opt_scores') is not None:
                    scores = out['opt_scores']
                    disc_metrics.observe(scores, batch["ans_ind"])

                    if "gt_relevance" in batch:
                        scores = scores[
                                 torch.arange(scores.size(0)),
                                 batch["round_id"] - 1, :]

                        disc_ndcg.observe(scores, batch["gt_relevance"])

                if out.get('opts_out_scores') is not None:
                    scores = out['opts_out_scores']
                    gen_metrics.observe(scores, batch["ans_ind"])

                    if "gt_relevance" in batch:
                        scores = scores[
                                 torch.arange(scores.size(0)),
                                 batch["round_id"] - 1, :]

                        gen_ndcg.observe(scores, batch["gt_relevance"])

                if out.get('opt_scores') is not None and out.get('opts_out_scores') is not None:
                    scores = (out['opts_out_scores'] + out['opt_scores']) / 2

                    sparse_metrics.observe(scores, batch["ans_ind"])
                    if "gt_relevance" in batch:
                        scores = scores[
                                 torch.arange(scores.size(0)),
                                 batch["round_id"] - 1, :]

                        ndcg.observe(scores, batch["gt_relevance"])

        avg_metric_dict = {}
        avg_metric_dict.update(sparse_metrics.retrieve(reset=True, key='avg_'))
        avg_metric_dict.update(ndcg.retrieve(reset=True, key='avg_'))

        disc_metric_dict = {}
        disc_metric_dict.update(disc_metrics.retrieve(reset=True, key='disc_'))
        disc_metric_dict.update(disc_ndcg.retrieve(reset=True, key='disc_'))

        gen_metric_dict = {}
        gen_metric_dict.update(gen_metrics.retrieve(reset=True, key='gen_'))
        gen_metric_dict.update(gen_ndcg.retrieve(reset=True, key='gen_'))

        for metric_dict in [avg_metric_dict, disc_metric_dict, gen_metric_dict]:
            for metric_name, metric_value in metric_dict.items():
                logging.info(f"{metric_name}: {metric_value}")
            summary_writer.add_scalars(f"val/metrics", metric_dict, global_iteration_step)

        model.train()
        torch.cuda.empty_cache()


100%|█████████▉| 254/255 [04:45<00:01,  1.11s/it, batch_loss=0.291, epoch=0]
100%|██████████| 1032/1032 [07:00<00:00,  2.47it/s]
100%|█████████▉| 254/255 [04:59<00:01,  1.22s/it, batch_loss=0.33, epoch=1] 
100%|██████████| 1032/1032 [07:20<00:00,  2.33it/s]
100%|█████████▉| 254/255 [04:49<00:01,  1.12s/it, batch_loss=0.361, epoch=2]
100%|██████████| 1032/1032 [07:09<00:00,  2.46it/s]
100%|█████████▉| 254/255 [04:55<00:01,  1.18s/it, batch_loss=0.315, epoch=3]
100%|██████████| 1032/1032 [07:07<00:00,  2.47it/s]
100%|█████████▉| 254/255 [04:54<00:01,  1.17s/it, batch_loss=0.316, epoch=4]
100%|██████████| 1032/1032 [07:07<00:00,  2.45it/s]
100%|█████████▉| 254/255 [04:52<00:01,  1.25s/it, batch_loss=0.244, epoch=5]
100%|██████████| 1032/1032 [07:09<00:00,  2.46it/s]
 22%|██▏       | 55/255 [01:03<03:54,  1.17s/it, batch_loss=0.204, epoch=6]


KeyboardInterrupt: 

In [54]:
os.path.curdir()

TypeError: 'str' object is not callable

In [56]:
os.path.dirname(config['callbacks']['path_pretrained_ckpt'])

'/media/local_workspace/quang/checkpoints/visdial/CVPR/v002_abc_LP_lkf_D36'