# Image Captioning Using Deep Learning With Attention Mechanism

In [None]:
import os
import json
import random
import numpy as np
from scipy.misc import imread, imresize
from collections import Counter
from tqdm import tqdm_notebook

import nltk
from nltk.translate.bleu_score import corpus_bleu

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence

import torchvision
import torchvision.transforms as transforms

### Set Configs

In [None]:
# model configs
EMBEDDING_SIZE = 512
ATTENTION_SIZE = 512
DECODER_SIZE = 512
DROPOUT = 0.5

# training configs
N_EPOCHS = 129
BATCH_SIZE = 32
ENCODER_LR = 1e-4 # learning rate for encoder if fine-tuning
DECODER_LR = 4e-4 # learning rate for decoder
GRAD_CLIP = 5.
ALPHA_C = 1.
FINE_TUNE_ENCODER = False

In [None]:
# data configs
DATA_FOLDER = './datasets/' # folder with data files saved by create_input_files
DATA_NAME = 'coco_5_cap_per_img_5_min_word_freq'

In [None]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Set Utils

In [None]:
def clip_gradient(optimizer, grad_clip):
    
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None: param.grad.data.clamp_(-grad_clip, grad_clip)

In [None]:
def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, 
                    encoder_optimizer, decoder_optimizer, bleu4, is_best):
    
    state = { 'epoch': epoch,
              'epochs_since_improvement': epochs_since_improvement,
              'bleu-4': bleu4,
              'encoder': encoder,
              'decoder': decoder,
              'encoder_optimizer': encoder_optimizer,
              'decoder_optimizer': decoder_optimizer }
    
    if not os.path.exists('./weights/'): os.makedirs('./weights/')
    filename = './weights/checkpoint_' + data_name + '.pth.tar'
    torch.save(state, filename)
    
    # if this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
    if is_best: torch.save(state, 'BEST_' + filename)

In [None]:
class AverageMeter(object):
    
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def adjust_learning_rate(optimizer, shrink_factor):
    
    print('\nDecaying Learning rate...')
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print(f'The new learning rate is {optimizer.param_groups[0]["lr"]}\n')

In [None]:
def calculate_accuracy(scores, targets, k):
    
    batch_size = targets.size(0)
    _, index = scores.topk(k, 1, True, True)
    correct = index.eq(targets.view(-1, 1).expand_as(index))
    correct_total = correct.view(-1).float().sum()
    
    return correct_total.item() * (100.0 / batch_size)

In [None]:
def create_input_files(datasets, karpathy_json_path, image_dir, output_dir, captions_per_image, min_word_freq, max_length=100):
    
    assert datasets in {'coco', 'flickr8k', 'flickr30k'}
    
    # read Karpathy's json
    with open(karpathy_json_path, 'r') as file:
        data = json.load(file)
        
    # read image paths and captions for each image
    train_image_paths = []
    train_image_captions = []
    valid_image_paths = []
    valid_image_captions = []
    test_image_paths = []
    test_image_captions = []
    
    word_freq = Counter()
    
    for image in tqdm_notebook(data['images']):
        captions = []
        for sentence in image['sentences']:
            word_freq.update(sentence['tokens'])
            if len(sentence['tokens']) <= max_length:
                captions.append(sentence['tokens'])
                
        if len(captions) == 0:
            continue
            
        path = os.path.join(image_dir, image['filepath'], image['filename']) if datasets == 'coco' \
                                                                             else os.path.join(image_dir, image['filename'])
        
        if image['split'] in {'train', 'restval'}:
            train_image_paths.append(path)
            train_image_captions.append(captions)
        elif image['split'] in {'val'}:
            valid_image_paths.append(path)
            valid_image_captions.append(captions)
        elif image['split'] in {'test'}:
            test_image_paths.append(path)
            test_image_captions.append(captions)
            
    # sanity check
    assert len(train_image_paths) == len(train_image_captions)
    assert len(valid_image_paths) == len(valid_image_captions)
    assert len(test_image_paths) == len(test_image_captions)
    
    # create vocabulary
    words = [word for word in word_freq.keys() if word_freq[word] > min_word_freq]
    word_vocab = { key: value + 1 for value, key in enumerate(words)}
    word_vocab['<unk>'] = len(word_vocab) + 1
    word_vocab['<start>'] = len(word_vocab) + 1
    word_vocab['<end>'] = len(word_vocab) + 1
    word_vocab['<pad>'] = 0
    
    # create a base/ root name for all output files
    base_filename = datasets + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'
    
    # save word vocabulary to a JSON
    with open(os.path.join(output_dir, 'data/' + 'WORD_VOCAB_' + base_filename + '.json'), 'w') as file:
        json.dump(word_vocab, file)
        
    # sample captions for each image, save images to HDF5 file and captions and their lengths to JSON files
    random.seed(9)
    for image_paths, image_captions, split in [(train_image_paths, train_image_captions, 'TRAIN'),
                                              (valid_image_paths, valid_image_captions, 'VALID'),
                                              (test_image_paths, test_image_captions, 'TEST')]:
        
        with h5py.File(os.path.join(output_dir, 'data/' + split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as file:
            
            # make a note of the number of captions we are sampling per image
            file.attrs['captions_per_image'] = captions_per_image
            
            # create dataset inside HDF5 file to store images
            images = file.create_dataset('./datasets/images', (len(image_paths), 3, 256, 256), dtype='uint8')
            
            print(f'\nReading {split} images and captions, storing to file...\n')
            
            encoded_captions = []
            captions_length = []
            
            for i, path in enumerate(image_paths):
                
                # sample captions
                if len(image_captions[i]) < captions_per_image:
                    captions = image_captions[i] + [random.choice(image_captions[i]) for _ in range(captions_per_image - len(image_captions[i]))]
                else:
                    captions = random.sample(image_captions[i], k=captions_per_image)
                    
                # sanity check
                assert len(captions) == captions_per_image
                
                # read images
                image = imread(image_paths[i])
                if len(image.shape) == 2:
                    image = image[:, :, np.newaxis]
                    image = np.concatenate([image, image, image], axis=2)
                image = imresize(image, (256, 256))
                image = image.transpose(2, 0, 1)
                
                # sanity check
                assert image.shape == (3, 256, 256)
                assert np.max(image) <= 255
                
                # save image to HDF5 file
                images[i] = image
                
                for j, caption in enumerate(captions):
                    # encode captions
                    encoded_caption = [word_vocab['<start>']] + [word_vocab.get(word, word_vocab['<unk>']) for word in caption] +\
                                      [word_vocab['<end>']] + [word_vocab['<pad>']] * (max_length - len(caption))
                        
                    # find caption lengths
                    caption_length = len(caption) + 2
                    
                    encoded_captions.append(encoded_caption)
                    captions_length.append(caption_length)
            
            # sanity check
            assert images.shape[0] * captions_per_image == len(encoded_captions) == len(captions_length)
            
            # save encoded captions and their lengths to JSON files
            with open(os.path.join(output_dir, 'data/' + split + '_CAPTIONS_' + base_filename + '.json'), 'w') as file:
                json.dump(encoded_captions, file)
            
            with open(os.path.join(output_dir, 'data/' + split + '_CAPLENS_' + base_filename + '.json'), 'w') as file:
                json.dump(captions_length, file)

    return word_vocab

In [None]:
word_vocab = create_input_files(datasets='coco', karpathy_json_path='./datasets/karpathy_captions/datasets_coco.json',
                                image_dir='./datasets/', output_dir='./datasets/',
                                captions_per_image=5,
                                min_word_freq=5,
                                max_length=50)

In [None]:
word_vocab_file = os.path.join(DATA_FOLDER, 'data/' + 'WORD_VOCAB_' + DATA_NAME + '.json')
with open(word_vocab_file, 'r') as file: word_vocab = json.load(file)

## Set Data Loader

In [None]:
class CaptionDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_folder, data_name, split, transform=None):
        
        super(CaptionDataset, self).__init__()
        
        self.split = split
        assert self.split in {'TRAIN', 'VALID', 'TEST'}
        
        # open hdf5 file where images are stored
        import h5py
        
        self.hdf5 = h5py.File(os.path.join(data_folder, 'data/' + self.split + '_IMAGES_' + data_name + '.hdf5'), 'r')
        self.images = self.hdf5['images']
        
        # captions per image
        self.cpi = self.hdf5.attrs['captions_per_image']
        
        # load encoded captions (completely into memory)
        with open(os.path.join(data_folder, 'data/' + self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as file:
            self.captions = json.load(file)
            
        # load captions lengths (completely into memory)
        with open(os.path.join(data_folder, 'data'/ + self.split + '_CAPLENS_' + data_name + '.json'), 'r') as file:
            self.caplens = json.load(file)
            
        # pytorch transformation pipeline for the image (normalizing, etc.)
        self.transform = transform
        
        # total number of data points
        self.dataset_size = len(self.captions)
        
    def __getitem_(self, i):
        
        # remember, the Nth caption corresponds to the (N // captions_per_image)th image
        image = torch.FloatTensor(self.images[i // self.cpi] / 255.)
        if self.transform is not None:
            image = self.transform(image)
            
        caption = torch.LongTensor(self.captions[i])
        caplen = torch.LongTensor([self.caplens[i]])
        
        if self.split is 'TRAIN':
            return image, caption, caplen
        else:
            # for validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score
            all_captions = torch.LongTensor(
                self.captions[((i // self.cpi) * self.cpi) : (((i // self.cpi) * self.cpi) + self.cpi)])
            return image, caption, caplen, all_captions
        
    def __len__(self):
        return self.dataset_size

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean = (0.485, 0.456, 0.406),
                                                     std = (0.229, 0.224, 0.225))])

In [None]:
train_loader = torch.utils.data.DataLoader(CaptionDataset(DATA_FOLDER, DATA_NAME, split='TRAIN', 
                                                          transform=transform), batch_size=BATCH_SIZE, 
                                                          shuffle=True, num_workers=2, pin_memory=True)

valid_loader = torch.utils.data.DataLoader(CaptionDataset(DATA_FOLDER, DATA_NAME, split='VALID',
                                                          trasform=transform), batch_size=BATCH_SIZE, 
                                                          shuffle=True, num_workers=2, pin_memory=True)

## Build [Image Captioning](https://arxiv.org/pdf/1411.4555.pdf) Network with [Attention](https://arxiv.org/pdf/1502.03044.pdf)

In [None]:
class EncoderCNN(nn.Module):
    
    def __init__(self, image_size=14):
        
        super(EncoderCNN, self).__init__()
        
        self.image_size = image_size
        
        # import pre-trained ImageNet ResNet-101
        resnet = torchvision.models.resnet101(pretrained=True)
        
        # remove linear and pool layers
        modules = list(resnet.children())[:-2]
        self.resnet_layer = nn.Sequential(*modules)
        
        # resize image to fixed size to allow input image of variable size
        self.adaptive_pool_layer = nn.AdaptiveAvgPool2d((image_size, image_size))
        
        # this will enable or disable the calculation of gradients for the Encoder's parameters
        self.fine_tune()
        
    def fine_tune(self, is_fine_tune=True):
        
        for param in self.resnet_layer.parameters():
            param.requires_grad = False
        
        # if fine-tuning, then only fine-tune convolutional blocks 2 through 4
        for child in list(self.resnet_layer.children())[5:]:
            for param in child.parameters():
                param.requires_grad = is_fine_tune
                
    def forward(self, images):
        
        feature_vectors = self.resnet_layer(images) # (batch_size, 2048, image_size/ 32, image_size/ 32)
        feature_vectors = self.adaptive_pool_layer(feature_vectors) # (batch_size, 2048, image_size/ 32, image_size/ 32)
        feature_vectors = feature_vectors.permute(0, 2, 3, 1) # (batch_size, image_size, image_size, 2048)
        
        return feature_vectors

In [None]:
class Attention(nn.Module):
    
    def __init__(self, encoder_size, decoder_size, attention_size):
        
        super(Attention, self).__init__()
        
        self.encoder_attention_layer = nn.Linear(encoder_size, attention_size) # linear layer to transform encoded image
        self.decoder_attention_layer = nn.Linear(decoder_size, attention_size) # linear layer to transform decoder's output
        self.total_attention_layer = nn.Linear(attention_size, 1) # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
        
    def forward(self, encoder_output, decoder_hidden):
        
        encoder_attention = self.encoder_attention_layer(encoder_output) # (batch_size, num_pixels, attention_size)
        decoder_attention = self.decoder_attention_layer(decoder_hidden) # (batch_size, attention_size)
        total_attention = self.total_attention_layer(self.relu(encoder_attention + decoder_attention.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
        alpha = self.softmax(total_attention) # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_attention * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_size)
        
        return attention_weighted_encoding, alpha

In [None]:
class AttentionDecoderRNN(nn.Module):
    
    def __init__(self, attention_size, embedding_size, decoder_size, vocab_size, encoder_size=2048, dropout=0.5):
        
        super(AttentionDecoderRNN, self).__init__()
        
        self.attention_size = attention_size
        self.embedding_size = embedding_size
        self.decoder_size = decoder_size
        self.vocab_size = vocab_size
        self.encoder_size = encoder_size
        self.dropout = dropout
        
        # init attention network
        self.attention_layer = Attention(encoder_size, decoder_size, attention_size)
        
        self.embedding_layer = nn.Embedding(vocab_size, embedding_size)
        self.dropout = nn.Dropout(self.dropout)
        self.decode_step = nn.LSTMCell(embedding_size + encoder_size, decoder_size, bias=True)
        self.init_hidden = nn.Linear(encoder_size, decoder_size) # linear layer to find initial hidden state of LSTM
        self.init_cell = nn.Linear(encoder_size, decoder_size) # linear layer to find initial cell state of LSTM
        self.f_beta = nn.Linear(decoder_size, encoder_size) # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc_layer = nn.Linear(decoder_size, vocab_size) # linear layer to find scores over vocabulary
        
        self.init_weights() # initialize some layers with the uniform distribution
        
    def init_weights(self):
        
        self.embedding_layer.weight.data.uniform_(-0.1, 0.1)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)

    def init_state(self, encoder_outputs):
        
        mean_encoder_outputs = encoder_outputs.mean(dim=1)
        hidden = self.init_hidden(mean_encoder_outputs)
        cell = self.init_cell(mean_encoder_outputs)
        
        return hidden, cell
    
    def load_pretrained_embeddings(self, embeddings):
        
        self.embedding_layer.weight = nn.Parameter(embeddings)
        
    def fine_tune_embeddings(self, is_fine_tune=True):
        
        for param in self.embedding_layer.parameters():
            param.requires_grad_ = is_fine_tune
            
    def forward(self, encoder_outputs, encoded_captions, caption_lengths):
        
        batch_size = encoder_outputs.size()
        encoder_size = encoder_outputs.size(-1)
        vocab_size = self.vocab_size
        
        # flatten image
        encoder_outputs = encoder_outputs.view(batch_size, -1, encoder_size) # (batch_size, num_pixels, encoder_size)
        num_pixels = encoder_outputs.size(1)
        
        # sort input data by decreasing lengths
        caption_lengths, sort_id = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_outputs = encoder_outputs[sort_id]
        encoded_captions = encoded_captions[sort_id]
        
        # embedding
        embeddings = self.embedding_layer(encoded_captions) # (batch_size, max_caption_length, embedding_size)
        
        # init LSTM state
        decoder_hidden, decoder_cell = self.init_state(encoder_outputs) # (batch_size, decoder_size)
        
        # since generation process finished as soon as model generate <end> so decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()
        
        # create tensors to hold word prediction scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
        
        # at each time-step, decode by attention weights the encoder's output based on the decoder's previous hidden state
        # then generate a new word in the decoder with the previous word and the attention-weighted encoding
        
        for d_time in range(max(decode_lengths)):
            batch_size_time = sum([length > d_time for length in decode_lengths])
            attention_weighted_encoding, alpha = self.attention_layer(encoder_outputs[:batch_size_time],
                                                                decoder_hidden[:batch_size_time])
            
            gate = self.sigmoid(self.f_beta(decoder_hidden[:batch_size_time])) # (batch_size_time, encoder_size)
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            decoder_hidden, decoder_cell = self.decode_step(
                torch.cat([embeddings[:batch_size_time, d_time, :], attention_weighted_encoding], dim=1),
                (decoder_hidden[:batch_size_time], decoder_cell[:batch_size_time])) # (batch_size_time, decoder_size)
            
            prediction = self.fc_layer(self.dropout(decoder_hidden)) # (batch_size_time, vocab_size)
            predictions[:batch_size_time, d_time, :] = prediction
            alphas[:batch_size_time, d_time, :] = alpha
            
        return predictions, encoded_captions, decode_lengths, alphas, sort_id

#### Initialize Image Captioning Network

In [None]:
encoder = EncoderCNN(); encoder.fine_tune(FINE_TUNE_ENCODER)
encoder.to(device)

In [None]:
decoder = AttentionDecoderRNN(attention_size=ATTENTION_SIZE, embedding_size=EMBEDDING_SIZE, decoder_size=DECODER_SIZE,
                              vocab_size=len(word_vocab), dropout=DROPOUT)

decoder.to(device)

## Set Loss Function

In [None]:
ce_loss = nn.CrossEntropyLoss()
ce_loss.to(device)

## Set Optimizer

In [None]:
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                     lr=ENCODER_LR) if FINE_TUNE_ENCODER else None

In [None]:
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                     lr=DECODER_LR)

## Train The Network

In [None]:
start_epoch = 0
epochs_since_improvement = 0
best_bleu4 = 0.
print_every = 100

In [None]:
def train_network(train_loader, encoder, decoder, ce_loss, encoder_optimizer, decoder_optimizer, epoch):
    
    # set train mode for the networks
    encoder.train()
    decoder.train()
    
    start = time.time()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top5_accuracy = AverageMeter()
    
    for i, (images, captions, caplens) in enumerate(train_loader):
        
         # set mini-batch datasets
        images = images.to(device); captions = captions.to(device); caplens = caplens.to(device);
        
        encoder.zero_grad()
        decoder.zero_grad()
        
        # encoding process: forward propagation
        features = encoder(images)
        scores, encoded_captions, decode_lengths, alphas, sort_id = decoder(images, captions, caplens)
        
        # since the network decode starting with <start>, the targets are all words after <start>, up to <end>
        targets = encoded_captions[:, 1:]
        
        # remove timesteps that the network doesn't decode at, or are pads
        scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
        
        # calculate loss
        loss = ce_loss(scores, targets)
        loss.backward()
        
        # add doublt stochastic attention regularization
        loss += ALPHA_C * ((1. - alphas.sum(dim=1)) ** 2).mean()
        
        # clip gradients
        clip_gradient(encoder_optimizer, GRAD_CLIP)
        clip_gradient(decoder_optimizer, GRAD_CLIP)
        
        # update weights
        encoder_optimizer.step()
        decoder_optimizer.step()
        
        # keep track of metrics
        top5 = calculate_accuracy(scores, targets, k=5)
        top5_accuracy.update(top5, sum(decode_lengths))
        losses.update(loss.item(), sum(decode_lengths))
        batch_time.update(time.time() - start)
        
        start = time.time()
        
        # print status
        if i % print_every == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time: {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss: {loss.val: .4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy: {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                           batch_time=batch_time,
                                                                           data_time=data_time, loss=losses,
                                                                           top5=top5_accuracy))

In [None]:
def validate_network(valid_loader, encoder, decoder, ce_loss):
    
    # set eval mode for the networks
    decoder.eval()
    if encoder is not None: encoder.eval()

    start = time.time()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top5_accuracy = AverageMeter()
    
    references = list() # references as true captions for calculating BLEU-4 score
    hypotheses = list() # hypotheses as predictions
    
    with torch.no_grad():
        
        for i, (images, captions, caplens, allcaps) in enumerate(valid_loader):
            
            images = images.to(device); captions = captions.to(device); caplens = caplens.to(device);
            
            # forward propagation
            if encoder is not None: images = encoder(images)
            scores, encoded_captions, decode_lengths, alphas, sort_id = decoder(images, captions, caplens)
            
            # since the network decode starting with <start>, the targets are all words after <start>, up to <end>
            targets = encoded_captions[:, 1:]
            
            # remove timesteps that the network doesn't decode at, or are pads
            scores_copy = scores.clone()
            scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
            
            # calculate loss
            loss = ce_loss(scores, targets)
            loss.backward()
            
            # add doublt stochastic attention regularization
            loss += ALPHA_C * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # keep track of metrics
            top5 = calculate_accuracy(scores, targets, k=5)
            top5_accuracy.update(top5, sum(decode_lengths))
            losses.update(loss.item(), sum(decode_lengths))
            batch_time.update(time.time() - start)
            
            start = time.time()
            
            if i % print_every == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy: {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(valid_loader),
                                                                                 batch_time=batch_time,
                                                                                 loss=losses,
                                                                                 top5=top5_accuracy))
        
            # store references (true captions), and hypothesis (prediction) for each image
            # if for n images, the network has n hypyothesis, and references a, b, c... for each image then the network need 
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b, ref2c], ...], 
            # hypotheses = [[ref1a, ref1b, ref1c], [ref2a, ref2b, ref2c], ...], 
            
            # references
            # because images were sorted in the decoder
            allcaps = allcaps[sort_id] 
            for j in range(allcaps.shape[0]):
                image_caps = allcaps[j].tolist()
                image_captions = list(
                    map(lambda chaption: [word for word in chaption if word not in {word_vocab['<start>'], word_vocab['<pad>']}],
                        img_caps)) # remove <start> and <pad> tokens
                references.append(image_captions)
                
            # hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]]) # remove <pad> token
            preds = temp_preds
            hypotheses.extend(preds)
            
            # sanity check
            assert len(references) == len(hypotheses)
            
        # calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)
        
        print('\n * LOSS: {loss.avg:.3f}, TOP-5 ACCURACY: {top5.avg:.3f}, BLEU-4: {bleu}\n'.format(
               loss=losses,
               top5=top5_accuracy,
               bleu=bleu4))
        
        return bleu4

In [None]:
print('Training the network...')
for epoch in range(1, N_EPOCHS+1):
    
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 ==0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if FINE_TUNE_ENCODER: adjust_learning_rate(encoder_optimizer, 0.8)
    
    # run one epoch's training
    train_network(train_loader=train_loader, encoder=encoder, decoder=decoder, ce_loss=ce_loss,
                  encoder_optimizer=encoder_optimizer, decoder_optimizer=decoder_optimizer, epoch=epoch)
    
    # run one epoch's validation
    recent_bleu4 = validate_network(valid_loader=valid_loader, encoder=encoder, decoder=decoder, ce_loss=ce_loss)
    
    # check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    
    if not is_best:
        epochs_since_improvement += 1
        print(f'\nEpochs since last improvement: {epochs_since_improvement}\n')
    else:
        epochs_since_improvement = 0
        
    # save checkpoints
    save_checkpoint(data_name, epoch, epochs_since_improvement, 
                    encoder, decoder, encoder_optimizer, decoder_optimizer, recent_bleu4, is_best)

## Evaluate The Network

In [None]:
BEAM_SIZE = 3

In [None]:
def evaluate_network(beam_size):
    
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean = (0.485, 0.456, 0.406),
                                                     std = (0.229, 0.224, 0.225))])
    
    test_loader = torch.utils.data.DataLoader(
        CaptionDataset(DATA_FOLDER, DATA_NAME, 'TEST', transform=transform, 
                       batch_size=1, shuffle=True, num_workers=1, pin_memory=True))
    
    references = list()
    hypotheses = list()
    
    for i, (images, captions, caplens, allcaps) in enumerate(test_loader, desc=f'EVALUATING AT BEAM SIZE {str(beam_size)}'):
        
        k = beam_size
        vocab_size = len(word_vocab)
        
        images = images.to(device); captions = captions.to(device); caplens = caplens.to(device);
        
        # encoding process: forward propagation
        features = encoder(image) # (1, encoder_image_size, encoder_image_size, encoder_size)
        encoder_image_size = features.size(1)
        encoder_size = features.size(3)
        
        # flatten features/ encoding outputs
        encoder_outputs = features.view(1, -1, encoder_size) # (1, num_pixels, encoder_size)
        num_pixels = encoder_outputs.size(1)
        
        # treat the problem as having a batch size of k
        encoder_outputs = encoder_outputs.expand(k, num_pixels, encoder_size) # (k, num_pixels, encoder_size)
        
        # tensor to store top k previous words at each step (starting from just <start> token)
        k_prev_words = torch.LongTensor([[word_vocab['<start>']]] * k).to(device) # (k, 1)
        
        # tensor to store top k sequences (starting from just <start> token)
        sequences = k_prev_words # (k, 1)
        
        # tensor to store top k sequences' scores (starting from 0 values)
        top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
        
        # lists to store completed sequences and scores
        complete_sequences = list()
        complete_sequences_scores = list()
        
        # decoding process
        step = 1
        hidden, cell = decoder.init_state(encoder_outputs)
        
        # s is a number less than or equal to k, since sequences are removed from this process once they hit <end> token
        while True:
            
            embeddings = decoder.embedding_layer(k_prev_words).squeeze(1) # (s, embedding_size)
            attention_weighted_encoding, _ = decoder.attention_layer(encoder_outputs, hidden) # (s, encoder_size), (s, num_pixels)
            
            gate = decoder.sigmoid(decoder.f_beta(hidden)) # gating scalar, (s, encoder_size)
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            hidden, cell = decoder.decode_step(torch.cat([embeddings, attention_weighted_encoding], dim=1), (hidden, cell)) # (s, decoder_size)
            
            scores = decoder.fc_layer(hidden) # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)
            
            # add scores
            scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
            
            # for the first step, all k points will have the same scores (since same k previous words, hidden, cell)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
            else:
                # unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
                
            # convert unrolled indices to actual indices of scores
            prev_word_indices = top_k_words / vocab_size # (s)
            next_word_indices = top_k_words % vocab_size # (s)
            
            # add new words to sequences
            sequences = torch.cat([sequences[prev_word_indices], next_word_indices.unsqueeze(1)], dim=1) # (s, step+1)
            
            incomplete_indices = [indices for indices, next_word in enumerate(next_word_indices) if
                                  next_word != word_vocab['<end>']]
            complete_indices = list(set(range(len(next_word_indices))) - set(incomplete_indices))
            
            # set aside complete sequences
            if len(complete_indices) > 0:
                complete_sequences.extend(sequences[complete_indices].tolist())
                complete_sequences_scores.extend(top_k_scores[complete_indices])
            k -= len(complete_indices) # reduce beam length accordingly
            
            # process with incomplete sequences
            if k==0: break
                
            sequences = sequences[incomplete_indices]
            hidden = hidden[prev_word_indices[incomplete_indices]]
            cell = cell[prev_word_indices[incomplete_indices]]
            encoder_outputs = encoder_outputs[prev_word_indices[incomplete_indices]]
            top_k_scores = top_k_scores[incomplete_indices].unsqueeze(1)
            k_prev_words = next_word_indices[incomplete_indices].unsqueeze(1)
            
            # break if thins have been going on too long
            if step > 50: break
            step += 1
            
        i = complete_sequences_scores.index(max(complete_sequences_scores))
        sequence = complete_sequences[i]
        
        # references
        image_caps = allcaps[0].tolist()
        image_captions = list(
            map(lambda chaption: [word for word in chaption if word not in {word_vocab['<start>'], word_vocab['<pad>']}],
                img_caps)) # remove <start> and <pad> tokens
        references.append(image_captions)
        
        # hypotheses
        hypotheses.append([word for word in sequence if word not in {word_vocab['<start>'], word_vocab['<end>'], word_vocab['<pad>']}])
        
        assert len(references) == len(hypotheses)
        
    # calculate BLEU-4 scores
    bleu4 = corpus_bleu(references, hypotheses)
    
    return bleu4

In [None]:
evaluate_network(BEAM_SIZE)

---