In [1]:
import json
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import skimage.transform
from scipy.misc import imread, imresize
from PIL import Image

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

In [2]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

In [3]:
def caption_image(encoder, decoder, image_path, word_vocab_path, beam_size=3):
    
    # load word vocab in word2index and then convert into index2word
    with open(word_vocab_path, 'r') as file: word_vocab = json.load(file)
    index2word = { index:word for word, index in word_vocab.items() }

    k = beam_size
    vocab_size = len(word_vocab)
    
    # read and process image
    image = imread(image_path)
    if len(image.shape) == 2:
        image = image[:, :, np.newaxis]
        image = np.concatenate([image, image, image], axis=2)
    
    image = imresize(image, (256, 256))
    image = image.transpose(2, 0, 1)
    image = image/ 255.
    image = torch.FloatTensor(image).to(device)
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean = (0.485, 0.456, 0.406),
                                                     std = (0.229, 0.224, 0.225))])
    
    image = transform(image) # (3, 256, 256)
    image = image.unsqueeze(0) # (1, 3, 256, 256)
    
    # encoding process
    features = encoder(image)
    encoder_image_size = features.size(1)
    encoder_dim = features.size(3)

    # flatten features/ encoding outputs
    encoder_outputs = features.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
    num_pixels = encoder_outputs.size(1)
    
    # treat the problem as having a batch size of k
    encoder_outputs = encoder_outputs.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)

    # tensor to store top k previous words at each step (starting from just <start> token)
    k_prev_words = torch.LongTensor([[word_vocab['<start>']]] * k).to(device) # (k, 1)

    # tensor to store top k sequences (starting from just <start> token)
    sequences = k_prev_words # (k, 1)

    # tensor to store top k sequences' scores (starting from 0 values)
    top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)

    # tensor to store top k sequences' alphas (starting from 1 values)
    sequences_alpha = torch.ones(k, 1, encoder_image_size, encoder_image_size).to(device) # (k, 1, encoder_image_size, encoder_image_size)
    
    # lists to store completed sequences and scores
    complete_sequences = list()
    complete_sequences_scores = list()
    complete_sequences_alpha = list()

    # decoding process
    step = 1
    hidden, cell = decoder.init_state(encoder_outputs)
    
    # s is a number less than or equal to k, since sequences are removed from this process once they hit <end> token
    while True:
        
        embeddings = decoder.embedding_layer(k_prev_words).squeeze(1) # (s, embedding_dim)
        attention_weighted_encoding, alpha = decoder.attention_layer(encoder_outputs, hidden) # (s, encoder_dim), (s, num_pixels)

        alpha = alpha.view(-1, encoder_image_size, encoder_image_size) # (s, encoder_image_size, encoder_image_size)
        
        gate = decoder.sigmoid(decoder.f_beta(hidden)) # gating scalar, (s, encoder_dim)
        attention_weighted_encoding = gate * attention_weighted_encoding

        hidden, cell = decoder.decode_step(torch.cat([embeddings, attention_weighted_encoding], dim=1), (hidden, cell)) # (s, decoder_dim)

        scores = decoder.fc_layer(hidden) # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)
        
        # add scores
        scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)

        # for the first step, all k points will have the same scores (since same k previous words, hidden, cell)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
        else:
            # unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)

        # convert unrolled indices to actual indices of scores
        prev_word_indices = top_k_words / vocab_size # (s)
        next_word_indices = top_k_words % vocab_size # (s)
        
        # add new words to sequences
        sequences = torch.cat([sequences[prev_word_indices], next_word_indices.unsqueeze(1)], dim=1) # (s, step+1)
        sequences_alpha = torch.cat([sequences_alpha[prev_word_indices], alpha[prev_word_indices].unsqueeze(1)], 
                                     dim=1) # (s, step+1, encoder_image_size, encoder_image_size)
        
        incomplete_indices = [indices for indices, next_word in enumerate(next_word_indices) if
                              next_word != word_vocab['<end>']]
        complete_indices = list(set(range(len(next_word_indices))) - set(incomplete_indices))

        # set aside complete sequences
        if len(complete_indices) > 0:
            complete_sequences.extend(sequences[complete_indices].tolist())
            complete_sequences_alpha.extend(sequences_alpha[complete_indices].tolist())
            complete_sequences_scores.extend(top_k_scores[complete_indices])
        k -= len(complete_indices) # reduce beam length accordingly

        # process with incomplete sequences
        if k==0: break

        sequences = sequences[incomplete_indices]
        sequences_alpha = sequences_alpha[incomplete_indices]
        hidden = hidden[prev_word_indices[incomplete_indices]]
        cell = cell[prev_word_indices[incomplete_indices]]
        encoder_outputs = encoder_outputs[prev_word_indices[incomplete_indices]]
        top_k_scores = top_k_scores[incomplete_indices].unsqueeze(1)
        k_prev_words = next_word_indices[incomplete_indices].unsqueeze(1)

        # break if thins have been going on too long
        if step > 50: break
        step += 1
        
    i = complete_sequences_scores.index(max(complete_sequences_scores))
    sequence = complete_sequences[i]
    alphas = complete_sequences_alpha[i]
     
    return sequence, alphas, index2word

In [4]:
def visualize_image(image_path, sequence, alphas, index2word, smooth=True):
    
    image = Image.open(image_path)
    image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)

    words = [index2word[index] for index in sequence]
    
    for t in range(len(words)):
        
        if t > 50: break
            
        plt.subplot(np.ceil(len(words)/ 5.), 5, t + 1)
        plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
        plt.imshow(image)
        
        alphas = torch.FloatTensor(alphas)
        current_alpha = alphas[t, :]
        if smooth:
            alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8)
        else:
            alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24])
            
        if t == 0:
            plt.imshow(alpha, alpha=0)
        else:
            plt.imshow(alpha, alpha=0.8)
            
        plt.set_cmap(cm.Greys_r)
        plt.axis('off')
        
    plt.show()    

In [5]:
def load_model(model_path, device):

    checkpoint = torch.load(model_path)
    
    # init the networks
    encoder = checkpoint['encoder']
    encoder = encoder.to(device)
    
    decoder = checkpoint['decoder']
    decoder = decoder.to(device)
    
    # set the networks into eval mode
    encoder.eval()
    decoder.eval()

    return encoder, decocer

## Caption The Image!

In [6]:
MODEL_PATH = './weights/'
IMAGE_PATH = './images/'
WORD_VOCAB_PATH = './datasets/data/'
BEAM_SIZE = 3

In [None]:
encoder, decoder = load_image(MODEL_PATH, device)
sequence, alphas, index2word = caption_image(encoder, decoder, IMAGE_PATH, WORD_VOCAB_PATH, beam_size)

In [None]:
visualize_image(WORD_VOCAB_PATH, sequence, alphas, index2word, smooth=True)

---