In [1]:
from __future__ import print_function

from miscc.utils import mkdir_p
from miscc.utils import build_super_images
from miscc.losses import sent_loss, words_loss
from miscc.config import cfg, cfg_from_file

from datasets import TextDataset
from datasets import prepare_data

from model import RNN_ENCODER, CNN_ENCODER

import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
from PIL import Image
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:

UPDATE_INTERVAL = 100
class parse_args():
    cfg_file='../code/cfg/DAMSM/coco.yml'
    gpu_id=3
    data_dir='../data/coco/'
    manualSeed=1
args = parse_args()

In [3]:


def train(dataloader, cnn_model, trx_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    
    cnn_model.train()
    trx_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        trx_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
#         print(words_features.shape,sent_code.shape)
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)
#         print('nef:{0},att_sze:{1}'.format(nef,att_sze))

#         hidden = trx_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
#         print('captions:',captions, captions.size())
        
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
#         words_emb, sent_emb = trx_model(captions, cap_lens, hidden)
        
        words_emb, sent_emb = trx_model(captions)
#         print('words_emb:',words_emb.size(),', sent_emb:', sent_emb.size())
        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
                                                 cap_lens, class_ids, batch_size)
        
        
        
        
#         print(w_loss0.data)
#         print('--------------------------')
#         print(w_loss1.data)
#         print('--------------------------')
#         print(attn_maps[0].shape)
    
    
    
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1
#         print(loss)
        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        
#         print(s_total_loss0[0],s_total_loss1[0])
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(trx_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

#             print(count)
            s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            tbw.add_scalar('train_w_loss0', float(w_cur_loss0.item()), epoch)
            tbw.add_scalar('train_s_loss0', float(s_cur_loss0.item()), epoch)
            tbw.add_scalar('train_w_loss1', float(w_cur_loss1.item()), epoch)
            tbw.add_scalar('train_s_loss1', float(s_cur_loss1.item()), epoch)
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            
#             print(imgs[-1].cpu().shape, captions.shape, len(attn_maps),attn_maps[-1].shape, att_sze)
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '{0}/attention_maps_e{1}_s{2}.png'.format(image_dir,epoch, step)
                im.save(fullpath)
    return count

In [4]:

def evaluate(dataloader, cnn_model, rnn_model, batch_size):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
                class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = \
            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss / step
    w_cur_loss = w_total_loss / step

    return s_cur_loss, w_cur_loss

In [5]:
import torch.nn.functional as F
################ Transformer: Text Encoder ############
class SelfAttention(nn.Module):
    def __init__(self, k, heads=8):
        super().__init__()
        self.k, self.heads = k, heads
        # These compute the queries, keys and values for all 
        # heads (as a single concatenated vector)
        self.tokeys    = nn.Linear(k, k * heads, bias=False)
        self.toqueries = nn.Linear(k, k * heads, bias=False)
        self.tovalues  = nn.Linear(k, k * heads, bias=False)

        # This unifies the outputs of the different heads into 
        # a single k-vector
        self.unifyheads = nn.Linear(heads * k, k)
    def forward(self, x):
        b, t, k = x.size()
        h = self.heads

        queries = self.toqueries(x).view(b, t, h, k)
        keys    = self.tokeys(x)   .view(b, t, h, k)
        values  = self.tovalues(x) .view(b, t, h, k)
        
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, k)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, k)
        values = values.transpose(1, 2).contiguous().view(b * h, t, k)
        
        queries = queries / (k ** (1/4))
        keys    = keys / (k ** (1/4))

        # - get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))
        # - dot has size (b*h, t, t) containing raw weights

        dot = F.softmax(dot, dim=2) 
        # - dot now contains row-wise normalized weights
        
        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, k)
        
        out = out.transpose(1, 2).contiguous().view(b, t, h * k)
        return self.unifyheads(out)

class TransformerBlock(nn.Module):

    def __init__(self, emb, heads, mask, seq_length, ff_hidden_mult=4, dropout=0.0, wide=True):
        super().__init__()

        self.attention = SelfAttention(k=emb, heads=heads)
        
        self.mask = mask

        self.norm1 = nn.LayerNorm(emb)
        self.norm2 = nn.LayerNorm(emb)

        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_mult * emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult * emb, emb)
        )

        self.do = nn.Dropout(dropout)

    def forward(self, x):

        attended = self.attention(x)

        x = self.norm1(attended + x)

        x = self.do(x)

        fedforward = self.ff(x)

        x = self.norm2(fedforward + x)

        x = self.do(x)

        return x
class TEXT_TRANSFORMER_ENCODER(nn.Module):
    def __init__(self, emb, heads, depth, seq_length, num_tokens, dropout=0.0, wide=False):
        super().__init__()

        self.num_tokens = num_tokens

        self.token_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=num_tokens)
        self.pos_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=seq_length)

        tblocks = []
        for i in range(depth):
            tblocks.append(
                TransformerBlock(emb=emb
                                 , heads=heads
                                 , seq_length=seq_length
                                 , mask=False
                                 , dropout=dropout
                                 , wide=wide))

        self.tblocks = nn.Sequential(*tblocks)

        self.do = nn.Dropout(dropout)

        
    def forward(self, x):
        """
        :param x: A batch by sequence length integer tensor of token indices.
        :return: predicted log-probability vectors for each token based on the preceding tokens.
        """
        tokens = self.token_embedding(x)
        
        
        b, t, e = tokens.size()
#         print('b:{0}, t:{1}, e:{2}'.format(b, t, e))
        positions = torch.arange(t,device='cuda')
        
        positions = self.pos_embedding(positions)[None, :, :].expand(b, t, e)
        
#         print('positions:',positions.size())
        x = tokens + positions
        x = self.do(x)

#         print('x:',x.size())
        words_emb = self.tblocks(x)
        words_emb = torch.transpose(words_emb,1,2)
        sent_emb = x.mean(dim=1) # pool over the time dimension
#         print('words_emb:',words_emb.shape,'sent_emb:',sent_emb.shape)

        return words_emb,sent_emb


In [6]:
np.arange(3)

array([0, 1, 2])

In [7]:

def build_models():
    # build model ############################################################
#     text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    text_encoder = TEXT_TRANSFORMER_ENCODER(emb=cfg.TEXT.EMBEDDING_DIM
                                            ,heads=8
                                            ,depth=4
                                            ,seq_length=cfg.TEXT.WORDS_NUM
                                            ,num_tokens=dataset.n_words)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(batch_size)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != '':
        print('Loading... ', cfg.TRAIN.NET_E)
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print('Load ', name)

        istart = cfg.TRAIN.NET_E.rfind('_') + 8
        iend = cfg.TRAIN.NET_E.rfind('.')
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print('start_epoch', start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch


In [8]:

wlosses = []
slosses = []

args = parse_args()
if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)

if args.gpu_id == -1:
    cfg.CUDA = False
else:
    cfg.GPU_ID = args.gpu_id

if args.data_dir != '':
    cfg.DATA_DIR = args.data_dir
print('Using config:')
pprint.pprint(cfg)

if not cfg.TRAIN.FLAG:
    args.manualSeed = 5000
elif args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
np.random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if cfg.CUDA:
    torch.cuda.manual_seed_all(args.manualSeed)

##########################################################################
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/transfomer_test_{0}_{1}_{2}'.format(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

model_dir = os.path.join(output_dir, 'Model')
image_dir = os.path.join(output_dir, 'Image')
metrics_dir = os.path.join(output_dir, 'Metrics')
mkdir_p(model_dir)
mkdir_p(image_dir)
mkdir_p(metrics_dir)

torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True



Using config:
{'B_VALIDATION': False,
 'CONFIG_NAME': 'DAMSM',
 'CUDA': True,
 'DATASET_NAME': 'coco',
 'DATA_DIR': '../data/coco/',
 'GAN': {'B_ATTENTION': True,
         'B_DCGAN': False,
         'CONDITION_DIM': 100,
         'DF_DIM': 64,
         'GF_DIM': 128,
         'R_NUM': 2,
         'Z_DIM': 100},
 'GPU_ID': 3,
 'RNN_TYPE': 'LSTM',
 'TEXT': {'CAPTIONS_PER_IMAGE': 5, 'EMBEDDING_DIM': 256, 'WORDS_NUM': 15},
 'TRAIN': {'BATCH_SIZE': 96,
           'B_NET_D': True,
           'DISCRIMINATOR_LR': 0.0002,
           'ENCODER_LR': 0.002,
           'FLAG': True,
           'GENERATOR_LR': 0.0002,
           'MAX_EPOCH': 600,
           'NET_E': '',
           'NET_G': '',
           'RNN_GRAD_CLIP': 0.25,
           'SMOOTH': {'GAMMA1': 4.0,
                      'GAMMA2': 5.0,
                      'GAMMA3': 10.0,
                      'LAMBDA': 1.0},
           'SNAPSHOT_INTERVAL': 5},
 'TREE': {'BASE_SIZE': 299, 'BRANCH_NUM': 1},
 'WORKERS': 1}


  yaml_cfg = edict(yaml.load(f))


In [None]:
# Get data loader ##################################################
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
batch_size = cfg.TRAIN.BATCH_SIZE
image_transform = transforms.Compose([
    transforms.Scale(int(imsize * 76 / 64)),
    transforms.RandomCrop(imsize),
    transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, 'train',
                      base_size=cfg.TREE.BASE_SIZE,
                      transform=image_transform)

print(dataset.n_words, dataset.embeddings_num)
assert dataset


In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))

# # validation data #
dataset_val = TextDataset(cfg.DATA_DIR, 'test',
                          base_size=cfg.TREE.BASE_SIZE,
                          transform=image_transform)
dataloader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=batch_size, drop_last=True,
    shuffle=True, num_workers=int(cfg.WORKERS))



In [None]:
# Train ##############################################################
text_encoder, image_encoder, labels, start_epoch = build_models()


In [None]:
text_encoder

In [None]:
para = list(text_encoder.parameters())
for v in image_encoder.parameters():
    if v.requires_grad:
        para.append(v)
# optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
# At any point you can hit Ctrl + C to break out of training early.


In [None]:
tb_dir = '../tensorboard/transfomer_test_{0}_{1}_{2}'.format(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
mkdir_p(tb_dir)
tbw = SummaryWriter(log_dir=tb_dir) # Tensorboard logging


In [None]:
try:
    lr = cfg.TRAIN.ENCODER_LR
    for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
        optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
        epoch_start_time = time.time()

        count = train(dataloader, image_encoder, text_encoder,
              batch_size, labels, optimizer, epoch,
              dataset.ixtoword, image_dir)
        print('-' * 89)
        if len(dataloader_val) > 0:
            s_loss, w_loss = evaluate(dataloader_val, image_encoder,
                                      text_encoder, batch_size)
            wlosses.append(w_loss)
            slosses.append(s_loss)
            print('| end epoch {:3d} | valid loss '
                  '{:5.2f} {:5.2f} | lr {:.5f}|'
                  .format(epoch, s_loss, w_loss, lr))
            tbw.add_scalar('val_w_loss', float(w_loss.item()), epoch)
            tbw.add_scalar('val_s_loss', float(s_loss.item()), epoch)
        print('-' * 89)
        if lr > cfg.TRAIN.ENCODER_LR/10.:
            lr *= 0.98

        if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or
            epoch == cfg.TRAIN.MAX_EPOCH):
            torch.save(image_encoder.state_dict(),
                       '{0}/image_encoder{1}.pth'.format(model_dir, epoch))
            torch.save(text_encoder.state_dict(),
                       '{0}/text_encoder{1}.pth'.format(model_dir, epoch))
            print('Save G/Ds models.')
    df = pd.DataFrame()
    df['eval_wlosses']=wlosses
    df['eval_slosses']=slosses
    df.to_csv('{0}/val_losses.csv'.format(metrics_dir))
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')


In [None]:
# captions: tensor([[   10,   115,   485,  ...,    21,     8,   423],
#         [   10,    78,   946,  ...,    11,    10,   423],
#         [   56,   673,   674,  ...,    47,   795,    11],
#         ...,
#         [   56,    89,   868,  ...,     0,     0,     0],
#         [   10,  1353,   115,  ...,     0,     0,     0],
#         [   10, 14329,   115,  ...,     0,     0,     0]], device='cuda:3') torch.Size([96, 15])
# b:96, t:15, e:256
# positions: torch.Size([96, 15, 256])
# x: torch.Size([96, 15, 256])
# words_emb: torch.Size([96, 15, 256]) sent_emb: torch.Size([96, 256])
# words_emb: torch.Size([96, 15, 256]) , sent_emb: torch.Size([96, 256])
# contextT: torch.Size([96, 289, 256]) query: torch.Size([96, 15, 15])