In [1]:
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
import evaluation
from evaluation import PTBTokenizer, Cider
from models.transformer import Transformer, LinearEncoder, Decoder
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import argparse, os, pickle
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile

import horovod.torch as hvd
import torch.multiprocessing as mp
from apex import amp

random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)

class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

In [2]:
def evaluate_loss(model, dataloader, loss_fn, text_field):
    # Validation loss
    model.eval()
    running_loss = .0
    with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar:
        with torch.no_grad():
            for it, (detections, captions) in enumerate(dataloader):
                detections, captions = detections.to(device), captions.to(device)
                out = model(detections, captions)
                captions = captions[:, 1:].contiguous()
                out = out[:, :-1].contiguous()
                loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1))
                this_loss = loss.item()
                running_loss += this_loss

                pbar.set_postfix(loss=running_loss / (it + 1))
                pbar.update()

    val_loss = running_loss / len(dataloader)
    return val_loss


def evaluate_metrics(model, dataloader, text_field):
    import itertools
    model.eval()
    gen = {}
    gts = {}
    with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader)) as pbar:
        for it, (images, caps_gt) in enumerate(iter(dataloader)):
            images = images.to(device)
            with torch.no_grad():
                out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
            caps_gen = text_field.decode(out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]
                gts['%d_%d' % (it, i)] = gts_i
            pbar.update()

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen)
    return scores


def train_xe(model, dataloader, optim, text_field):
    # Training with cross-entropy
    model.train()
    # scheduler.step()
    running_loss = .0
    with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
        for it, (detections, captions) in enumerate(dataloader):
            detections, captions = detections.to(device), captions.to(device)
            out = model(detections, captions)

            if args.use_amp:
                optim.synchronize()
                
            optim.zero_grad()
            captions_gt = captions[:, 1:].contiguous()
            out = out[:, :-1].contiguous()
            loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1))

            if args.use_amp:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if args.use_amp:
                optim.skip_synchronize()
            else:
                optim.step()

            this_loss = loss.item()
            running_loss += this_loss

            pbar.set_postfix(loss=running_loss / (it + 1))
            pbar.update()
            scheduler.step()

    loss = running_loss / len(dataloader)
    return loss


def train_scst(model, dataloader, optim, cider, text_field):
    # Training with self-critical
    tokenizer_pool = multiprocessing.Pool()
    running_reward = .0
    running_reward_baseline = .0
    model.train()
    running_loss = .0
    seq_len = 20
    beam_size = 5

    with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
        for it, (detections, caps_gt) in enumerate(dataloader):
            detections = detections.to(device)
            outs, log_probs = model.beam_search(detections, seq_len, text_field.vocab.stoi['<eos>'],
                                                beam_size, out_size=beam_size)

            if args.use_amp:
                optim.synchronize()

            optim.zero_grad()

            # Rewards
            caps_gen = text_field.decode(outs.view(-1, seq_len))
            caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt)))
            caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt])
            reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32)
            reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size)
            reward_baseline = torch.mean(reward, -1, keepdim=True)
            loss = -torch.mean(log_probs, -1) * (reward - reward_baseline)

            loss = loss.mean()

            if args.use_amp:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if args.use_amp:
                optim.skip_synchronize()
            else:
                optim.step()

            running_loss += loss.item()
            running_reward += reward.mean().item()
            running_reward_baseline += reward_baseline.mean().item()
            pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1),
                             reward_baseline=running_reward_baseline / (it + 1))
            pbar.update()

    loss = running_loss / len(dataloader)
    reward = running_reward / len(dataloader)
    reward_baseline = running_reward_baseline / len(dataloader)
    return loss, reward, reward_baseline

In [3]:
cfg = {'exp_name': 'transformer', 
       'batch_size' : 200 , 
       'workers' : 2, 
       'head' : 8, 
       'warmup' : 10000, 
       'resume_last' : False, 'resume_best' : False,
       'features_path' : './data/coco_detections.hdf5',
       'annotation_folder' : './data/annotations/',
       'logs_folder' : './tensorboard_logs',
       'N_enc' : 1,
       'N_dec' : 1,
       'use_amp' : True,
       'cuda' : True
        }

args = objectview(cfg)

In [4]:
device = torch.device('cuda')

In [5]:
print('Transformer Training')

writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name))

# Pipeline for image regions
image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False)

# Pipeline for text
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                   remove_punctuation=True, nopoints=False)

# Horovod: initialize library.
hvd.init()
torch.manual_seed(1)

if args.cuda:
    # Horovod: pin GPU to local rank.
    torch.cuda.set_device(hvd.local_rank())
    torch.cuda.manual_seed(1)

# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(1)

kwargs = {'pin_memory': True} if args.cuda else {}
# When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
# issues with Infiniband implementations that are not fork-safe
if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
    mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
    kwargs['multiprocessing_context'] = 'forkserver'



Transformer Training


In [6]:
# Create the dataset
dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
train_dataset, val_dataset, test_dataset = dataset.splits

In [7]:
type(train_dataset)

data.dataset.PairedDataset

In [None]:
if not os.path.isfile('vocab_%s.pkl' % args.exp_name):
    print("Building vocabulary")
    text_field.build_vocab(train_dataset, val_dataset, min_freq=5)
    pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb'))
else:
    text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb'))

# Model and dataloaders
encoder = LinearEncoder(args.N_enc, 0)
decoder = Decoder(len(text_field.vocab), 54, args.N_dec, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()})
ref_caps_train = list(train_dataset.text)
cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train))
dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})


def lambda_lr(s):
    warm_up = args.warmup
    s += 1
    return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5)


# Initial conditions
optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))

# Horovod: wrap optimizer with DistributedOptimizer.
optim = hvd.DistributedOptimizer(optim, named_parameters=model.named_parameters(), )

if args.use_amp:
model, optim = amp.initialize(model, optim, opt_level="O1")

scheduler = LambdaLR(optim, lambda_lr)

# Horovod: (optional) compression algorithm.
# compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none


loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['<pad>'])
use_rl = False
best_cider = .0
patience = 0
start_epoch = 0

if args.resume_last or args.resume_best:
if args.resume_last:
    fname = 'saved_models/%s_last.pth' % args.exp_name
else:
    fname = 'saved_models/%s_best.pth' % args.exp_name

if os.path.exists(fname):
    data = torch.load(fname)
    torch.set_rng_state(data['torch_rng_state'])
    torch.cuda.set_rng_state(data['cuda_rng_state'])
    np.random.set_state(data['numpy_rng_state'])
    random.setstate(data['random_rng_state'])
    model.load_state_dict(data['state_dict'], strict=False)
    optim.load_state_dict(data['optimizer'])
    scheduler.load_state_dict(data['scheduler'])
    start_epoch = data['epoch'] + 1
    best_cider = data['best_cider']
    patience = data['patience']
    use_rl = data['use_rl']
    print('Resuming from epoch %d, validation loss %f, and best cider %f' % (
        data['epoch'], data['val_loss'], data['best_cider']))


# Horovod Distribute Sampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())

dict_train_sampler = torch.utils.data.distributed.DistributedSampler(dict_dataset_train, num_replicas=hvd.size(), rank=hvd.rank())

dict_val_sampler = torch.utils.data.distributed.DistributedSampler(dict_dataset_val, num_replicas=hvd.size(), rank=hvd.rank())

dict_test_sampler = torch.utils.data.distributed.DistributedSampler(dict_dataset_test, num_replicas=hvd.size(), rank=hvd.rank())

# Broadcast parameters from rank 0 to all other processes.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optim, root_rank=0)


dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=True, sampler=train_sampler, **kwargs)
dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, sampler=val_sampler, **kwargs)
dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=False, num_workers=args.workers, sampler=dict_train_sampler, **kwargs)
dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5, sampler=dict_val_sampler, **kwargs)
dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5, sampler=dict_test_sampler, **kwargs)

print("Training starts")
for e in range(start_epoch, start_epoch + 100):

if not use_rl:
    train_loss = train_xe(model, dataloader_train, optim, text_field)
    writer.add_scalar('data/train_loss', train_loss, e)
else:
    hvd.broadcast_optimizer_state(optim, root_rank=0)

    train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim, cider_train, text_field)
    writer.add_scalar('data/train_loss', train_loss, e)
    writer.add_scalar('data/reward', reward, e)
    writer.add_scalar('data/reward_baseline', reward_baseline, e)

# Validation loss
val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field)
writer.add_scalar('data/val_loss', val_loss, e)

# Validation scores
scores = evaluate_metrics(model, dict_dataloader_val, text_field)
print("Validation scores", scores)
val_cider = scores['CIDEr']
writer.add_scalar('data/val_cider', val_cider, e)
writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e)
writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e)
writer.add_scalar('data/val_meteor', scores['METEOR'], e)
writer.add_scalar('data/val_rouge', scores['ROUGE'], e)

# Test scores
scores = evaluate_metrics(model, dict_dataloader_test, text_field)
print("Test scores", scores)
writer.add_scalar('data/test_cider', scores['CIDEr'], e)
writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e)
writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e)
writer.add_scalar('data/test_meteor', scores['METEOR'], e)
writer.add_scalar('data/test_rouge', scores['ROUGE'], e)

# Prepare for next epoch
best = False
if val_cider >= best_cider:
    best_cider = val_cider
    patience = 0
    best = True
else:
    patience += 1

switch_to_rl = False
exit_train = False
if patience == 5:
    if not use_rl:
        use_rl = True
        switch_to_rl = True
        patience = 0
        optim = Adam(model.parameters(), lr=5e-6)
        print("Switching to RL")
    else:
        print('patience reached.')
        exit_train = True

if switch_to_rl and not best:
    data = torch.load('saved_models/%s_best.pth' % args.exp_name)
    torch.set_rng_state(data['torch_rng_state'])
    torch.cuda.set_rng_state(data['cuda_rng_state'])
    np.random.set_state(data['numpy_rng_state'])
    random.setstate(data['random_rng_state'])
    model.load_state_dict(data['state_dict'])
    print('Resuming from epoch %d, validation loss %f, and best cider %f' % (
        data['epoch'], data['val_loss'], data['best_cider']))

torch.save({
    'torch_rng_state': torch.get_rng_state(),
    'cuda_rng_state': torch.cuda.get_rng_state(),
    'numpy_rng_state': np.random.get_state(),
    'random_rng_state': random.getstate(),
    'epoch': e,
    'val_loss': val_loss,
    'val_cider': val_cider,
    'state_dict': model.state_dict(),
    'optimizer': optim.state_dict(),
    'scheduler': scheduler.state_dict(),
    'patience': patience,
    'best_cider': best_cider,
    'use_rl': use_rl,
}, 'saved_models/%s_last.pth' % args.exp_name)

if best:
    copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best.pth' % args.exp_name)

if exit_train:
    writer.close()
    break