In [1]:
from __future__ import print_function

from miscc.utils import mkdir_p
from miscc.utils import build_super_images
from miscc.losses import sent_loss, words_loss
from miscc.config import cfg, cfg_from_file

from datasets import TextDataset
from datasets import prepare_data

from model import RNN_ENCODER, CNN_ENCODER

import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms


# dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
# sys.path.append(dir_path)



In [2]:
class parse_args():
    cfg_file='../code/cfg/DAMSM/bird.yml'
    gpu_id=1
    data_dir='../data/birds/'
    manualSeed=1
args = parse_args()

In [3]:

args = parse_args()
if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)

if args.gpu_id == -1:
    cfg.CUDA = False
else:
    cfg.GPU_ID = args.gpu_id

if args.data_dir != '':
    cfg.DATA_DIR = args.data_dir
print('Using config:')
pprint.pprint(cfg)

if not cfg.TRAIN.FLAG:
    args.manualSeed = 100
elif args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
np.random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if cfg.CUDA:
    torch.cuda.manual_seed_all(args.manualSeed)



Using config:
{'B_VALIDATION': False,
 'CONFIG_NAME': 'DAMSM',
 'CUDA': True,
 'DATASET_NAME': 'birds',
 'DATA_DIR': '../data/birds/',
 'GAN': {'B_ATTENTION': True,
         'B_DCGAN': False,
         'CONDITION_DIM': 100,
         'DF_DIM': 64,
         'GF_DIM': 128,
         'R_NUM': 2,
         'Z_DIM': 100},
 'GPU_ID': 1,
 'RNN_TYPE': 'LSTM',
 'TEXT': {'CAPTIONS_PER_IMAGE': 10, 'EMBEDDING_DIM': 256, 'WORDS_NUM': 18},
 'TRAIN': {'BATCH_SIZE': 48,
           'B_NET_D': True,
           'DISCRIMINATOR_LR': 0.0002,
           'ENCODER_LR': 0.002,
           'FLAG': True,
           'GENERATOR_LR': 0.0002,
           'MAX_EPOCH': 600,
           'NET_E': '',
           'NET_G': '',
           'RNN_GRAD_CLIP': 0.25,
           'SMOOTH': {'GAMMA1': 4.0,
                      'GAMMA2': 5.0,
                      'GAMMA3': 10.0,
                      'LAMBDA': 1.0},
           'SNAPSHOT_INTERVAL': 50},
 'TREE': {'BASE_SIZE': 299, 'BRANCH_NUM': 1},
 'WORKERS': 1}


  yaml_cfg = edict(yaml.load(f))


In [4]:
##########################################################################
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
    (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

model_dir = os.path.join(output_dir, 'Model')
image_dir = os.path.join(output_dir, 'Image')
mkdir_p(model_dir)
mkdir_p(image_dir)

torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True



In [5]:

# Get data loader ##################################################
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
batch_size = cfg.TRAIN.BATCH_SIZE
image_transform = transforms.Compose([
    transforms.Scale(int(imsize * 76 / 64)),
    transforms.RandomCrop(imsize),
    transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, 'train',
                      base_size=cfg.TREE.BASE_SIZE,
                      transform=image_transform)

print(dataset.n_words, dataset.embeddings_num)





Total filenames:  11788 001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg
Load filenames from: ../data/birds//train/filenames.pickle (8855)
Load filenames from: ../data/birds//test/filenames.pickle (2933)
Load from:  ../data/birds/captions.pickle
../data/birds/train
../data/birds/train/class_info.pickle
5450 10


In [6]:
assert dataset

In [7]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))

# # validation data #
dataset_val = TextDataset(cfg.DATA_DIR, 'test',
                          base_size=cfg.TREE.BASE_SIZE,
                          transform=image_transform)
dataloader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))



Total filenames:  11788 001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg
Load filenames from: ../data/birds//train/filenames.pickle (8855)
Load filenames from: ../data/birds//test/filenames.pickle (2933)
Load from:  ../data/birds/captions.pickle
../data/birds/test
../data/birds/test/class_info.pickle


In [13]:

UPDATE_INTERVAL = 2



def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
#         print(words_features.shape,sent_code.shape)
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
#         print(nef,att_sze)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
#         print('here')
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
                                                 cap_lens, class_ids, batch_size)
        
        
        
        
#         print(w_loss0.data)
#         print('--------------------------')
#         print(w_loss1.data)
#         print('--------------------------')
#         print(attn_maps[0].shape)
    
    
    
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1
#         print(loss)
        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        
#         print(s_total_loss0[0],s_total_loss1[0])
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

#             print(count)
            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            
#             print(imgs[-1].cpu().shape, captions.shape, len(attn_maps),attn_maps[-1].shape, att_sze)
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count


def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss[0] / step
    w_cur_loss = w_total_loss[0] / step

    return s_cur_loss, w_cur_loss


def build_models():
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch



In [14]:
# Train ##############################################################
text_encoder, image_encoder, labels, start_epoch = build_models()
para = list(text_encoder.parameters())
for v in image_encoder.parameters():
    if v.requires_grad:
        para.append(v)
# optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
# At any point you can hit Ctrl + C to break out of training early.

lr = cfg.TRAIN.ENCODER_LR


  "num_layers={}".format(dropout, num_layers))


Load pretrained model from  https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth


In [15]:
# def words_loss(img_features, words_emb, labels,
#                cap_lens, class_ids, batch_size):
#     """
#         words_emb(query): batch x nef x seq_len
#         img_features(context): batch x nef x 17 x 17
#     """
#     masks = []
#     att_maps = []
#     similarities = []
#     cap_lens = cap_lens.data.tolist()
#     for i in range(batch_size):
#         if class_ids is not None:
#             print('class_ids:',class_ids)
#             mask = (class_ids == class_ids[i]).astype(np.bool)
#             print('mask:',mask)
#             mask[i] = 0
#             masks.append(mask.reshape((1, -1)))
            
#         # Get the i-th text description
#         words_num = cap_lens[i]
#         # -> 1 x nef x words_num
#         word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()
#         # -> batch_size x nef x words_num
#         word = word.repeat(batch_size, 1, 1)
#         # batch x nef x 17*17
#         context = img_features
#         """
#             word(query): batch x nef x words_num
#             context: batch x nef x 17 x 17
#             weiContext: batch x nef x words_num
#             attn: batch x words_num x 17 x 17
#         """
#         weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1)
#         att_maps.append(attn[i].unsqueeze(0).contiguous())
#         # --> batch_size x words_num x nef
#         word = word.transpose(1, 2).contiguous()
#         weiContext = weiContext.transpose(1, 2).contiguous()
#         # --> batch_size*words_num x nef
#         word = word.view(batch_size * words_num, -1)
#         weiContext = weiContext.view(batch_size * words_num, -1)
#         #
#         # -->batch_size*words_num
#         row_sim = cosine_similarity(word, weiContext)
#         # --> batch_size x words_num
#         row_sim = row_sim.view(batch_size, words_num)

#         # Eq. (10)
#         row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_()
#         row_sim = row_sim.sum(dim=1, keepdim=True)
#         row_sim = torch.log(row_sim)

#         # --> 1 x batch_size
#         # similarities(i, j): the similarity between the i-th image and the j-th text description
#         similarities.append(row_sim)

#     # batch_size x batch_size
#     similarities = torch.cat(similarities, 1)
#     if class_ids is not None:
#         masks = np.concatenate(masks, 0)
#         # masks: batch_size x batch_size
#         masks = torch.ByteTensor(masks)
#         if cfg.CUDA:
#             masks = masks.cuda()

#     similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3
#     if class_ids is not None:
#         similarities.data.masked_fill_(masks, -float('inf'))
#     similarities1 = similarities.transpose(0, 1)
#     if labels is not None:
#         loss0 = nn.CrossEntropyLoss()(similarities, labels)
#         loss1 = nn.CrossEntropyLoss()(similarities1, labels)
#     else:
#         loss0, loss1 = None, None
#     return loss0, loss1, att_maps


In [16]:
# import os
# import errno
# import numpy as np
# from torch.nn import init

# import torch
# import torch.nn as nn

# from PIL import Image, ImageDraw, ImageFont
# from copy import deepcopy
# import skimage.transform
# # For visualization ################################################
# COLOR_DIC = {0:[128,64,128],  1:[244, 35,232],
#              2:[70, 70, 70],  3:[102,102,156],
#              4:[190,153,153], 5:[153,153,153],
#              6:[250,170, 30], 7:[220, 220, 0],
#              8:[107,142, 35], 9:[152,251,152],
#              10:[70,130,180], 11:[220,20, 60],
#              12:[255, 0, 0],  13:[0, 0, 142],
#              14:[119,11, 32], 15:[0, 60,100],
#              16:[0, 80, 100], 17:[0, 0, 230],
#              18:[0,  0, 70],  19:[0, 0,  0]}
# FONT_MAX = 50
# def build_super_images(real_imgs, captions, ixtoword,
#                        attn_maps, att_sze, lr_imgs=None,
#                        batch_size=cfg.TRAIN.BATCH_SIZE,
#                        max_word_num=cfg.TEXT.WORDS_NUM):
#     nvis = 8
#     real_imgs = real_imgs[:nvis]
#     if lr_imgs is not None:
#         lr_imgs = lr_imgs[:nvis]
#     if att_sze == 17:
#         vis_size = att_sze * 16
#     else:
#         vis_size = real_imgs.size(2)

# #     print('vis_size:',vis_size)
#     text_convas = \
#         np.ones([batch_size * FONT_MAX,
#                  (max_word_num + 2) * (vis_size + 2), 3],
#                 dtype=np.uint8)

#     for i in range(max_word_num):
#         istart = (i + 2) * (vis_size + 2)
#         iend = (i + 3) * (vis_size + 2)
#         text_convas[:, istart:iend, :] = COLOR_DIC[i]


#     real_imgs = \
#         nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
#     # [-1, 1] --> [0, 1]
#     real_imgs.add_(1).div_(2).mul_(255)
#     real_imgs = real_imgs.data.numpy()
#     # b x c x h x w --> b x h x w x c
#     real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))
#     pad_sze = real_imgs.shape
#     middle_pad = np.zeros([pad_sze[2], 2, 3])
#     post_pad = np.zeros([pad_sze[1], pad_sze[2], 3])
#     if lr_imgs is not None:
#         lr_imgs = \
#             nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs)
#         # [-1, 1] --> [0, 1]
#         lr_imgs.add_(1).div_(2).mul_(255)
#         lr_imgs = lr_imgs.data.numpy()
#         # b x c x h x w --> b x h x w x c
#         lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1))

#     # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17
#     seq_len = max_word_num
#     img_set = []
#     num = nvis  # len(attn_maps)
# #     print('num:',num)
#     text_map, sentences = \
#         drawCaption(text_convas, captions, ixtoword, vis_size)
#     text_map = np.asarray(text_map).astype(np.uint8)

#     bUpdate = 1
#     for i in range(num): #num: 8
# #         print('1 attn_maps[i].shape:',attn_maps[i].shape)
#         attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)
# #         print('2 attn.shape:',attn.shape)
#         # --> 1 x 1 x 17 x 17
#         attn_max = attn.max(dim=1, keepdim=True)
#         attn = torch.cat([attn_max[0], attn], 1)
# #         print('3 attn.shape:',attn.shape)
#         #
#         attn = attn.view(-1, 1, att_sze, att_sze)
# #         print('4 attn.shape:',attn.shape)
#         attn = attn.repeat(1, 3, 1, 1).data.numpy()
# #         print('5 attn.shape:',attn.shape)
#         # n x c x h x w --> n x h x w x c
#         attn = np.transpose(attn, (0, 2, 3, 1))
# #         print('6 attn.shape:',attn.shape)
#         num_attn = attn.shape[0]
# #         print('num_attn:',num_attn)
#         #
#         img = real_imgs[i]
#         if lr_imgs is None:
#             lrI = img
#         else:
#             lrI = lr_imgs[i]
#         row = [lrI, middle_pad]
#         row_merge = [img, middle_pad]
#         row_beforeNorm = []
#         minVglobal, maxVglobal = 1, 0
#         for j in range(num_attn):
# #             print('attn.shape:',attn.shape)
#             one_map = attn[j]
# #             print('0 one_map.shape:',one_map.shape)
#             if (vis_size // att_sze) > 1:
#                 one_map = skimage.transform.pyramid_expand(one_map, sigma=20
#                                                            ,upscale=vis_size // att_sze
#                                                            ,multichannel=True)
# #             print('1 one_map.shape:',one_map.shape)
#             row_beforeNorm.append(one_map)
#             minV = one_map.min()
#             maxV = one_map.max()
#             if minVglobal > minV:
#                 minVglobal = minV
#             if maxVglobal < maxV:
#                 maxVglobal = maxV
# #         print('len(row_beforeNorm):',len(row_beforeNorm))
#         for j in range(seq_len + 1):
#             if j < num_attn:
# #                 print('2 one_map.shape:',one_map.shape)
#                 one_map = row_beforeNorm[j]
# #                 print('3 one_map.shape:',one_map.shape)
#                 one_map = (one_map - minVglobal) / (maxVglobal - minVglobal)
#                 one_map *= 255
# #                 print('4 one_map.shape:',one_map.shape)
#                 #
                
                
# #                 print(np.uint8(one_map).shape,type(np.uint8(one_map)))
                
                
# #                 print(img.shape,one_map.shape)
#                 PIL_im = Image.fromarray(np.uint8(img))
# #                 display(PIL_im)
#                 PIL_att = Image.fromarray(np.uint8(one_map))
#                 merged = \
#                     Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))
#                 mask = Image.new('L', (vis_size, vis_size), (210))
# #                 print('---------------')
# #                 print(np.asarray(mask))
#                 merged.paste(PIL_im, (0, 0))
                
                
#                 merged.paste(PIL_att, (0, 0), mask)
#                 merged = np.array(merged)[:, :, :3]
#             else:
#                 one_map = post_pad
#                 merged = post_pad
#             row.append(one_map)
#             row.append(middle_pad)
#             #
#             row_merge.append(merged)
#             row_merge.append(middle_pad)
#         row = np.concatenate(row, 1)
#         row_merge = np.concatenate(row_merge, 1)
#         txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX]
#         if txt.shape[1] != row.shape[1]:
# #             print('txt', txt.shape, 'row', row.shape)
#             bUpdate = 0
#             break
#         row = np.concatenate([txt, row, row_merge], 0)
#         img_set.append(row)
#     if bUpdate:
#         img_set = np.concatenate(img_set, 0)
#         img_set = img_set.astype(np.uint8)
#         return img_set, sentences
#     else:
#         return None
# def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2):
#     num = captions.size(0)
#     img_txt = Image.fromarray(convas)
#     # get a font
#     # fnt = None  # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
#     fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
#     # get a drawing context
#     d = ImageDraw.Draw(img_txt)
#     sentence_list = []
#     for i in range(num):
#         cap = captions[i].data.cpu().numpy()
#         sentence = []
#         for j in range(len(cap)):
#             if cap[j] == 0:
#                 break
#             word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')
#             d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]),
#                    font=fnt, fill=(255, 255, 255, 255))
#             sentence.append(word)
#         sentence_list.append(sentence)
#     return img_txt, sentence_list


In [17]:
for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
    optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
    epoch_start_time = time.time()
    count = train(dataloader, image_encoder, text_encoder,
                  batch_size, labels, optimizer, epoch,
                  dataset.ixtoword, image_dir)
    print('-' * 89)
    
    

  "See the documentation of nn.Upsample for details.".format(mode))


| epoch   0 |     0/  184 batches | ms/batch 328.79 | s_loss  1.95  1.96 | w_loss  2.78  2.24
| epoch   0 |     2/  184 batches | ms/batch 6115.63 | s_loss  3.92  3.91 | w_loss  5.30  4.28
| epoch   0 |     4/  184 batches | ms/batch 5972.23 | s_loss  3.90  3.90 | w_loss  4.73  3.98
| epoch   0 |     6/  184 batches | ms/batch 6043.24 | s_loss  3.88  3.89 | w_loss  4.65  3.96
| epoch   0 |     8/  184 batches | ms/batch 6008.20 | s_loss  3.87  3.87 | w_loss  4.36  3.97
| epoch   0 |    10/  184 batches | ms/batch 5968.83 | s_loss  3.86  3.87 | w_loss  4.25  3.94
| epoch   0 |    12/  184 batches | ms/batch 5741.01 | s_loss  3.86  3.87 | w_loss  4.05  3.90
| epoch   0 |    14/  184 batches | ms/batch 5822.53 | s_loss  3.86  3.87 | w_loss  3.95  3.91
| epoch   0 |    16/  184 batches | ms/batch 5874.10 | s_loss  3.86  3.87 | w_loss  3.92  3.89
| epoch   0 |    18/  184 batches | ms/batch 5872.86 | s_loss  3.86  3.87 | w_loss  3.88  3.90
| epoch   0 |    20/  184 batches | ms/batch 6025.4

KeyboardInterrupt: 

In [None]:
if len(dataloader_val) > 0:
        s_loss, w_loss = evaluate(dataloader_val, image_encoder,
                                  text_encoder, batch_size)
        print('| end epoch {:3d} | valid loss '
              '{:5.2f} {:5.2f} | lr {:.5f}|'
              .format(epoch, s_loss, w_loss, lr))
    print('-' * 89)
    if lr > cfg.TRAIN.ENCODER_LR/10.:
        lr *= 0.98

    if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or
        epoch == cfg.TRAIN.MAX_EPOCH):
        torch.save(image_encoder.state_dict(),
                   '%s/image_encoder%d.pth' % (model_dir, epoch))
        torch.save(text_encoder.state_dict(),
                   '%s/text_encoder%d.pth' % (model_dir, epoch))
        print('Save G/Ds models.')
