In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from utils import load_checkpoint
from dataset import build_vocab, get_loaders
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu

In [5]:
DATA_NAME = 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101finetuneencoder'

# local
DATA_JSON_PATH = 'data.json'
IMGS_PATH = 'flickr/Images/'
CHECKPOINT_PATH = 'models/BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101_fullvocab_fix_ds_rmsprop_finetune.pth.tar'
# kaggle paths
# DATA_JSON_PATH = '/kaggle/working/Image-Captioning/data.json'
# IMGS_PATH = '../input/flickr8kimagescaptions/flickr8k/images/'


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [7]:
# Load model
checkpoint = load_checkpoint(CHECKPOINT_PATH)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval();

Loaded Checkpoint!!
Last Epoch: 12
Best Bleu-4: 0.1655661704030113


In [8]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
vocab = build_vocab(DATA_JSON_PATH)
len(vocab)

100%|██████████| 40000/40000 [00:00<00:00, 383978.61it/s]


5089

In [10]:
bs = 1
beam_size=2
loader = get_loaders(bs, IMGS_PATH, DATA_JSON_PATH, transform, vocab, test=True, n_workers=8)

Dataset split: test
Unique images: 1000
Total size: 5000


In [123]:
import torch
import torch.nn.functional as F
import numpy as np
import json
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.transform
import argparse
from PIL import Image
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
    """
    Reads an image and captions it with beam search.
    :param encoder: encoder model
    :param decoder: decoder model
    :param image_path: path to image
    :param word_map: word map
    :param beam_size: number of sequences to consider at each decode-step
    :return: caption, weights for visualization
    """

    k = beam_size
    vocab_size = len(word_map)

    # Read image and process
    img = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    image = transform(img).to(device)  # (3, 256, 256)

    # Encode
    image = image.unsqueeze(0)  # (1, 3, 256, 256)
    encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)

    # Flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)

    # We'll treat the problem as having a batch size of k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

    # Tensor to store top k previous words at each step; now they're just <start>
    k_prev_words = torch.LongTensor([[word_map.stoi['<sos>']]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

    # Tensor to store top k sequences' alphas; now they're just 1s
    seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)

    # Lists to store completed sequences, their alphas and scores
    complete_seqs = list()
    complete_seqs_alpha = list()
    complete_seqs_scores = list()

    # Start decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)

    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    while True:

        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

        awe, alpha = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

        alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)

        gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

        scores = decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # Add
        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words // vocab_size  # (s)
        next_word_inds = top_k_words % vocab_size  # (s)
        
        # Add new words to sequences, alphas
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
        seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
                               dim=1)  # (s, step+1, enc_image_size, enc_image_size)
#         print(seqs[prev_word_inds], prev_word_inds)
#         if step == 5:
#             return seqs
        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                           next_word != word_map.stoi['<eos>']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]
        seqs_alpha = seqs_alpha[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[i]
    alphas = complete_seqs_alpha[i]

    return seq, alphas, complete_seqs

In [143]:
seq, _, comp_seqs = caption_image_beam_search(encoder, decoder, 'flickr/Images/3514019869_7de4ece2a5.jpg', vocab, beam_size=2)

In [144]:
[sent for sent in comp_seqs][0]

[1, 4, 5, 65, 13, 99, 856, 2]

In [145]:
[" ".join([vocab.itos[i] for i in sent]) for sent in comp_seqs]

['<sos> a dog runs through an obstacle <eos>',
 '<sos> a dog jumps over a hurdle <eos>']

In [146]:
[vocab.itos[i] for i in seq]

['<sos>', 'a', 'dog', 'runs', 'through', 'an', 'obstacle', '<eos>']

In [20]:
def evaluate(beam_size):
    """
    Evaluation
    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU-4 score
    """

    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    for i, (image, caps, caplens, allcaps) in enumerate(
            tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):

        k = beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Encode
        encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(3)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[vocab.stoi['<sos>']]] * k).to(device)  # (k, 1)
        
        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

            awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

            gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe

            h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

            scores = decoder.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0)  # (s)

a                
            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words // vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)
            if prev_word_inds.item() != 0:
                print(prev_word_inds)
            
#             print(top_k_scores, top_k_words)
            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != vocab.stoi['<eos>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(lambda c: [w for w in c if w not in {vocab.stoi['<sos>'], vocab.stoi['<eos>'], vocab.stoi['<pad>']}],
                img_caps))  # remove <start> and pads
        references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {vocab.stoi['<sos>'], vocab.stoi['<eos>'], vocab.stoi['<pad>']}])

        assert len(references) == len(hypotheses)

    # Calculate BLEU-4 scores
    bleu4 = corpus_bleu(references, hypotheses)

    return bleu4



In [17]:
vocab_size = len(vocab)
evaluate(1)

EVALUATING AT BEAM SIZE 1:   4%|▍         | 223/5000 [00:07<02:41, 29.49it/s]


KeyboardInterrupt: 

In [54]:
vocab.stoi
evaluate(1)

EVALUATING AT BEAM SIZE 1: 100%|██████████| 5000/5000 [02:42<00:00, 30.81it/s]


0.20798872886259115

In [55]:
evaluate(2)

EVALUATING AT BEAM SIZE 2: 100%|██████████| 5000/5000 [03:09<00:00, 26.34it/s]


0.22867002491504052

In [28]:
references = list()
hypotheses = list()

for i, (image, caps, caplens, allcaps) in enumerate(
        tqdm(test_loader, desc=f'Evaluating at Beam size {beam_size}')):
    
    k = beam_size
    
    image = image.to(device)
    
    # encoder
    encoder_out = encoder(image) # [1, enc_img_size, enc_img_size, encoder_dim]
    enc_img_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    
    # flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim) # [1, num_pixels, encoder_dim]
    num_pixels = encoder_out.size(1)
    
    # treat it as we have a batch of size k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # [k, num_pixels, encoder_dim]
    
    # tensor to store top k previous words at each step; currently it's only <sos>
    k_prev_words = torch.LongTensor([[vocab.stoi['<sos>']]]*k).to(device) # [k, 1]
    
    
    
    # tensor to store top k sequences
    seqs = k_prev_words
    
    # tensor to store top k sequences' scores
    top_k_scores = torch.zeros(k, 1).to(device)
    
    # lists to store completed sequences and scores
    complete_seqs = list()
    complete_seqs_scores = list()
    
    # decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)
    
    while True:
        embeddings = decoder.embedding(k_prev_words).squeeze(1) # [s, embed_dim]
        
        awe, _ = decoder.attention(encoder_out, h) # [s, encoder_dim], [s, num_pixels]
        
        gate = decoder.sigmoid(decoder.f_beta(h))
        awe = gate * awe
        
        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))
        
        scores = decoder.fc(h) # [s, vocab size]
        scores = F.log_softmax(scores, dim=1)
        break
        # add 
        scores = top_k_scores.expand_as(scores) + scores # [s, vocab_size]
                
        # for the first step: all k points will have the same score; since same k previous words, h,c
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # [s]
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # [s]
            
        # convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / len(vocab)
        next_word_inds = top_k_words % len(vocab)
        
        print(top_k_words)
        print(prev_word_inds)
        print(next_word_inds)
        
        # add new words to sequences
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
        
        # which sequences are incomplete - didn't reach <end>
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if next_word != vocab['<eos>']]
        
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
        
        # set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds) # reduce beam length accordingly
        
        # proceed with incomplete sequences
        if k == 0:
            break
        
        seqs = seqs[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
        
        # break if things have been going on too long
        if step > 50:
            break
        step += 1
        
    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[1]
    
    # references 
    for j in range(allcaps.shape[0]):
            img_caps = allcaps[j].tolist()
            references.append(vocab.indextostring(img_caps))

    # hypotheses
    hypotheses.append([w for w in seq if w not in {vocab.stoi['<sos>'], vocab.stoi['<eos>'], vocab.stoi['<pad>']}])
    
    assert len(references) == len(hypotheses)
    
bleu4 = corpus_bleu(references, hypotheses)

print(bleu4)

Evaluating at Beam size 2:   0%|          | 0/5000 [00:00<?, ?it/s]


ValueError: max() arg is an empty sequence

In [None]:
w, ww = scores.view(-1).topk(k, 0, True, True)

In [20]:
w, ww

(tensor([-0.1701, -0.1701], device='cuda:0', grad_fn=<TopkBackward>),
 tensor([5093,    4], device='cuda:0'))