In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [14]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from utils import load_checkpoint
from dataset import build_vocab, get_loaders
from tqdm import tqdm

In [15]:
vocabulary = build_vocab('data.json') 

100%|██████████| 30000/30000 [00:00<00:00, 400361.20it/s]


In [16]:
DATA_NAME = 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101'

# local
DATA_JSON_PATH = 'data.json'
IMGS_PATH = 'flickr/Images/'
CHECKPOINT_PATH = 'models/BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101.pth.tar'
# kaggle paths
# DATA_JSON_PATH = '/kaggle/working/Image-Captioning/data.json'
# IMGS_PATH = '../input/flickr8kimagescaptions/flickr8k/images/'


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [18]:
# Load model
checkpoint = load_checkpoint(CHECKPOINT_PATH)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval();

Loaded Checkpoint!!
Last Epoch: 8
Best Bleu-4: 0.13934344941076532


In [19]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [20]:
vocab = build_vocab(DATA_JSON_PATH)
len(vocab)

100%|██████████| 30000/30000 [00:00<00:00, 290537.37it/s]


4451

In [21]:
bs = 1
beam_size=1
test_loader = get_loaders(bs, IMGS_PATH, DATA_JSON_PATH, transform, vocab, test=True, n_workers=8)

Dataset split: test
Unique images: 1000
Total size: 5000


In [25]:
references = list()
hypotheses = list()

for i, (image, caps, caplens, allcaps) in enumerate(
        tqdm(test_loader, desc=f'Evaluating at Beam size {beam_size}')):
    
    k = beam_size
    
    image = image.to(device)
    
    # encoder
    encoder_out = encoder(image) # [1, enc_img_size, enc_img_size, encoder_dim]
    enc_img_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    
    # flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim) # [1, num_pixels, encoder_dim]
    num_pixels = encoder_out.size(1)
    
    # treat it as we have a batch of size k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # [k, num_pixels, encoder_dim]
    
    # tensor to store top k previous words at each step; currently it's only <sos>
    k_prev_words = torch.LongTensor([[vocab.stoi['<sos>']]]*k).to(device) # [k, 1]
    
    # tensor to store top k sequences
    seqs = k_prev_words
    
    # tensor to store top k sequences' scores
    top_k_scores = torch.zeros(k, 1).to(device)
    
    # lists to store completed sequences and scores
    complete_seqs = list()
    complete_seqs_scores = list()
    
    # decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)
    
    while True:
        embeddings = decoder.embedding(k_prev_words).squeeze(1) # [s, embed_dim]
        
        awe, _ = decoder.attention(encoder_out, h) # [s, encoder_dim], [s, num_pixels]
        
        gate = decoder.sigmoid(decoder.f_beta(h))
        awe = gate * awe
        
        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))
        
        scores = decoder.fc(h) # [s, vocab size]
        scores = F.log_softmax(scores, dim=1)
        
        # add 
        scores = top_k_scores.expand_as(scores) + scores # [s, vocab_size]
        
        # for the first step: all k points will have the same score; since same k previous words, h,c
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)
        else:
            # unroll and find top scores and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)
            
        # convert unrolled indices to actual indices of scores
        print(prev_word_inds)
        prev_word_inds = top_k_words / len(vocab)
        next_word_inds = top_k_words % len(vocab)
        
        print(prev_word_inds)
        print(next_word_inds)
        
        # add new words to sequences
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
        
        # which sequences are incomplete - didn't reach <end>
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if next_word != vocab['<eos>']]
        
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
        
        # set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds) # reduce beam length accordingly
        
        # proceed with incomplete sequences
        if k == 0:
            break
        
        seqs = seqs[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
        
        # break if things have been going on too long
        if step > 50:
            break
        step += 1
        
    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[1]
    
    # references 
    for j in range(allcaps.shape[0]):
            img_caps = allcaps[j].tolist()
            references.append(vocab.indextostring(img_caps))

    # hypotheses
    hypotheses.append([w for w in seq if w not in {vocab.stoi['<sos>'], vocab.stoi['<eos>'], vocab.stoi['<pad>']}])
    
    assert len(references) == len(hypotheses)
    
bleu4 = corpus_bleu(references, hypotheses)

print(bleu4)

Evaluating at Beam size 1:   0%|          | 0/5000 [00:00<?, ?it/s]

tensor([0.0009], device='cuda:0')
tensor([0.0009], device='cuda:0')
tensor([4], device='cuda:0')





IndexError: tensors used as indices must be long, byte or bool tensors

In [26]:
len(vocab)

4451