In [None]:
import os
import numpy as np
import h5py
import json
import torch
import torch.nn as nn
from torch.nn import Parameter
from scipy.misc import imread, imresize
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample
from torch.utils.data import Dataset
import torchvision
import time
import torch.backends.cudnn as cudnn
import torch.optim
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
from cococaptioncider.pycocotools.coco import COCO
from cococaptioncider.pycocoevalcap.eval import COCOEvalCap
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.cm as cm
import skimage.transform
from PIL import Image

In [None]:
#Get what you have, CPU or GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_size, embed_size):
        super(Encoder,self).__init__()
        #resnet = torchvision.models.resnet101(pretrained = True)
        resnet = torchvision.models.resnet101(pretrained = True)
        all_modules = list(resnet.children())
        #Remove the last FC layer used for classification and the average pooling layer
        modules = all_modules[:-2]
        #Initialize the modified resnet as the class variable
        self.resnet = nn.Sequential(*modules) 
        self.avgpool = nn.AvgPool2d(7)
        self.fine_tune()    # To fine-tune the CNN, self.fine_tune(status = True)
    
    def forward(self,images):
        """
        The forward propagation function
        input: resized image of shape (batch_size,3,224,224)
        """
        #Run the image through the ResNet
        encoded_image = self.resnet(images)         # (batch_size,2048,7,7)
        batch_size = encoded_image.shape[0]
        features = encoded_image.shape[1]
        num_pixels = encoded_image.shape[2] * encoded_image.shape[3]
        # Get the global features of the image
        global_features = self.avgpool(encoded_image).view(batch_size, -1)   # (batch_size, 2048)
        enc_image = encoded_image.permute(0, 2, 3, 1)  #  (batch_size,7,7,2048)
        enc_image = enc_image.view(batch_size,num_pixels,features)          # (batch_size,num_pixels,2048)
        return enc_image, global_features
    
    def fine_tune(self, status = False):
        
        if not status:
            for param in self.resnet.parameters():
                param.requires_grad = False
        else:
            for module in list(self.resnet.children())[7:]:    #1 layer only. len(list(resnet.children())) = 8
                for param in module.parameters():
                    param.requires_grad = True 

In [None]:
class AdaptiveLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AdaptiveLSTMCell, self).__init__()
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
        self.x_gate = nn.Linear(input_size, hidden_size)
        self.h_gate = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, inp, states):
        h_old, c_old = states
        ht, ct = self.lstm_cell(inp, (h_old, c_old))
        sen_gate = F.sigmoid(self.x_gate(inp) + self.h_gate(h_old))
        st =  sen_gate * F.tanh(ct)
        return ht, ct, st

In [None]:
class AdaptiveAttention(nn.Module):
    def __init__(self, hidden_size, att_dim):
        super(AdaptiveAttention,self).__init__()
        self.sen_affine = nn.Linear(hidden_size, hidden_size)  
        self.sen_att = nn.Linear(hidden_size, att_dim)
        self.h_affine = nn.Linear(hidden_size, hidden_size)   
        self.h_att = nn.Linear(hidden_size, att_dim)
        self.v_att = nn.Linear(hidden_size, att_dim)
        self.alphas = nn.Linear(att_dim, 1)
        self.context_hidden = nn.Linear(hidden_size, hidden_size)

    def forward(self, spatial_image, decoder_out, st):
        """
        spatial_image: the spatial image of size (batch_size,num_pixels,hidden_size)
        decoder_out: the decoder hidden state of shape (batch_size, hidden_size)
        st: visual sentinal returned by the Sentinal class, of shape: (batch_size, hidden_size)
        """
        # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
        num_pixels = spatial_image.shape[1]
        visual_attn = self.v_att(spatial_image)           # (batch_size,num_pixels,att_dim)
        sentinel_affine = F.relu(self.sen_affine(st))     # (batch_size,hidden_size)
        sentinel_attn = self.sen_att(sentinel_affine)     # (batch_size,att_dim)

        hidden_affine = F.tanh(self.h_affine(decoder_out))    # (batch_size,hidden_size)
        hidden_attn = self.h_att(hidden_affine)               # (batch_size,att_dim)

        hidden_resized = hidden_attn.unsqueeze(1).expand(hidden_attn.size(0), num_pixels + 1, hidden_attn.size(1))

        concat_features = torch.cat([spatial_image, sentinel_affine.unsqueeze(1)], dim = 1)   # (batch_size, num_pixels+1, hidden_size)
        attended_features = torch.cat([visual_attn, sentinel_attn.unsqueeze(1)], dim = 1)     # (batch_size, num_pixels+1, att_dim)

        attention = F.tanh(attended_features + hidden_resized)    # (batch_size, num_pixels+1, att_dim)
        
        alpha = self.alphas(attention).squeeze(2)                   # (batch_size, num_pixels+1)
        att_weights = F.softmax(alpha, dim=1)                              # (batch_size, num_pixels+1)

        context = (concat_features * att_weights.unsqueeze(2)).sum(dim=1)       # (batch_size, hidden_size)     
        beta_value = att_weights[:,-1].unsqueeze(1)                       # (batch_size, 1)

        out_l = F.tanh(self.context_hidden(context + hidden_affine))

        return out_l, att_weights, beta_value

In [None]:
class DecoderWithAttention(nn.Module):
    def __init__(self,hidden_size, vocab_size, att_dim, embed_size, encoded_dim):
        super(DecoderWithAttention,self).__init__()
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.encoded_to_hidden = nn.Linear(encoded_dim, hidden_size)
        self.global_features = nn.Linear(encoded_dim, embed_size)
        self.LSTM = AdaptiveLSTMCell(embed_size * 2,hidden_size)
        self.adaptive_attention = AdaptiveAttention(hidden_size, att_dim)
        # input to the LSTMCell should be of shape (batch, input_size). Remember we are concatenating the word with
        # the global image features, therefore out input features should be embed_size * 2
        self.embedding = nn.Embedding(vocab_size, embed_size)  
        self.vocab_size = vocab_size
        self.dropout = nn.Dropout(p=0.5)
        self.init_weights()
        
    def init_weights(self):
        self.fc.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.embedding.weight.data.uniform_(-0.1, 0.1)
    
    def init_hidden_state(self, enc_image):
        h = torch.zeros(enc_image.shape[0], 512).to(device)
        c = torch.zeros(enc_image.shape[0], 512).to(device)
        return h, c
    
    def forward(self, enc_image, global_features, encoded_captions, caption_lengths):
        
        """
        enc_image: the encoded images from the encoder, of shape (batch_size, num_pixels, 2048)
        global_features: the global image features returned by the Encoder, of shape: (batch_size, 2048)
        encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        """
        spatial_image = F.relu(self.encoded_to_hidden(enc_image))  # (batch_size,num_pixels,hidden_size)
        global_image = F.relu(self.global_features(global_features))      # (batch_size,embed_size)
        batch_size = spatial_image.shape[0]
        num_pixels = spatial_image.shape[1]
        # Sort input data by decreasing lengths
        # caption_lenghts will contain the sorted lengths, and sort_ind contains the sorted elements indices 
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        #The sort_ind contains elements of the batch index of the tensor encoder_out. For example, if sort_ind is [3,2,0],
        #then that means the descending order starts with batch number 3,then batch number 2, and finally batch number 0. 
        spatial_image = spatial_image[sort_ind]           # (batch_size,num_pixels,hidden_size) with sorted batches
        global_image = global_image[sort_ind]             # (batch_size, embed_size) with sorted batches
        encoded_captions = encoded_captions[sort_ind]     # (batch_size, max_caption_length) with sorted batches 
        enc_image = enc_image[sort_ind]                   # (batch_size, num_pixels, 2048)

        # Embedding. Each batch contains a caption. All batches have the same number of rows (words), since we previously
        # padded the ones shorter than max_caption_length, as well as the same number of columns (embed_dim)
        embeddings = self.embedding(encoded_captions)     # (batch_size, max_caption_length, embed_dim)

        # Initialize the LSTM state
        h,c = self.init_hidden_state(enc_image)          # (batch_size, hidden_size)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores,alphas and betas
        predictions = torch.zeros(batch_size, max(decode_lengths), self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels+1).to(device)
        betas = torch.zeros(batch_size, max(decode_lengths),1).to(device) 
        
        # Concatenate the embeddings and global image features for input to LSTM 
        global_image = global_image.unsqueeze(1).expand_as(embeddings)
        inputs = torch.cat((embeddings,global_image), dim = 2)    # (batch_size, max_caption_length, embed_dim * 2)

        #Start decoding
        for timestep in range(max(decode_lengths)):
            # Create a Packed Padded Sequence manually, to process only the effective batch size N_t at that timestep. Note
            # that we cannot use the pack_padded_seq provided by torch.util because we are using an LSTMCell, and not an LSTM
            batch_size_t = sum([l > timestep for l in decode_lengths])
            current_input = inputs[:batch_size_t, timestep, :]             # (batch_size_t, embed_dim * 2)
            h, c, st = self.LSTM(current_input, (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, hidden_size)
            # Run the adaptive attention model
            out_l, alpha_t, beta_t = self.adaptive_attention(spatial_image[:batch_size_t],h,st)
            # Compute the probability over the vocabulary
            pt = self.fc(self.dropout(out_l))                  # (batch_size, vocab_size)
            predictions[:batch_size_t, timestep, :] = pt
            alphas[:batch_size_t, timestep, :] = alpha_t
            betas[:batch_size_t, timestep, :] = beta_t
        return predictions, alphas, betas, encoded_captions, decode_lengths, sort_ind  

In [None]:
cudnn.benchmark = True
checkpoint = 'BEST_checkpoint_12.pth.tar'
checkpoint = torch.load(checkpoint)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()

with open('caption data/WORDMAP_coco.json', 'r') as j:
    word_map = json.load(j)

rev_word_map = {v: k for k, v in word_map.items()}  # idx2word

In [None]:
def predict_output(image): 
    """
    Predict output with beam size of 1 (predict the word and feed it to the next LSTM). 
    Prints out the generated sentence
    """
    max_len = 20
    sampled = []
    rev_word_map = {v: k for k, v in word_map.items()}  # idx2word
    img = imread(image)
    img = imresize(img, (224, 224))
    img = img.transpose(2, 0, 1)
    img = img / 255.
    img = torch.FloatTensor(img).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([normalize])
    image = transform(img)  # (3, 224, 224)
    # Encode
    image = image.unsqueeze(0)  # (1, 3, 224, 224)
    enc_image,  global_features = encoder(image)
    num_pix = enc_image.shape[1]
    spatial_image = F.relu(decoder.encoded_to_hidden(enc_image))  # (batch_size,num_pixels,hidden_size)
    global_image = F.relu(decoder.global_features(global_features))      # (batch_size,embed_size)
    alphas_stored = torch.zeros(max_len, num_pix+1)
    betas_stored = torch.zeros(max_len,1)
    pred = torch.LongTensor([[word_map['<start>']]]).to(device)   # (1, 1)  
    betas_stored = torch.zeros(max_len,1)
    h,c = decoder.init_hidden_state(enc_image)                    #  (1,hidden_size)

    for timestep in range(max_len):
        embeddings = decoder.embedding(pred).squeeze(1)       # (1,1,embed_dim) --> (1,embed_dim)    
        inputs = torch.cat((embeddings,global_image), dim = 1)    # (1, embed_dim * 2)
        h, c, st = decoder.LSTM(inputs, (h, c))  # (1, hidden_size)
        # Run the adaptive attention model
        out, alpha, beta = decoder.adaptive_attention(spatial_image, h, st)
        # Compute the probability
        pt = decoder.fc(out)  
        _,pred = pt.max(1)
        sampled.append(pred.item())
        alphas_stored[timestep] = alpha
        betas_stored[timestep] = beta.item()
        
    generated_words = [rev_word_map[sampled[i]] for i in range(len(sampled))]
    filtered_words = ' '.join([word for word in generated_words if word != '<end>'])
    print(filtered_words)
    print(betas_stored)

In [None]:
predict_output('test_imgs/test1.jpg')

In [None]:
# Implementation with Beam Search
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
    
    k = beam_size
    vocab_size = len(word_map)
    infinite_pred = False

    # Read image and process
    img = imread(image_path)
    if len(img.shape) == 2:
        img = img[:, :, np.newaxis]
        img = np.concatenate([img, img, img], axis=2)
    img = imresize(img, (224, 224))
    img = img.transpose(2, 0, 1)
    img = img / 255.
    img = torch.FloatTensor(img).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([normalize])
    image = transform(img)  # (3, 224, 224)
    # Encode
    image = image.unsqueeze(0)  # (1, 3, 224, 224)
    enc_image, global_features = encoder(image) #enc_image of shape (batch_size,num_pixels,features)
    # Flatten encoding
    num_pixels = enc_image.size(1)
    encoder_dim = enc_image.size(2)
  
    # We'll treat the problem as having a batch size of k
    enc_image = enc_image.expand(k, num_pixels, encoder_dim)
    # Tensor to store top k previous words at each step; now they're just <start>
    k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)
    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)
    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)
    # Tensor to store top k sequences' alphas; now they're just 1s
    seqs_alpha = torch.ones(k, 1, 7, 7).to(device)  # (k, 1, enc_image_size, enc_image_size)
    #Tensor to store the top k sequences betas
    seqs_betas = torch.ones(k,1,1).to(device) 
    # Lists to store completed sequences, their alphas, betas and scores
    complete_seqs = list()
    complete_seqs_alpha = list()
    complete_seqs_scores = list()
    complete_seqs_betas = list()       
    # Start decoding
    step = 1
    h, c = decoder.init_hidden_state(enc_image)
    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    spatial_image = F.relu(decoder.encoded_to_hidden(enc_image))  # (k,num_pixels,hidden_size)
    global_image = F.relu(decoder.global_features(global_features))      # (1,embed_dim)
    
    while True:
        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (k,embed_dim)
        inputs = torch.cat((embeddings, global_image.expand_as(embeddings)), dim = 1)    
        h, c, st = decoder.LSTM(inputs , (h, c))  # (batch_size_t, hidden_size)
        # Run the adaptive attention model
        out_l, alpha, beta_t = decoder.adaptive_attention(spatial_image, h, st)
        alpha = alpha[:,:-1]
        alpha = alpha.view(-1, 7, 7)  # (s, enc_image_size, enc_image_size)
        # Compute the probability over the vocabulary
        scores = decoder.fc(out_l)      # (batch_size, vocab_size)
        scores = F.log_softmax(scores, dim=1)   # (s, vocab_size)
        # (k,1) will be (k,vocab_size), then (k,vocab_size) + (s,vocab_size) --> (s, vocab_size)
        scores = top_k_scores.expand_as(scores) + scores  
        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            #Remember: torch.topk returns the top k scores in the first argument, and their respective indices in the second argument
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / vocab_size  # (s) 
        next_word_inds = top_k_words % vocab_size  # (s) 
        # Add new words to sequences, alphas
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
        # (s, step+1, enc_image_size, enc_image_size)
        seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],dim=1)  
        seqs_betas = torch.cat([seqs_betas[prev_word_inds], beta_t[prev_word_inds].unsqueeze(1)], dim=1)  

        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if next_word != word_map['<end>']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
            complete_seqs_betas.extend(seqs_betas[complete_inds])   
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
            
        seqs = seqs[incomplete_inds]              
        seqs_alpha = seqs_alpha[incomplete_inds]   
        seqs_betas = seqs_betas[incomplete_inds]    
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        spatial_image = spatial_image[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            infinite_pred = True
            break
            
        step += 1
        
    if infinite_pred is not True:
        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]
    else:
        i = 0
        seq = seqs[i][:20]
        seq = [seq[j].item() for j in range(len(seq))]
        
    alphas = complete_seqs_alpha[i]
    betas = complete_seqs_betas[i] 

    return seq, alphas, betas     

In [None]:
def visualize_att(image_path, seq, alphas, betas, rev_word_map, smooth=True):
    """
    Visualizes caption with weights at every word.
    Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
    :param image_path: path to image that has been captioned
    :param seq: caption
    :param alphas: weights
    :param rev_word_map: reverse word mapping, i.e. ix2word
    :param smooth: smooth weights?
    """
    image = Image.open(image_path)
    image = image.resize([7 * 7, 7 * 7], Image.LANCZOS)
    words = [rev_word_map[ind] for ind in seq]
    print(' '.join(words[1:-1]))

    for t in range(len(words)):
        if t > 50:
            break
        plt.subplot(np.ceil(len(words) / 5.), 5, t + 1)
        plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
        plt.text(10, 65, '%.2f' % (1-(betas[t].item())), color='green', backgroundcolor='white', fontsize=15)
        plt.imshow(image)
        current_alpha = alphas[t, :]
        if smooth:
            alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=7, sigma=7)
        else:
            alpha = skimage.transform.resize(current_alpha.numpy(), [7 * 7, 7 * 7])
        if t == 0:
            plt.imshow(alpha, alpha=0)
        else:
            plt.imshow(alpha, alpha=0.8)
        plt.set_cmap('jet')
        plt.axis('off')
        
    plt.show()

In [None]:
%matplotlib inline
plt.rcParams['figure.figsize'] = (7, 7)  # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
# Encode, decode with attention and beam search k=3
seq, alphas, betas = caption_image_beam_search(encoder, decoder, 'test_imgs/test.jpg', word_map)
alphas = torch.FloatTensor(alphas)
# Visualize caption and attention of best sequence
visualize_att('test_imgs/test.jpg', seq, alphas, betas, rev_word_map, smooth=True)