# 1. Build caption vocab

In [None]:
import nltk
import pickle
import os
from collections       import Counter
from pycocotools.coco  import COCO

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def build_vocab(json, threshold):
    """Build a simple vocabulary wrapper."""
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized the captions.".format(i+1, len(ids)))

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

def save_vocab(kwargs, overwrite=False):
    vocab_path = kwargs.get('vocab_path', None)
    exists = os.path.isfile(vocab_path)
    
    if not exists or overwrite:
        vocab = build_vocab(json=kwargs.get('caption_path', None), threshold=kwargs.get('threshold', None))
        with open(vocab_path, 'wb') as f:
            pickle.dump(vocab, f)
        print("Total vocabulary size: {}".format(len(vocab)))
        print("Saved the vocabulary wrapper to '{}'".format(vocab_path))

In [None]:
kwargs = {}
kwargs['caption_path'] = './data/annotations/captions_val2017.json' # path for train annotation file'
kwargs['vocab_path'] = './data/vocab.pkl' # path for saving vocabulary wrapper
kwargs['threshold'] = 4 # minimum word count threshold'
save_vocab(kwargs)

# 2. Resize image

In [2]:
from PIL import Image

def resize_images(image_dir, output_dir, size):
    """Resize the images in 'image_dir' and save into 'output_dir'."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        img_output_dir = os.path.join(output_dir, image)
        if os.path.exists(img_output_dir):
            continue
            
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = img.resize(size, Image.ANTIALIAS)
                img.save(img_output_dir, img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))


In [None]:
image_dir = './data/val2017/' # directory for train images'
output_dir = './data/resizedval2017/' # directory for saving resized images'
image_size = (256, 256) 
args = resize_images(image_dir, output_dir, image_size)

# 3. Load precessed data

In [3]:
# From pytorch tutorial
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np
from PIL import Image

class CocoDataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, root, json, vocab, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)
    
    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader

# 4. Train

## 4.1 Model

In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn import init


class Encoder(nn.Module):
    def __init__(self, embed_size, encoder=models.resnet50):
        super(Encoder, self).__init__()
        if 'resnet' in encoder.__module__:
            resnet = encoder(pretrained=True)
            modules = list(resnet.children())[:-2]
            self.dim = 2048
        elif 'vgg' in encoder.__module__:
            resnet = encoder(pretrained=True)
            modules = list(resnet.children())[:-1]
            self.dim = 512
        else:
            raise NotImplementedError
        self.resnet = nn.Sequential(*modules)

    def forward(self, images):
        """Extract the image feature vectors."""
        features = self.resnet(images) # [batch, self.dim, 7, 7]
        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(-1)) # [batch, 49, self.dim=2048]
        return features


class Attention(nn.Module):
    def __init__(self, encoder_dim, hidden_size=512):
        super(Attention, self).__init__()

        self.affine_W = nn.Linear(encoder_dim, hidden_size)
        self.affine_U = nn.Linear(hidden_size, hidden_size)
        
        self.affine_V = nn.Linear(hidden_size, 1)
        
    def init_weights(self):
        init.xavier_uniform(self.affine_W.weight )
        init.xavier_uniform(self.affine_U.weight )
        init.xavier_uniform(self.affine_V.weight )

    def forward(self, a, prev_hidden_state):
        att = torch.tanh(self.affine_W(a) + self.affine_U(prev_hidden_state).unsqueeze(1)) # [batch, 49, 1]
        e_t = self.affine_V(att).squeeze(2)

        alpha_t = nn.Softmax(1)(e_t) # [batch, 49]
        context_t = (a * alpha_t.unsqueeze(2)).sum(1) # [batch, 2048]

        return context_t, alpha_t


class Decoder(nn.Module):
    def __init__(self, encoder_dim, vocab_size, hidden_size=512):
        super(Decoder, self).__init__()

        self.vocab_size = vocab_size
        self.encoder_dim = encoder_dim

        self.init_affine_h = nn.Linear(encoder_dim, hidden_size)
        self.init_affine_c = nn.Linear(encoder_dim, hidden_size)

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.attention = Attention(encoder_dim, hidden_size=hidden_size)
        self.f_beta = nn.Linear(hidden_size, encoder_dim)
        self.lstm = nn.LSTMCell(hidden_size+encoder_dim, hidden_size)
        self.output_W = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        batch_size = features.size(0)

        features_avg = features.mean(dim=1)
        h = torch.tanh(self.init_affine_h(features_avg))
        c = torch.tanh(self.init_affine_c(features_avg))

        T = max([len(caption) for caption in captions])
        prev_word = torch.zeros(batch_size, 1).long()
        pred_words = torch.zeros(batch_size, T, self.vocab_size) # [128, 26, 2699]
        alphas = torch.zeros(batch_size, T, features.size(1))

        embedding = self.embedding(prev_word)

        for t in range(T):
            context_t, alpha_t = self.attention.forward(features, h)
            gate = torch.sigmoid(self.f_beta(h))
            gated_context = gate * context_t
            
            # dim = 3 is for mini batch
            embedding = embedding.squeeze(1) if embedding.dim() == 3 else embedding
            lstm_input = torch.cat((embedding, gated_context), dim=1)

            h, c = self.lstm(lstm_input, (h, c))
            output = torch.sigmoid(self.output_W(h))

            pred_words[:, t] = output
            alphas[:, t] = alpha_t
            if not self.training:
                embedding = self.embedding(output.max(1)[1].reshape(batch_size, 1))
            else:
                prev_word = captions[:,t]
                embedding = self.embedding(prev_word)
        return pred_words, alphas

    def caption(self, features, beam_size):
        '''
        From https://github.com/kelvinxu/arctic-captions/blob/master/generate_caps.py        
        '''
        
        prev_words = torch.zeros(beam_size, 1).long()

        sentences = prev_words
        top_preds = torch.zeros(beam_size, 1)
        alphas = torch.ones(beam_size, 1, features.size(1))

        completed_sentences = []
        completed_sentences_alphas = []
        completed_sentences_preds = []

        step = 1
        h = torch.tanh(self.init_affine_h(features_avg))
        c = torch.tanh(self.init_affine_c(features_avg))

        while True:
            embedding = self.embedding(prev_words).squeeze(1)
            context, alpha = self.attention(img_features, h)
            gate = torch.sigmoid(self.f_beta(h))
            gated_context = gate * context

            lstm_input = torch.cat((embedding, gated_context), dim=1)
            h, c = self.lstm(lstm_input, (h, c))
            output = top_preds.expand_as(output) + output

            if step == 1:
                top_preds, top_words = output[0].topk(beam_size, 0, True, True)
            else:
                top_preds, top_words = output.view(-1).topk(beam_size, 0, True, True)
            prev_word_idxs = top_words / output.size(1)
            next_word_idxs = top_words % output.size(1)

            sentences = torch.cat((sentences[prev_word_idxs], next_word_idxs.unsqueeze(1)), dim=1)
            alphas = torch.cat((alphas[prev_word_idxs], alpha[prev_word_idxs].unsqueeze(1)), dim=1)

            incomplete = [idx for idx, next_word in enumerate(next_word_idxs) if next_word != 1]
            complete = list(set(range(len(next_word_idxs))) - set(incomplete))

            if len(complete) > 0:
                completed_sentences.extend(sentences[complete].tolist())
                completed_sentences_alphas.extend(alphas[complete].tolist())
                completed_sentences_preds.extend(top_preds[complete])
            beam_size -= len(complete)

            if beam_size == 0:
                break
            sentences = sentences[incomplete]
            alphas = alphas[incomplete]
            h = h[prev_word_idxs[incomplete]]
            c = c[prev_word_idxs[incomplete]]
            features = features[prev_word_idxs[incomplete]]
            top_preds = top_preds[incomplete].unsqueeze(1)
            prev_words = next_word_idxs[incomplete].unsqueeze(1)

            if step > 50:
                break
            step += 1

        idx = completed_sentences_preds.index(max(completed_sentences_preds))
        sentence = completed_sentences[idx]
        alpha = completed_sentences_alphas[idx]
        return sentence, alpha
    

## 4.2 Measures

In [5]:
import torch

class AverageMeter(object):
    '''From https://github.com/pytorch/examples/blob/master/imagenet/main.py'''
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(preds, targets, k):
    batch_size = targets.size(0)
    _, pred = preds.topk(k, 1, True, True)
    pred = pred.permute(1, 0, 2)
    correct = pred.eq(targets.expand_as(pred))
    correct_total = correct.view(-1).float().sum()
    print(correct_total, correct.numel())
    return correct_total.item() / float(correct.numel() / k)

def calculate_caption_lengths(word_dict, captions):
    lengths = 0
    for caption_tokens in captions:
        for token in caption_tokens:
            if token in (word_dict['<start>'], word_dict['<end>'], word_dict['<pad>']):
                continue
            else:
                lengths += 1
    return lengths

## 4.3 Train

In [6]:
from torchvision import transforms

# Device configuration
device = torch.device('cpu')

def train(model_path, crop_size, vocab_path, image_dir, caption_path, log_step=10,
      save_step=1000, embed_size=256, hidden_size=512, num_epochs=5,
      batch_size=128, num_workers=2, learning_rate=0.001, alpha_c=1):
    # Create model directory
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # used for calculating bleu scores
    references = []
    hypotheses = []
    
    
    # Load vocabulary wrapper
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(image_dir, caption_path, vocab,
                             transform, batch_size,
                             shuffle=True, num_workers=num_workers)

    # Build the models
    encoder = Encoder(models.vgg19).to(device)
    decoder = Decoder(encoder.dim, len(vocab), hidden_size=hidden_size).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters())
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    # Train the models
    total_step = len(data_loader)
    for epoch in range(num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):

            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)

            optimizer.zero_grad()

            # Forward, backward and optimize
            features = encoder.forward(images)
            prediction, alphas = decoder.forward(features, captions)

            att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()
            loss = criterion(prediction.permute(0,2,1), captions) + att_regularization
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_caption_length = calculate_caption_lengths(vocab.word2idx, captions)
            acc1 = accuracy(prediction.permute(0,2,1), captions, 1)
            acc5 = accuracy(prediction.permute(0,2,1), captions, 5)
            losses.update(loss.item(), total_caption_length)
            top1.update(acc1, total_caption_length)
            top5.update(acc5, total_caption_length)

            # Print log info
            if i % log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                      .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
                print('Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f}), Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                        top1=top1, top5=top5))
            # Save the model checkpoints
            if (i + 1) % save_step == 0:
                torch.save(decoder.state_dict(), os.path.join(
                    model_path, 'decoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
                torch.save(encoder.state_dict(), os.path.join(
                    model_path, 'encoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))

In [None]:
model_path = './models/'  # path for saving trained models
crop_size = 224  # size for randomly cropping images
vocab_path = './data/vocab.pkl'  # path for vocabulary wrapper
image_dir = './data/resizedval2017'
caption_path = './data/annotations/captions_val2017.json'
train(model_path, crop_size, vocab_path, image_dir, caption_path, log_step=10,
      save_step=1000, embed_size=256, hidden_size=512, num_epochs=5,
      batch_size=128, num_workers=2, learning_rate=0.001)

loading annotations into memory...
Done (t=0.06s)
creating index...
index created!
tensor(0.) 2560
tensor(3.) 12800
Epoch [0/5], Step [0/196], Loss: 8.2578, Perplexity: 3857.5028
Top 1 Accuracy 0.000 (0.000), Top 5 Accuracy 0.001 (0.001)
tensor(2914.) 4480
tensor(3191.) 22400
tensor(1756.) 3328
tensor(2098.) 16640


In [None]:
def validate(model_path, crop_size, vocab_path, image_dir, caption_path, log_step=10):
    raise NotImplementError