In [1]:
from __future__ import print_function

from miscc.utils import mkdir_p
from miscc.utils import build_super_images
from miscc.losses import sent_loss, words_loss
from miscc.config import cfg, cfg_from_file

from datasets import TextDataset
from datasets import prepare_data

from model import RNN_ENCODER, CNN_ENCODER

import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
#import argparse
import numpy as np
from PIL import Image
import easydict
import warnings

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

In [2]:
dir_path = (os.path.abspath(os.path.join(os.path.realpath("__file__"), './.')))
sys.path.append(dir_path)

In [3]:
UPDATE_INTERVAL = 200
"""def parse_args():
    parser = argparse.ArgumentParser(description='Train a DAMSM network')
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file',
                        default='cfg/DAMSM/bird.yml', type=str)
    parser.add_argument('--gpu', dest='gpu_id', type=int, default=0)
    parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
    parser.add_argument('--manualSeed', type=int, help='manual seed')
    args = parser.parse_args()
    return args
    """

"def parse_args():\n    parser = argparse.ArgumentParser(description='Train a DAMSM network')\n    parser.add_argument('--cfg', dest='cfg_file',\n                        help='optional config file',\n                        default='cfg/DAMSM/bird.yml', type=str)\n    parser.add_argument('--gpu', dest='gpu_id', type=int, default=0)\n    parser.add_argument('--data_dir', dest='data_dir', type=str, default='')\n    parser.add_argument('--manualSeed', type=int, help='manual seed')\n    args = parser.parse_args()\n    return args\n    "

In [4]:
args = easydict.EasyDict({
    "cfg_file": 'cfg/DAMSM/coco.yml',
    "gpu_id": 0,
    "data_dir": '../data/coco',
    "manualSeed": 100
})

In [5]:
def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    # print(count)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)


        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
                                                 cap_lens, class_ids, batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count

In [6]:
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    return s_cur_loss, w_cur_loss

In [7]:
def build_models():
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch

In [8]:
if __name__ == "__main__":
    args #= parse_args()
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    if args.gpu_id == -1:
        cfg.CUDA = False
    else:
        cfg.GPU_ID = args.gpu_id

    if args.data_dir != '':
        cfg.DATA_DIR = args.data_dir
    print('Using config:')
    pprint.pprint(cfg)

    if not cfg.TRAIN.FLAG:
        args.manualSeed = 100
    elif args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    random.seed(args.manualSeed)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)

Using config:
{'B_VALIDATION': False,
 'CONFIG_NAME': 'DAMSM',
 'CUDA': True,
 'DATASET_NAME': 'coco',
 'DATA_DIR': '../data/coco',
 'GAN': {'B_ATTENTION': True,
         'B_DCGAN': False,
         'CONDITION_DIM': 100,
         'DF_DIM': 64,
         'GF_DIM': 128,
         'R_NUM': 2,
         'Z_DIM': 100},
 'GPU_ID': 0,
 'RNN_TYPE': 'LSTM',
 'TEXT': {'CAPTIONS_PER_IMAGE': 10, 'EMBEDDING_DIM': 256, 'WORDS_NUM': 15},
 'TRAIN': {'BATCH_SIZE': 110,
           'B_NET_D': True,
           'DISCRIMINATOR_LR': 0.0002,
           'ENCODER_LR': 0.0002,
           'FLAG': True,
           'GENERATOR_LR': 0.0002,
           'MAX_EPOCH': 101,
           'NET_E': '',
           'NET_G': '',
           'RNN_GRAD_CLIP': 0.25,
           'SMOOTH': {'GAMMA1': 4.0,
                      'GAMMA2': 5.0,
                      'GAMMA3': 10.0,
                      'LAMBDA': 1.0},
           'SNAPSHOT_INTERVAL': 10},
 'TREE': {'BASE_SIZE': 299, 'BRANCH_NUM': 1},
 'WORKERS': 0}


  yaml_cfg = edict(yaml.load(f))


In [9]:
##########################################################################
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
    (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

model_dir = os.path.join(output_dir, 'Model')
image_dir = os.path.join(output_dir, 'Image')
mkdir_p(model_dir)
mkdir_p(image_dir)

torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True

In [10]:
# Get data loader ##################################################
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
batch_size = cfg.TRAIN.BATCH_SIZE
image_transform = transforms.Compose([
    transforms.Resize(int(imsize * 76 / 64)),
    transforms.RandomCrop(imsize),
    transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, 'train',
                      base_size=cfg.TREE.BASE_SIZE,
                      transform=image_transform)

print(dataset.n_words, dataset.embeddings_num)
assert dataset
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))

# # validation data #
dataset_val = TextDataset(cfg.DATA_DIR, 'test',
                          base_size=cfg.TREE.BASE_SIZE,
                          transform=image_transform)
dataloader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))

Load filenames from: ../data/coco/train/filenames.pickle (82783)
Load filenames from: ../data/coco/test/filenames.pickle (40504)
Save to:  ../data/coco\captions.pickle
33916 10
Load filenames from: ../data/coco/train/filenames.pickle (82783)
Load filenames from: ../data/coco/test/filenames.pickle (40504)
Load from:  ../data/coco\captions.pickle


In [11]:
warnings.filterwarnings(action='ignore')

In [12]:
# Train ##############################################################
text_encoder, image_encoder, labels, start_epoch = build_models()
para = list(text_encoder.parameters())
for v in image_encoder.parameters():
    if v.requires_grad:
        para.append(v)
# optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
# At any point you can hit Ctrl + C to break out of training early.
try:
    lr = cfg.TRAIN.ENCODER_LR
    for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
        optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
        epoch_start_time = time.time()
        count = train(dataloader, image_encoder, text_encoder,
                      batch_size, labels, optimizer, epoch,
                      dataset.ixtoword, image_dir)
        print('-' * 89)
        if len(dataloader_val) > 0:
            s_loss, w_loss = evaluate(dataloader_val, image_encoder,
                                      text_encoder, batch_size)
            print('| end epoch {:3d} | valid loss '
                  '{:5.2f} {:5.2f} | lr {:.5f}|'
                  .format(epoch, s_loss, w_loss, lr))
        print('-' * 89)
        if lr > cfg.TRAIN.ENCODER_LR/10.:
            lr *= 0.98

        if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or
            epoch == cfg.TRAIN.MAX_EPOCH):
            torch.save(image_encoder.state_dict(),
                       '%s/image_encoder%d.pth' % (model_dir, epoch))
            torch.save(text_encoder.state_dict(),
                       '%s/text_encoder%d.pth' % (model_dir, epoch))
            print('Save G/Ds models.')
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

Load pretrained model from  https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
| epoch   0 |     0/  752 batches | ms/batch 49.16 | s_loss  0.02  0.02 | w_loss  0.03  0.03
| epoch   0 |   200/  752 batches | ms/batch 4524.48 | s_loss  4.50  4.51 | w_loss  4.64  4.39
| epoch   0 |   400/  752 batches | ms/batch 4039.73 | s_loss  3.65  3.62 | w_loss  3.43  3.33
| epoch   0 |   600/  752 batches | ms/batch 4344.10 | s_loss  3.16  3.11 | w_loss  2.89  2.85
-----------------------------------------------------------------------------------------
| end epoch   0 | valid loss  5.09  4.44 | lr 0.00020|
-----------------------------------------------------------------------------------------
Save G/Ds models.
| epoch   1 |     0/  752 batches | ms/batch 15.18 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch   1 |   200/  752 batches | ms/batch 3642.48 | s_loss  2.70  2.66 | w_loss  2.45  2.45
| epoch   1 |   400/  752 batches | ms/batch 3866.94 | s_loss  2.52  2.48 | w_loss  

| epoch  13 |   200/  752 batches | ms/batch 3852.48 | s_loss  1.42  1.43 | w_loss  1.23  1.27
| epoch  13 |   400/  752 batches | ms/batch 3682.09 | s_loss  1.42  1.42 | w_loss  1.23  1.26
| epoch  13 |   600/  752 batches | ms/batch 3894.68 | s_loss  1.42  1.43 | w_loss  1.22  1.26
-----------------------------------------------------------------------------------------
| end epoch  13 | valid loss  2.71  2.35 | lr 0.00015|
-----------------------------------------------------------------------------------------
| epoch  14 |     0/  752 batches | ms/batch 21.50 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch  14 |   200/  752 batches | ms/batch 3890.91 | s_loss  1.41  1.42 | w_loss  1.21  1.26
| epoch  14 |   400/  752 batches | ms/batch 3685.65 | s_loss  1.39  1.40 | w_loss  1.20  1.24
| epoch  14 |   600/  752 batches | ms/batch 3827.28 | s_loss  1.41  1.41 | w_loss  1.22  1.26
-----------------------------------------------------------------------------------------
| end epoch 

-----------------------------------------------------------------------------------------
| end epoch  26 | valid loss  2.58  2.17 | lr 0.00012|
-----------------------------------------------------------------------------------------
| epoch  27 |     0/  752 batches | ms/batch 19.31 | s_loss  0.01  0.01 | w_loss  0.00  0.00
| epoch  27 |   200/  752 batches | ms/batch 3714.64 | s_loss  1.28  1.29 | w_loss  1.06  1.10
| epoch  27 |   400/  752 batches | ms/batch 3761.98 | s_loss  1.28  1.29 | w_loss  1.06  1.10
| epoch  27 |   600/  752 batches | ms/batch 3757.66 | s_loss  1.29  1.30 | w_loss  1.06  1.10
-----------------------------------------------------------------------------------------
| end epoch  27 | valid loss  2.57  2.15 | lr 0.00012|
-----------------------------------------------------------------------------------------
| epoch  28 |     0/  752 batches | ms/batch 14.00 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch  28 |   200/  752 batches | ms/batch 3799.86 | s_lo

| epoch  40 |     0/  752 batches | ms/batch 17.12 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch  40 |   200/  752 batches | ms/batch 3555.16 | s_loss  1.22  1.23 | w_loss  0.98  1.02
| epoch  40 |   400/  752 batches | ms/batch 3444.92 | s_loss  1.22  1.23 | w_loss  0.99  1.03
| epoch  40 |   600/  752 batches | ms/batch 3601.93 | s_loss  1.23  1.24 | w_loss  0.99  1.04
-----------------------------------------------------------------------------------------
| end epoch  40 | valid loss  2.47  2.03 | lr 0.00009|
-----------------------------------------------------------------------------------------
Save G/Ds models.
| epoch  41 |     0/  752 batches | ms/batch 13.13 | s_loss  0.01  0.01 | w_loss  0.00  0.00
| epoch  41 |   200/  752 batches | ms/batch 3633.32 | s_loss  1.22  1.23 | w_loss  0.99  1.02
| epoch  41 |   400/  752 batches | ms/batch 3785.35 | s_loss  1.21  1.22 | w_loss  0.97  1.02
| epoch  41 |   600/  752 batches | ms/batch 3741.88 | s_loss  1.22  1.24 | w_loss  0.

| epoch  53 |   400/  752 batches | ms/batch 3707.10 | s_loss  1.19  1.20 | w_loss  0.94  0.99
| epoch  53 |   600/  752 batches | ms/batch 3957.19 | s_loss  1.18  1.19 | w_loss  0.93  0.98
-----------------------------------------------------------------------------------------
| end epoch  53 | valid loss  2.43  2.01 | lr 0.00007|
-----------------------------------------------------------------------------------------
| epoch  54 |     0/  752 batches | ms/batch 13.60 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch  54 |   200/  752 batches | ms/batch 3763.25 | s_loss  1.17  1.19 | w_loss  0.93  0.98
| epoch  54 |   400/  752 batches | ms/batch 3842.81 | s_loss  1.20  1.21 | w_loss  0.96  1.00
| epoch  54 |   600/  752 batches | ms/batch 3744.07 | s_loss  1.19  1.20 | w_loss  0.94  0.98
-----------------------------------------------------------------------------------------
| end epoch  54 | valid loss  2.45  2.01 | lr 0.00007|
----------------------------------------------------

| end epoch  66 | valid loss  2.44  2.01 | lr 0.00005|
-----------------------------------------------------------------------------------------
| epoch  67 |     0/  752 batches | ms/batch 15.48 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch  67 |   200/  752 batches | ms/batch 3777.23 | s_loss  1.15  1.17 | w_loss  0.91  0.95
| epoch  67 |   400/  752 batches | ms/batch 3595.68 | s_loss  1.15  1.16 | w_loss  0.91  0.94
| epoch  67 |   600/  752 batches | ms/batch 3904.73 | s_loss  1.16  1.18 | w_loss  0.91  0.95
-----------------------------------------------------------------------------------------
| end epoch  67 | valid loss  2.42  1.99 | lr 0.00005|
-----------------------------------------------------------------------------------------
| epoch  68 |     0/  752 batches | ms/batch 14.23 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch  68 |   200/  752 batches | ms/batch 3780.61 | s_loss  1.16  1.18 | w_loss  0.91  0.96
| epoch  68 |   400/  752 batches | ms/batch 3767.24 |

| epoch  80 |   200/  752 batches | ms/batch 3704.40 | s_loss  1.15  1.16 | w_loss  0.90  0.94
| epoch  80 |   400/  752 batches | ms/batch 3705.37 | s_loss  1.14  1.16 | w_loss  0.89  0.93
| epoch  80 |   600/  752 batches | ms/batch 3712.01 | s_loss  1.16  1.18 | w_loss  0.90  0.94
-----------------------------------------------------------------------------------------
| end epoch  80 | valid loss  2.40  1.96 | lr 0.00004|
-----------------------------------------------------------------------------------------
Save G/Ds models.
| epoch  81 |     0/  752 batches | ms/batch 13.67 | s_loss  0.01  0.01 | w_loss  0.00  0.01
| epoch  81 |   200/  752 batches | ms/batch 3713.65 | s_loss  1.15  1.17 | w_loss  0.90  0.94
| epoch  81 |   400/  752 batches | ms/batch 3961.78 | s_loss  1.14  1.15 | w_loss  0.89  0.93
| epoch  81 |   600/  752 batches | ms/batch 3740.76 | s_loss  1.15  1.17 | w_loss  0.91  0.95
------------------------------------------------------------------------------------

| epoch  93 |   600/  752 batches | ms/batch 3751.30 | s_loss  1.15  1.16 | w_loss  0.89  0.93
-----------------------------------------------------------------------------------------
| end epoch  93 | valid loss  2.39  1.94 | lr 0.00003|
-----------------------------------------------------------------------------------------
| epoch  94 |     0/  752 batches | ms/batch 14.69 | s_loss  0.01  0.01 | w_loss  0.00  0.00
| epoch  94 |   200/  752 batches | ms/batch 3768.92 | s_loss  1.13  1.15 | w_loss  0.88  0.92
| epoch  94 |   400/  752 batches | ms/batch 3571.70 | s_loss  1.14  1.16 | w_loss  0.89  0.93
| epoch  94 |   600/  752 batches | ms/batch 3711.72 | s_loss  1.15  1.16 | w_loss  0.89  0.94
-----------------------------------------------------------------------------------------
| end epoch  94 | valid loss  2.37  1.92 | lr 0.00003|
-----------------------------------------------------------------------------------------
| epoch  95 |     0/  752 batches | ms/batch 13.36 | s_lo