In [1]:
from __future__ import print_function

from miscc.utils import mkdir_p
from miscc.utils import build_super_images
from miscc.losses import sent_loss, words_loss
from miscc.config import cfg, cfg_from_file

from datasets import TextDataset
from datasets import prepare_data

from model import RNN_ENCODER, CNN_ENCODER

import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms


# dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
# sys.path.append(dir_path)



In [18]:
class parse_args():
    cfg_file='../code/cfg/DAMSM/bird.yml'
    gpu_id=1
    data_dir='../data/birds/'
    manualSeed=1
args = parse_args()

In [19]:

args = parse_args()
if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)

if args.gpu_id == -1:
    cfg.CUDA = False
else:
    cfg.GPU_ID = args.gpu_id

if args.data_dir != '':
    cfg.DATA_DIR = args.data_dir
print('Using config:')
pprint.pprint(cfg)

if not cfg.TRAIN.FLAG:
    args.manualSeed = 100
elif args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
np.random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if cfg.CUDA:
    torch.cuda.manual_seed_all(args.manualSeed)



Using config:
{'B_VALIDATION': False,
 'CONFIG_NAME': 'DAMSM',
 'CUDA': True,
 'DATASET_NAME': 'birds',
 'DATA_DIR': '../data/birds/',
 'GAN': {'B_ATTENTION': True,
         'B_DCGAN': False,
         'CONDITION_DIM': 100,
         'DF_DIM': 64,
         'GF_DIM': 128,
         'R_NUM': 2,
         'Z_DIM': 100},
 'GPU_ID': 1,
 'RNN_TYPE': 'LSTM',
 'TEXT': {'CAPTIONS_PER_IMAGE': 10, 'EMBEDDING_DIM': 256, 'WORDS_NUM': 18},
 'TRAIN': {'BATCH_SIZE': 48,
           'B_NET_D': True,
           'DISCRIMINATOR_LR': 0.0002,
           'ENCODER_LR': 0.002,
           'FLAG': True,
           'GENERATOR_LR': 0.0002,
           'MAX_EPOCH': 600,
           'NET_E': '',
           'NET_G': '',
           'RNN_GRAD_CLIP': 0.25,
           'SMOOTH': {'GAMMA1': 4.0,
                      'GAMMA2': 5.0,
                      'GAMMA3': 10.0,
                      'LAMBDA': 1.0},
           'SNAPSHOT_INTERVAL': 50},
 'TREE': {'BASE_SIZE': 299, 'BRANCH_NUM': 1},
 'WORKERS': 1}


In [4]:
##########################################################################
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
    (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

model_dir = os.path.join(output_dir, 'Model')
image_dir = os.path.join(output_dir, 'Image')
mkdir_p(model_dir)
mkdir_p(image_dir)

torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True



In [5]:

# Get data loader ##################################################
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
batch_size = cfg.TRAIN.BATCH_SIZE
image_transform = transforms.Compose([
    transforms.Scale(int(imsize * 76 / 64)),
    transforms.RandomCrop(imsize),
    transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, 'train',
                      base_size=cfg.TREE.BASE_SIZE,
                      transform=image_transform)

print(dataset.n_words, dataset.embeddings_num)





Total filenames:  11788 001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg
Load filenames from: ../data/birds//train/filenames.pickle (8855)
Load filenames from: ../data/birds//test/filenames.pickle (2933)
Load from:  ../data/birds/captions.pickle
../data/birds/train
../data/birds/train/class_info.pickle
5450 10


In [6]:
assert dataset

In [7]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))

# # validation data #
dataset_val = TextDataset(cfg.DATA_DIR, 'test',
                          base_size=cfg.TREE.BASE_SIZE,
                          transform=image_transform)
dataloader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))



Total filenames:  11788 001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg
Load filenames from: ../data/birds//train/filenames.pickle (8855)
Load filenames from: ../data/birds//test/filenames.pickle (2933)
Load from:  ../data/birds/captions.pickle
../data/birds/test
../data/birds/test/class_info.pickle


In [16]:

UPDATE_INTERVAL = 2



def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
#         print(words_features.shape,sent_code.shape)
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
#         print(nef,att_sze)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
#         print('here')
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
                                                 cap_lens, class_ids, batch_size)
        
        
        
        
#         print(w_loss0.data)
#         print('--------------------------')
#         print(w_loss1.data)
#         print('--------------------------')
#         print(attn_maps[0].shape)
    
    
    
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1
#         print(loss)
        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        
#         print(s_total_loss0[0],s_total_loss1[0])
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

#             print(count)
            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            
#             print(imgs[-1].cpu().shape, captions.shape, len(attn_maps),attn_maps[-1].shape, att_sze)
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, step)
                im.save(fullpath)
    return count


def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss[0] / step
    w_cur_loss = w_total_loss[0] / step

    return s_cur_loss, w_cur_loss


def build_models():
    print('here')
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    print(cfg.TRAIN.NET_E)
    if cfg.TRAIN.NET_E != '':
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        print(name)
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch



In [17]:
# Train ##############################################################
text_encoder, image_encoder, labels, start_epoch = build_models()
para = list(text_encoder.parameters())
for v in image_encoder.parameters():
    if v.requires_grad:
        para.append(v)
# optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
# At any point you can hit Ctrl + C to break out of training early.

lr = cfg.TRAIN.ENCODER_LR


here
Load pretrained model from  https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth



In [11]:
for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
    optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
    epoch_start_time = time.time()
    count = train(dataloader, image_encoder, text_encoder,
                  batch_size, labels, optimizer, epoch,
                  dataset.ixtoword, image_dir)
    print('-' * 89)
    if len(dataloader_val) > 0:
        s_loss, w_loss = evaluate(dataloader_val, image_encoder,
                                  text_encoder, batch_size)
        print('| end epoch {:3d} | valid loss '
              '{:5.2f} {:5.2f} | lr {:.5f}|'
              .format(epoch, s_loss, w_loss, lr))
    print('-' * 89)
    if lr > cfg.TRAIN.ENCODER_LR/10.:
        lr *= 0.98

    if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or
        epoch == cfg.TRAIN.MAX_EPOCH):
        torch.save(image_encoder.state_dict(),
                   '%s/image_encoder%d.pth' % (model_dir, epoch))
        torch.save(text_encoder.state_dict(),
                   '%s/text_encoder%d.pth' % (model_dir, epoch))
        print('Save G/Ds models.')

    

| epoch   0 |     0/  184 batches | ms/batch 363.27 | s_loss  1.93  1.93 | w_loss  1.94  1.93
| epoch   0 |     2/  184 batches | ms/batch 5974.01 | s_loss  3.84  3.87 | w_loss  3.85  3.93
| epoch   0 |     4/  184 batches | ms/batch 5504.15 | s_loss  3.80  3.84 | w_loss  3.84  3.84
| epoch   0 |     6/  184 batches | ms/batch 5671.67 | s_loss  3.78  3.82 | w_loss  3.84  3.81
| epoch   0 |     8/  184 batches | ms/batch 5658.30 | s_loss  3.73  3.76 | w_loss  3.97  3.77
| epoch   0 |    10/  184 batches | ms/batch 5555.09 | s_loss  3.73  3.83 | w_loss  3.90  3.81
| epoch   0 |    12/  184 batches | ms/batch 5614.50 | s_loss  3.74  3.85 | w_loss  3.88  3.78
| epoch   0 |    14/  184 batches | ms/batch 5800.69 | s_loss  3.69  3.75 | w_loss  3.81  3.73
| epoch   0 |    16/  184 batches | ms/batch 5585.99 | s_loss  3.69  3.71 | w_loss  3.93  3.70
| epoch   0 |    18/  184 batches | ms/batch 5750.51 | s_loss  3.65  3.72 | w_loss  3.92  3.67
| epoch   0 |    20/  184 batches | ms/batch 5657.8

KeyboardInterrupt: 