<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fgeneration/applications/generation/image_captioning/Image%20Captioning%20with%20Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Flickr8k Dataset

In [None]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"ravirajag","key":"3ebe1017281784a1ca4b048c597b0c87"}'}

In [None]:
!mkdir ~/.kaggle

In [None]:
!cp kaggle.json ~/.kaggle

In [None]:
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d adityajn105/flickr8k

Downloading flickr8k.zip to /content
 99% 1.03G/1.04G [00:32<00:00, 40.0MB/s]
100% 1.04G/1.04G [00:32<00:00, 34.4MB/s]


In [None]:
!unzip -qq flickr8k.zip -d flickr8k

In [None]:
!ls flickr8k

captions.txt  Images


### Imports

In [None]:
import os
import time
import nltk
import pickle

import numpy as np

from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data

from torch.nn.utils.rnn import pack_padded_sequence
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### Captions

In [None]:
captions = dict()

with open('flickr8k/captions.txt', 'r') as f:
    lines= f.readlines()
    for idx, line in enumerate(lines[1:]):
        index=line.index(',')
        img_id = line[:index].split('.')[0]
        img_caption = line[index+1:].strip()
        captions[idx] = {"image_id": img_id, "caption": img_caption}


In [None]:
len(captions)

40455

In [None]:
captions[90]

{'caption': 'A black dog running in the surf .',
 'image_id': '1022975728_75515238d8'}

### Vocabulary

In [None]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

In [None]:
def build_vocab(captions_dict, threshold):
    """Build a simple vocabulary wrapper."""
    counter = Counter()
    for i, id in enumerate(captions_dict):
        caption = captions_dict[id]["caption"]
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 5000 == 0:
            print("[{}/{}] Tokenized the captions.".format(i+1, len(captions_dict)))

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

In [None]:
vocab = build_vocab(captions, threshold=5)
vocab_path = 'vocab.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))

[5000/40455] Tokenized the captions.
[10000/40455] Tokenized the captions.
[15000/40455] Tokenized the captions.
[20000/40455] Tokenized the captions.
[25000/40455] Tokenized the captions.
[30000/40455] Tokenized the captions.
[35000/40455] Tokenized the captions.
[40000/40455] Tokenized the captions.
Total vocabulary size: 3006
Saved the vocabulary wrapper to 'vocab.pkl'


### Image Processing

In [None]:
def resize_image(image, size):
    return image.resize(size, Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 1000 == 0:
            print (f"[{i+1}/{num_images}] Resized the images and saved into '{output_dir}'.")
    

In [None]:
image_dir = "flickr8k/Images"
processed_image_dir = "processed_images"
size = [256, 256]

In [None]:
resize_images(image_dir, processed_image_dir, size)

[1000/8091] Resized the images and saved into 'processed_images'.
[2000/8091] Resized the images and saved into 'processed_images'.
[3000/8091] Resized the images and saved into 'processed_images'.
[4000/8091] Resized the images and saved into 'processed_images'.
[5000/8091] Resized the images and saved into 'processed_images'.
[6000/8091] Resized the images and saved into 'processed_images'.
[7000/8091] Resized the images and saved into 'processed_images'.
[8000/8091] Resized the images and saved into 'processed_images'.


### Caption Dataset

In [None]:
class CaptionDataset(data.Dataset):
    def __init__(self, image_dir, annotations, vocab, transform=None):
        self.image_dir = image_dir
        self.annotations = annotations
        self.vocab = vocab
        self.transform = transform
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, item):
        img_id = self.annotations[item]["image_id"]
        caption = self.annotations[item]["caption"]

        image = Image.open(os.path.join(self.image_dir, img_id+'.jpg')).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        
        caption = torch.Tensor(caption)
        return image, caption

In [None]:
# Image preprocessing, normalization for the pretrained resnet
crop_size = 224
transform = transforms.Compose([ 
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), 
                            (0.229, 0.224, 0.225))])

In [None]:
train_captions, valid_captions = train_test_split(captions, test_size=0.1, random_state=42)

In [None]:
len(train_captions), len(valid_captions)

(36409, 4046)

In [None]:
image_dir = "processed_images"

train_caption_dataset = CaptionDataset(image_dir, train_captions, vocab, transform)
valid_caption_dataset = CaptionDataset(image_dir, valid_captions, vocab, transform)

### DataLoader

In [None]:
def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

In [None]:
train_data_loader = torch.utils.data.DataLoader(
    dataset=train_caption_dataset, 
    batch_size=32,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn)

valid_data_loader = torch.utils.data.DataLoader(
    dataset=valid_caption_dataset, 
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn)

In [None]:
sample = next(iter(train_data_loader))
sample[0].shape, sample[1].shape, len(sample[2])

(torch.Size([32, 3, 224, 224]), torch.Size([32, 24]), 32)

### Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super().__init__()

        resnet = models.resnet101(pretrained=True)

        # remove the linear and pooling layers
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

    
    def forward(self, images):
        # images => [batch_size, 3, 224, 224]
        with torch.no_grad():
            features = self.resnet(images)
            # features => [batch_size, 2048, 7, 7]
        
        pooled = self.adaptive_pool(features)
        # pooled => [batch_size, 2048, encoded_img_size, encoded_img_size]
        #        => [batch_size, 2048, 14, 14]

        out = pooled.permute(0, 2, 3, 1)
        # out => [batch_size, 14, 14, 2048]

        return out

### Attention


In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()

        self.encoder_attn = nn.Linear(encoder_dim, attention_dim)
        self.decoder_attn = nn.Linear(decoder_dim, attention_dim)
        self.full_attn = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, encoder_out, decoder_hidden):
        enc_attn = self.encoder_attn(encoder_out)
        # enc_attn => [batch_size, num_pixels, attention_dim]

        dec_attn = self.decoder_attn(decoder_hidden)
        # dec_attn => [batch_size, attention_dim]

        attn = self.full_attn(self.relu(enc_attn + dec_attn.unsqueeze(1)))
        # attn => [batch_size, num_pixels, 1]

        attn = attn.squeeze(2)
        # attn => [batch_size, num_pixels]

        alpha = self.softmax(attn)
        # alpha => [batch_size, num_pixels]

        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        # attention_weighted_encoding => [batch_size, encoder_dim]

        return attention_weighted_encoding, alpha

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super().__init__()

        self.vocab_size = vocab_size
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()

        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(decoder_dim, vocab_size)
        self.init_weights()

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.out.bias.data.fill_(0)
        self.out.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        # encoder_out => [batch_size, num_pixels, encoder_dim]

        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, captions, lengths):
        # encoder_out => [batch_size, enc_image_size, enc_image_size, encoder_dim]
        # captions => [batch_size, max_seq_len]
        # lengths => [batch_size]

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels = encoder_out.size(1)

        embeddings = self.embedding(captions)
        # embeddings => [batch_size, max_seq_len, emb_dim]

        # initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = [length - 1 for length in lengths]

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.out(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        
        return predictions, alphas


### Model

In [None]:
embed_dim = 256
decoder_dim = 512
attention_dim = 512

In [None]:
encoder = Encoder().to(device)
decoder = Decoder(attention_dim, embed_dim, decoder_dim, len(vocab)).to(device)

### Loss & Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
params = list(decoder.parameters()) + list(encoder.adaptive_pool.parameters())
optimizer = torch.optim.Adam(params, lr=4e-4)

### Configurations

In [None]:
num_epochs = 2
log_step = 100
model_path = "models"

In [None]:
if not os.path.exists(model_path):
    os.makedirs(model_path)

### Train Method

In [None]:
def train(data_loader, device, alpha_c=1.0):

    epoch_loss = 0
    total_steps = len(data_loader)
    encoder.train()
    decoder.train()
    for i, (images, captions, lengths) in enumerate(data_loader):
        images = images.to(device)
        captions = captions.to(device)

        encoded_img = encoder(images)
        predictions, alphas = decoder(encoded_img, captions, lengths)

        decode_lengths = [length - 1 for length in lengths]

        outputs = pack_padded_sequence(predictions, decode_lengths, batch_first=True)[0]
        targets = pack_padded_sequence(captions, decode_lengths, batch_first=True)[0]

        loss = criterion(outputs, targets)
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
        epoch_loss += loss.item()

        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(params, 5.)
        optimizer.step()

        # Print log info
        if i % log_step == 0:
            print(f'Step [{i}/{total_steps}], Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}') 
            
    return epoch_loss / total_steps

### Validation Method

In [None]:
def evaluate(data_loader, device, alpha_c=1.0):
    epoch_loss = 0
    total_steps = len(data_loader)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        for i, (images, captions, lengths) in enumerate(data_loader):
            images = images.to(device)
            captions = captions.to(device)

            encoded_img = encoder(images)
            predictions, alphas = decoder(encoded_img, captions, lengths)

            decode_lengths = [length - 1 for length in lengths]

            outputs = pack_padded_sequence(predictions, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(captions, decode_lengths, batch_first=True)[0]

            loss = criterion(outputs, targets)
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            epoch_loss += loss.item()

            # Print log info
            if i % log_step == 0:
                print(f'Step [{i}/{total_steps}], Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}') 
                
    return epoch_loss / total_steps

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

### Training

In [None]:
best_valid_loss = float('inf')
for epoch in range(num_epochs):
    
    start_time = time.time()
    train_loss = train(train_data_loader, device)
    val_loss = evaluate(valid_data_loader, device)
    end_time = time.time()

    elapsed_mins, elapsed_secs = epoch_time(start_time, end_time)

    if val_loss < best_valid_loss:
        best_valid_loss = val_loss
        torch.save(decoder.state_dict(), os.path.join(
            model_path, 'decoder-{}.ckpt'.format(epoch+1)))
        torch.save(encoder.state_dict(), os.path.join(
            model_path, 'encoder-{}.ckpt'.format(epoch+1)))
    
    print(f"Epoch: {epoch+1:02} | Time: {elapsed_mins}m {elapsed_secs}s")
    print(f"\t Train Loss: {train_loss:.3f} | Train PPL: {np.exp(train_loss):5.4f} | Valid Loss: {val_loss:.3f} | Valid PPL: {np.exp(val_loss):5.4f}")

Step [0/1138], Loss: 8.9380, Perplexity: 7616.2601
Step [100/1138], Loss: 4.6478, Perplexity: 104.3503
Step [200/1138], Loss: 3.5743, Perplexity: 35.6696
Step [300/1138], Loss: 2.9959, Perplexity: 20.0039
Step [400/1138], Loss: 2.5779, Perplexity: 13.1695
Step [500/1138], Loss: 2.2090, Perplexity: 9.1069
Step [600/1138], Loss: 1.8936, Perplexity: 6.6430
Step [700/1138], Loss: 1.8594, Perplexity: 6.4202
Step [800/1138], Loss: 1.7878, Perplexity: 5.9762
Step [900/1138], Loss: 1.7243, Perplexity: 5.6086
Step [1000/1138], Loss: 1.5251, Perplexity: 4.5958
Step [1100/1138], Loss: 1.4220, Perplexity: 4.1452
Step [0/253], Loss: 1.5548, Perplexity: 4.7340
Step [100/253], Loss: 1.3285, Perplexity: 3.7753
Step [200/253], Loss: 1.3164, Perplexity: 3.7301
Epoch: 01 | Time: 16m 10s
	 Train Loss: 2.497 | Train PPL: 12.1511 | Valid Loss: 1.359 | Valid PPL: 3.8906
Step [0/1138], Loss: 1.3846, Perplexity: 3.9934
Step [100/1138], Loss: 1.3800, Perplexity: 3.9751
Step [200/1138], Loss: 1.3966, Perplexity:

In [None]:
!ls -lah models

total 421M
drwxr-xr-x 2 root root 4.0K Jul 13 17:46 .
drwxr-xr-x 1 root root 4.0K Jul 13 17:13 ..
-rw-r--r-- 1 root root  48M Jul 13 17:30 decoder-1.ckpt
-rw-r--r-- 1 root root  48M Jul 13 17:46 decoder-2.ckpt
-rw-r--r-- 1 root root 163M Jul 13 17:30 encoder-1.ckpt
-rw-r--r-- 1 root root 163M Jul 13 17:46 encoder-2.ckpt


### Inference

In [None]:
device

device(type='cuda')

In [None]:
vocab_path = "vocab.pkl"
# Load vocabulary wrapper
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

In [None]:
# Build models
encoder = Encoder()
decoder = Decoder(attention_dim, embed_dim, decoder_dim, len(vocab))

In [None]:
encoder.to(device).eval()
decoder.to(device).eval()

Decoder(
  (attention): Attention(
    (encoder_attn): Linear(in_features=2048, out_features=512, bias=True)
    (decoder_attn): Linear(in_features=512, out_features=512, bias=True)
    (full_attn): Linear(in_features=512, out_features=1, bias=True)
    (relu): ReLU()
    (softmax): Softmax(dim=1)
  )
  (embedding): Embedding(3006, 256)
  (decode_step): LSTMCell(2304, 512)
  (init_h): Linear(in_features=2048, out_features=512, bias=True)
  (init_c): Linear(in_features=2048, out_features=512, bias=True)
  (f_beta): Linear(in_features=512, out_features=2048, bias=True)
  (sigmoid): Sigmoid()
  (dropout): Dropout(p=0.5, inplace=False)
  (out): Linear(in_features=512, out_features=3006, bias=True)
)

In [None]:
encoder_path = "models/encoder-2.ckpt"
decoder_path = "models/decoder-2.ckpt"

# Load the trained model parameters
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

<All keys matched successfully>

In [None]:
transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

In [None]:
def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

In [None]:
def inference(image_path, beam_size=3):
    k = beam_size
    vocab_size = len(vocab)

    # Prepare an image
    image = load_image(image_path, transform)

    image_tensor = image.to(device)
    with torch.no_grad():
        encoded_img = encoder(image_tensor)
        # encoded_img => [1, enc_img_size, enc_img_size, encoder_dim]

        enc_img_size = encoded_img.size(1)
        encoded_dim = encoded_img.size(3)

        # Flatten encoding
        encoder_out = encoded_img.view(1, -1, encoded_dim)
        num_pixels = encoder_out.size(1)

        encoder_out = encoder_out.expand(k, num_pixels, encoded_dim)

        k_prev_words = torch.LongTensor([[vocab.word2idx['<start>']]] * k).to(device)
        # k_prev_words => [k, 1]

        # tensor to store top k sequences
        seqs = k_prev_words
        # seqs => [k, 1]

        # tensor to store top k sequences scores
        top_k_scores = torch.zeros(k, 1).to(device)
        # top_k_scores => [k, 1]

        # tensor to store top k sequences attentions
        seqs_alpha = torch.ones(k, 1, enc_img_size, enc_img_size).to(device)
        # seqs_alpha => [k, 1, enc_img_size, enc_img_size]

        # to store completed seqs, scores, attentions
        complete_seqs = list()
        complete_seqs_alpha = list()
        complete_seqs_scores = list()

        # decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:
            embeddings = decoder.embedding(k_prev_words).squeeze(1)
            # embeddings => [s, emb_dim]

            awe, alpha = decoder.attention(encoder_out, h)
            # awe => [s, encoder_dim]
            # alpha => [s, num_pixels]

            alpha = alpha.view(-1, enc_img_size, enc_img_size)
            # alpha => [s, enc_img_size, enc_img_size]

            gate = decoder.sigmoid(decoder.f_beta(h))
            # gate => [s, encoder_dim]

            awe = gate * awe
            # awe => [s, encoder_dim]

            h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))
            # h, c => [s, decoder_dim]

            scores = decoder.out(h)
            # scores => [s, vocab_dim]

            scores = F.log_softmax(scores, dim=1)

            # add the current word scores to the previous one
            scores = top_k_scores.expand_as(scores) + scores
            # scores => [s, vocab_size]

            # for the first step all the k points will have same scores
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)
            else:
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)
            
            prev_word_inds = top_k_words / vocab_size
            next_word_inds = top_k_words % vocab_size

            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # s+1
            seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], dim=1) # s+1

            # incomplete sequences
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if next_word != vocab.word2idx['<end>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)

            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            seqs_alpha = seqs_alpha[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]

            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            if step > 20:
                break
            step += 1
        
        if len(complete_seqs) == 0:
            complete_seqs = seqs.tolist()
            complete_seqs_scores = top_k_scores.tolist()
            complete_seqs_alpha = seqs_alpha.tolist()

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]
        alphas = complete_seqs_alpha[i]

    return seq, alphas

In [None]:
image_path = "flickr8k/Images/1009434119_febe49276a.jpg"
seq, _ = inference(image_path)
print(seq)

[1, 1, 572, 572, 572, 572, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]


In [None]:
image_path = "flickr8k/Images/1016887272_03199f49c4.jpg"
seq, _ = inference(image_path)
print(seq)

[1, 1, 572, 572, 572, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]


In [None]:
image_path = "flickr8k/Images/1022975728_75515238d8.jpg"
inference(image_path)

([1,
  1,
  572,
  572,
  572,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9,
  9],
 [[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,