In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from utils import load_checkpoint
from dataset import build_vocab, get_loaders
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu
from utils import print_scores

In [3]:
# local
DATA_JSON_PATH = 'data.json'
IMGS_PATH = 'flickr/Images/'
CHECKPOINT_PATH = 'models/BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101_fullvocab_fix_ds_rmsprop_finetune.pth.tar'
# kaggle paths
# DATA_JSON_PATH = '/kaggle/working/Image-Captioning/data.json'
# IMGS_PATH = '../input/flickr8kimagescaptions/flickr8k/images/'


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [5]:
# Load model
checkpoint = load_checkpoint(CHECKPOINT_PATH)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval();

Loaded Checkpoint!!
Last Epoch: 12
Best Bleu-4: 0.1655661704030113


In [6]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
vocab = build_vocab(DATA_JSON_PATH)
len(vocab)

100%|██████████| 40000/40000 [00:00<00:00, 370677.91it/s]


5089

In [8]:
bs = 1
beam_size=2
loader = get_loaders(bs, IMGS_PATH, DATA_JSON_PATH, transform, vocab, test=True, n_workers=8)

Dataset split: test
Unique images: 1000
Total size: 5000


In [9]:
import torch
import torch.nn.functional as F
import numpy as np
import json
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
# import skimage.transform
# import argparse
from PIL import Image
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):

    k = beam_size
    vocab_size = len(word_map)

    # Read image and process
    img = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    image = transform(img).to(device)  # (3, 256, 256)

    # Encode
    image = image.unsqueeze(0)  # (1, 3, 256, 256)
    encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)

    # Flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)

    # We'll treat the problem as having a batch size of k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

    # Tensor to store top k previous words at each step; now they're just <start>
    k_prev_words = torch.LongTensor([[word_map.stoi['<sos>']]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

    # Tensor to store top k sequences' alphas; now they're just 1s
    seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)

    # Lists to store completed sequences, their alphas and scores
    complete_seqs = list()
    complete_seqs_alpha = list()
    complete_seqs_scores = list()

    # Start decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)

    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    while True:

        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

        awe, alpha = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

        alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)

        gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

        scores = decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # Add
        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words // vocab_size  # (s)
        next_word_inds = top_k_words % vocab_size  # (s)
        
        # Add new words to sequences, alphas
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
        seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
                               dim=1)  # (s, step+1, enc_image_size, enc_image_size)
#         print(seqs[prev_word_inds], prev_word_inds)
#         if step == 5:
#             return seqs
        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                           next_word != word_map.stoi['<eos>']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]
        seqs_alpha = seqs_alpha[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[i]
    alphas = complete_seqs_alpha[i]

    return seq, alphas, complete_seqs

In [10]:
seq, _, comp_seqs = caption_image_beam_search(encoder, decoder, 'flickr/Images/3514019869_7de4ece2a5.jpg', vocab, beam_size=2)

In [11]:
[sent for sent in comp_seqs][0]

[1, 4, 5, 65, 13, 99, 856, 2]

In [12]:
[" ".join([vocab.itos[i] for i in sent]) for sent in comp_seqs]

['<sos> a dog runs through an obstacle <eos>',
 '<sos> a dog jumps over a hurdle <eos>']

In [13]:
[vocab.itos[i] for i in seq]

['<sos>', 'a', 'dog', 'runs', 'through', 'an', 'obstacle', '<eos>']

In [14]:
def evaluate(beam_size):

    references = list()
    hypotheses = list()

    # For each image
    for i, (image, caps, caplens, allcaps) in enumerate(
        tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):
        
        k = beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Encode
        encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(3)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[vocab.stoi['<sos>']]] * k).to(device)  # (k, 1)
        
        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

            awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

            gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe

            h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

            scores = decoder.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0)  # (s)
          
            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words // vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)
            
#             print(top_k_scores, top_k_words)
            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != vocab.stoi['<eos>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1
        
        if len(complete_seqs_scores) == 0:
            continue
        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(lambda c: [w for w in c if w not in {vocab.stoi['<sos>'], vocab.stoi['<eos>'], vocab.stoi['<pad>']}],
                img_caps))  # remove <start> and pads
        references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {vocab.stoi['<sos>'], vocab.stoi['<eos>'], vocab.stoi['<pad>']}])

        assert len(references) == len(hypotheses)

    # Calculate BLEU-4 scores
#     bleu4 = corpus_bleu(references, hypotheses)
    return references, hypotheses
    print_scores(references, hypotheses, nltk=True)


In [15]:
vocab_size = len(vocab)

In [16]:
vocab_size

5089

In [17]:
references, hypotheses = evaluate(3)

EVALUATING AT BEAM SIZE 3: 100%|██████████| 5000/5000 [03:29<00:00, 23.92it/s]


In [18]:
print_scores(references, hypotheses)

----- Bleu-n Scores -----
1: 64.62017684887459
2: 47.13660963689657
3: 33.759143347900405
4: 23.833249590210322
-------------------------
----- METEOR Score -----


AttributeError: 'NoneType' object has no attribute 'itos'

In [95]:
hs = [" ".join(word for word in sent) for sent in vocab.indextostring(hypotheses)]
rs = []
for r in references:
    rs.append([" ".join(word for word in sent) for sent in vocab.indextostring(r)])

In [102]:
hs[1]

'a black and white dog is running through the grass'

In [103]:
rs[2]

['a boy pushes a wagon full of pumpkins',
 'a boy pushes a wagon with two pumpkins',
 'a boy smiling leaning over a wagon filled with two large pumpkins',
 'a child squats behind a wagon with two pumpkins in it',
 'boy pushing wagon with two pumpkins in it']

In [105]:
from statistics import mean

total_meteor = 0

for r, h in tqdm(zip(rs, hs), total=len(rs)):
    total_meteor += meteor_score(r, h)

100%|██████████| 5000/5000 [00:10<00:00, 460.53it/s]


In [106]:
total_meteor/len(rs)

0.428133409401112

In [None]:
### turn outputs into strings -> bleu_score

In [76]:
print_scores(rs, hs)

----- Bleu-n Scores -----
1: 64.62017684887459
2: 47.13660963689657
3: 33.759143347900405
4: 23.833249590210322
-------------------------
----- METEOR Score -----


AttributeError: 'NoneType' object has no attribute 'itos'

In [29]:
from nltk.translate.meteor_score import meteor_score

In [65]:
import nltk
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /home/kelwa/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


True

In [102]:
hs = [" ".join([vocab.itos[i] for i in sent[0]]) for sent in hypotheses]
rs = []
for r in references:
    rs.append([" ".join([vocab.itos[i] for i in sent]) for sent in r])

In [12]:
vocab_size = len(vocab)

In [16]:
evaluate(2)

EVALUATING AT BEAM SIZE 2: 100%|██████████| 5000/5000 [03:27<00:00, 24.14it/s]


----- Bleu-n Scores -----
1: 63.54625550660793
2: 45.03356444855022
3: 31.69343570783961
4: 22.30082901251822
-------------------------


In [17]:
evaluate(3)

EVALUATING AT BEAM SIZE 3: 100%|██████████| 5000/5000 [03:57<00:00, 21.07it/s]


----- Bleu-n Scores -----
1: 64.31220201306728
2: 45.5251368754951
3: 32.38723129370634
4: 23.07429032664538
-------------------------


In [18]:
evaluate(1)

EVALUATING AT BEAM SIZE 1: 100%|██████████| 5000/5000 [02:59<00:00, 27.82it/s]


----- Bleu-n Scores -----
1: 61.123197163806296
2: 42.936593246185
3: 29.775037258795304
4: 20.646167109205283
-------------------------


In [19]:
for i in range(1, 6):
    print('*'*15, f"Beam size of {i}", '*'*15)
    evaluate(i)

EVALUATING AT BEAM SIZE 1:   0%|          | 0/5000 [00:00<?, ?it/s]

*************** Beam size of 1 ***************


EVALUATING AT BEAM SIZE 1: 100%|██████████| 5000/5000 [02:59<00:00, 27.92it/s]


----- Bleu-n Scores -----
1: 61.123197163806296
2: 42.936593246185
3: 29.775037258795304


EVALUATING AT BEAM SIZE 2:   0%|          | 0/5000 [00:00<?, ?it/s]

4: 20.646167109205283
-------------------------
*************** Beam size of 2 ***************


EVALUATING AT BEAM SIZE 2: 100%|██████████| 5000/5000 [03:58<00:00, 20.99it/s]


----- Bleu-n Scores -----
1: 63.54625550660793
2: 45.03356444855022
3: 31.69343570783961


EVALUATING AT BEAM SIZE 3:   0%|          | 0/5000 [00:00<?, ?it/s]

4: 22.30082901251822
-------------------------
*************** Beam size of 3 ***************


EVALUATING AT BEAM SIZE 3: 100%|██████████| 5000/5000 [04:43<00:00, 17.62it/s]


----- Bleu-n Scores -----
1: 64.31220201306728
2: 45.5251368754951
3: 32.38723129370634


EVALUATING AT BEAM SIZE 4:   0%|          | 0/5000 [00:00<?, ?it/s]

4: 23.07429032664538
-------------------------
*************** Beam size of 4 ***************


EVALUATING AT BEAM SIZE 4: 100%|██████████| 5000/5000 [04:58<00:00, 16.76it/s]


----- Bleu-n Scores -----
1: 64.40847503864691
2: 45.8716912526639
3: 32.61201588518963


EVALUATING AT BEAM SIZE 5:   0%|          | 0/5000 [00:00<?, ?it/s]

4: 23.235058922423306
-------------------------
*************** Beam size of 5 ***************


EVALUATING AT BEAM SIZE 5: 100%|██████████| 5000/5000 [05:38<00:00, 14.77it/s]


----- Bleu-n Scores -----
1: 64.8227213662521
2: 46.212067416932584
3: 32.929193446645684
4: 23.41989863202648
-------------------------
