## Image Captioning with Pytorch

The following contents are modified from MDS DSCI 575 lecture 8 demo

In [1]:
import os, sys, json
from collections import defaultdict
from tqdm import tqdm
import pickle
from time import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torchsummary import summary
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

from nltk.translate import bleu_score
from sklearn.model_selection import KFold

sys.path.append('../../scr/evaluation')
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.usc_sim.usc_sim import usc_sim
import subprocess


START = "startseq"
STOP = "endseq"
EPOCHS = 10
AWS = True


In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
# torch.cuda.empty_cache()
# import gc 
# gc.collect()

In [4]:
# Nicely formatted time string
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m:>02}:{s:>05.2f}"
        
if AWS:
    root_captioning = "../../s3"
else:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        root_captioning = "/content/drive/My Drive/data"
        COLAB = True
        print("Note: using Google CoLab")
    except:
        print("Note: not using Google CoLab")
        COLAB = False

### Clean/Build Dataset

- Read captions
- Preprocess captions


In [5]:
def get_img_info(name, num=np.inf):
    """
    Returns img paths and captions

    Parameters:
    -----------
    name: str
        the json file name
    num: int (default: np.inf)
        the number of observations to get

    Return:
    --------
    list, dict, int
        img paths, corresponding captions, max length of captions
    """
    img_path = []
    caption = [] 
    max_length = 0
    if AWS:
        with open(f'{root_captioning}/json/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for filename in data.keys():
                if num is not None and len(caption) == num:
                    break
                img_path.append(
                    f'{root_captioning}/{name}/{filename}'
                )
                sen_list = []
                for sentence in data[filename]['sentences']:
                    max_length = max(max_length, len(sentence['tokens']))
                    sen_list.append(sentence['raw'])

                caption.append(sen_list)    
    else:            
        with open(f'{root_captioning}/interim/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for set_name in ['rsicd', 'ucm']:
                for filename in data[set_name].keys():
                    if num is not None and len(caption) == num:
                        break

                    img_path.append(
                        f'{root_captioning}/raw/imgs/{set_name}/{filename}'
                    )
                    sen_list = []
                    for sentence in data[set_name][filename]['sentences']:
                        max_length = max(max_length, len(sentence['tokens']))
                        sen_list.append(sentence['raw'])

                    caption.append(sen_list)
    
    return img_path, caption, max_length            


In [6]:
# get img path and caption list
# # only test 800 train samples and 200 valid samples
# train_paths, train_descriptions, max_length_train = get_img_info('train', 800)
# test_paths, test_descriptions, max_length_test = get_img_info('valid', 200)

train_paths, train_descriptions, max_length_train = get_img_info('train')
test_paths, test_descriptions, max_length_test = get_img_info('valid')
max_length = max(max_length_train, max_length_test)



In [7]:
all_paths = train_paths.copy()
all_paths.extend(test_paths.copy())
all_paths = np.array(all_paths)

all_descriptions = train_descriptions.copy()
all_descriptions.extend(test_descriptions.copy())
all_descriptions = np.array(all_descriptions)

captions = all_descriptions.copy()
max_length_all = max(max_length_train, max_length_test)
max_length = max_length_all + 2
      
lex = set()
for sen in all_descriptions:
    [lex.update(d.split()) for d in sen]
    
# add a start and stop token at the beginning/end
for v in all_descriptions:
    for d in range(len(v)):
        v[d] = f'{START} {v[d]} {STOP}'
        
print(f'There are {len(all_paths)} images') 
print(f'There are {len(lex)} unique words (vocab)')
print(f'The maximum length of captions with start and stop token is {max_length}.')


There are 10416 images
There are 2912 unique words (vocab)
The maximum length of captions with start and stop token is 36.


In [8]:
all_paths[-1]

'../../s3/valid/rsicd_park_33.jpg'

In [9]:
all_descriptions[-1]

array(['startseq a vast artificial lake was built in the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq'],
      dtype='<U184')

### Loading Wikipedia2vec Embeddings

In [10]:
# read the embedding matrix 
with open(f'{root_captioning}/enwiki_20180420_2338_words_500d.json', 'r', encoding='utf-8') as file:
    embeddings_index = json.load(file)

In [11]:
def get_vocab(descriptions, word_count_threshold=10):

    captions = []
    for val in descriptions:
        for cap in val:
            captions.append(cap)
    print(f'There are {len(captions)} captions')
    
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('preprocessed words %d ==> %d' % (len(word_counts), len(vocab)))
    return vocab

def get_word_dict(vocab):
    
    idxtoword = {}
    wordtoidx = {}

    ix = 1
    for w in vocab:
        wordtoidx[w] = ix
        idxtoword[ix] = w
        ix += 1

    return idxtoword, wordtoidx

def get_vocab_size(idxtoword):
    
    print(f'The vocabulary size is {len(idxtoword) + 1}.')
    return len(idxtoword) + 1


def get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx):

    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    count = 0

    for word, i in wordtoidx.items():

        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            count += 1
            # Words not found in the embedding index will be all zeros
            embedding_matrix[i] = embedding_vector
            
    print(f'{count} out of {vocab_size} words are found in the pre-trained matrix.')            
    print(f'The size of embedding_matrix is {embedding_matrix.shape}')
    return embedding_matrix

### Building the Neural Network

An embedding matrix is built from Glove.  This will be directly copied to the weight matrix of the neural network.

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
# resnet 101 expects (224, 224) sized images
resnet_model =\
models.resnet101(pretrained=True).to(device)

In [14]:
class CNNModel(nn.Module):

    def __init__(self, pretrained=True):
        """
        Initializes a CNNModel

        Parameters:
        -----------
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(CNNModel, self).__init__()
        
        # remove the classification layer
        self.model =\
        nn.Sequential(
            *list(resnet_model.children())[:-3]
        )

        self.input_size = 224

    def forward(self, img_input, train=False):
        """
        forward of the CNNModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        if not train:
          # set the model to evaluation model
          self.model.eval()

        # N x 3 x 224 x 224
        features = self.model(img_input)
        # N x 1024 x 14 x 14

        return features

In [15]:
class AttentionModel(nn.Module):

    def __init__(self, hidden_size=256):
        """
        Initializes a AttentionModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(AttentionModel, self).__init__()

        self.W_qh = nn.Linear(hidden_size, 256)
        self.W_qI = nn.Linear(hidden_size, 256)
        self.W_q = nn.Linear(256, 1, bias=False)

    def forward(self, I, h):
        """
        forward of the AttentionModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        batch_size = I.size()[0]
        weight_dim = I.size()[1]
        
        # N x hidden_size
        h_a = h.unsqueeze(1).repeat(1, weight_dim, 1)
        # N x weight_dim x hidden_size

        # attention scoring function W_q*tanh(W_qI(I) + W_qh(h))
        attention =\
        self.W_q(
            torch.tanh(
                self.W_qI(I) + self.W_qh(h_a)
            )
        ).permute(0, 2, 1).squeeze(1)
        # N x weight_dim
        
        attention_weights = F.softmax(attention, dim=1)
        # N x weight_dim

        return attention_weights

In [16]:
class RNNModel(nn.Module):

    def __init__(
        self, 
        feature_size,
        vocab_size,
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):
      
        """
        Initializes a RNNModel

        Parameters:
        -----------
        feature_size: int
            the number of features in the image matrix
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """

        super(RNNModel, self).__init__()

        self.feature_size = feature_size
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()
        self.out_dense = nn.Linear(hidden_size, hidden_size)
        self.h_dense = nn.Linear(feature_size, hidden_size)
        self.c_dense = nn.Linear(feature_size, hidden_size)
        self.img_dense = nn.Linear(feature_size, hidden_size)
        self.lstm_dense = nn.Linear(hidden_size, hidden_size)
        
        self.embedding =\
        nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        if embedding_matrix is not None:

            self.embedding.load_state_dict({
                'weight': torch.FloatTensor(embedding_matrix)
            })
            self.embedding.weight.requires_grad = embedding_train

        self.attention1 = AttentionModel(hidden_size)
        self.attention2 = AttentionModel(hidden_size)
        self.attention3 = AttentionModel(hidden_size)

        self.lstm =\
        nn.LSTMCell(embedding_dim, hidden_size, bias=True)
      
    def forward(self, img_features, captions):
        """
        forward of the RNNModel

        Parameters:
        -----------
        img_features: torch.Tensor 
            the image feature matrix
            (N x feature_size(1024) x 14 x 14)
        captions: torch.Tensor 
            the padded caption matrix
            (N x seq_len)

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        # N = batch_size
        batch_size = captions.size(0)
        seq_len = captions.size(1)

        # N x feature_size(1024) x 14 x 14
        img_features = img_features.view(
            batch_size, self.feature_size, -1
        ).permute(0, 2, 1)
        # N x 196 x feature_size(1024)

        # N x 1000
        h = self.h_dense(img_features.mean(dim=1))
        c = self.c_dense(img_features.mean(dim=1))
        # N x hidden_size     

        # N x seq_len
        embed = self.embedding(captions)
        # N x seq_len x embedding_dim 

        # N x 196 x feature_size(1024)
        img_features = self.relu(self.img_dense(img_features))
        # N x 196 x hidden_size

        outputs =\
        torch.zeros(
            batch_size,
            seq_len, 
            self.hidden_size
        ).to(device)
        
        cs =\
        torch.zeros(
            batch_size,
            seq_len + 1, 
            self.hidden_size
        ).to(device)

        v_weights =\
        torch.zeros(
            batch_size,
            seq_len + 1, 
            196
        ).to(device)
            
        z_weights =\
        torch.zeros(
            batch_size,
            seq_len, 
            2
        ).to(device)
        
        for i in range(seq_len + 1):
            
            if i > 0:
                h, c =\
                self.lstm(
                    self.lstm_dense(embed[:, i - 1, :] + z), 
                    (h, c)
                )
                # h: N x hidden_size
                # c: N x hidden_size
            
            v_weight = self.attention1(img_features, h)
            # N x 196
            # weighted sum of image features
            v = (img_features * v_weight.unsqueeze(2)).sum(dim=1)
            # N x hidden_size
            v_weights[:, i, :] = v_weight

            cs[:, i, :] = c
            cs_t = cs.clone()
            s_weight = self.attention2(cs_t[:, :(i + 1), :], h)
            # N x (i + 1)
            # weighted sum of semantic features
            s = (cs_t[:, :(i + 1), :] * s_weight.unsqueeze(2)).sum(dim=1)
            # N x hidden_size
            
            z = torch.cat((s.unsqueeze(1), v.unsqueeze(1)), dim=1)
            # N x 2 x hidden_size    
            z_weight = self.attention3(z, h)
            # N x 2
            # weighted sum of z
            z = (z * z_weight.unsqueeze(2)).sum(dim=1)
            # N x hidden_size

            if i > 0:
                output =\
                self.out_dense(
                    z + h
                )
                outputs[:, i - 1, :] = output
                z_weights[:, i - 1, :] = z_weight
                
        return outputs, v_weights, z_weights



In [17]:
class CaptionModel(nn.Module):

    def __init__(
        self, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):

        """
        Initializes a CaptionModel

        Parameters:
        -----------
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """    
        super(CaptionModel, self).__init__() 

        # set feature_size based on cnn_type
        self.feature_size = 1024

        self.decoder = RNNModel(
            self.feature_size,
            vocab_size, 
            embedding_dim,
            hidden_size,
            embedding_matrix,
            embedding_train
        )


        self.dropout = nn.Dropout(p=0.5)        
        self.relu = nn.ReLU()
        self.dense = nn.Linear(hidden_size, vocab_size) 

    # def forward(self, captions):
    def forward(self, img_features, captions):
        """
        forward of the CaptionModel

        Parameters:
        -----------
        img_features: torch.Tensor 
            the image feature matrix
            (N x feature_size(1024) x 14 x 14)
        captions: torch.Tensor 
            the padded caption matrix
            (N x seq_len)

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """
    
        decoder_out, v_weights, z_weights = self.decoder(img_features, captions)

        # add up decoder outputs and image features
        outputs =\
        self.dense(
            self.relu(
                decoder_out
            )
        )

        return outputs, v_weights, z_weights


### Train the Neural Network

In [18]:
def train(model, iterator, optimizer, criterion, clip, vocab_size):
    """
    train the CaptionModel

    Parameters:
    -----------
    model: CaptionModel
        a CaptionModel instance
    iterator: torch.utils.data.dataloader
        a PyTorch dataloader
    optimizer: torch.optim
        a PyTorch optimizer 
    criterion: nn.CrossEntropyLoss
        a PyTorch criterion 

    Return:
    --------
    float
        average loss
    """
    model.train()    
    epoch_loss = 0
    
    for img_features, captions in iterator:
        
        optimizer.zero_grad()

        # for each caption, the end word is not passed for training
        outputs, v_weights, z_weights = model(
            img_features.to(device),
            captions[:, :-1].to(device)
        )

        loss = criterion(
            outputs.view(-1, vocab_size), 
            captions[:, 1:].flatten().to(device)
        ) + 0.001 * ((1. - z_weights.sum(dim=1)) ** 2).mean()\
        + ((1. - v_weights.sum(dim=1)) ** 2).mean() 
        
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        
    return epoch_loss / len(iterator)

In [19]:
class SampleDataset(Dataset):
    def __init__(
        self,
        descriptions,
        imgs,
        wordtoidx,
        max_length
    ):
        """
        Initializes a SampleDataset

        Parameters:
        -----------
        descriptions: list
            a list of captions
        imgs: numpy.ndarray
            the image features
        wordtoidx: dict
            the dict to get word index
        max_length: int
            all captions will be padded to this size
        """        
        self.imgs = imgs
        self.descriptions = descriptions
        self.wordtoidx = wordtoidx
        self.max_length = max_length

    def __len__(self):
        """
        Returns the batch size

        Return:
        --------
        int
            the batch size
        """
        # return len(self.descriptions)
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Prepare data for each image

        Parameters:
        -----------
        idx: int
          the index of the image to process

        Return:
        --------
        list, list, list
            [5 x image feature matrix],
            [five padded captions for this image]
            [the length of each caption]
        """

        img = self.imgs[idx // 5]
        # convert each word into a list of sequences.
        seq = [self.wordtoidx[word] for word 
               in self.descriptions[idx // 5][idx % 5].split(' ')
               if word in self.wordtoidx]
        # pad the sequence with 0 on the right side
        in_seq = np.pad(
            seq, 
            (0, max_length - len(seq)),
            mode='constant',
            constant_values=(0, 0)
            )

        return img, in_seq


In [20]:
def init_weights(model, embedding_pretrained=True):
    """
    Initialize weights and bias in the model

    Parameters:
    -----------
    model: CaptionModel
      a CaptionModel instance
    embedding_pretrained: bool (default: True)
        not initialize the embedding matrix if True
    """  
  
    for name, param in model.named_parameters():
        if embedding_pretrained and 'embedding' in name:
            continue
        elif 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            


In [21]:
def encode_image(model, img_path):
    """
    Process the images to extract features

    Parameters:
    -----------
    model: CNNModel
      a CNNModel instance
    img_path: str
        the path of the image
 
    Return:
    --------
    torch.Tensor
        the extracted feature matrix from CNNModel
    """  

    img = Image.open(img_path)

    # Perform preprocessing needed by pre-trained models
    preprocessor = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    img = preprocessor(img)
    # Expand to 2D array
    img = img.view(1, *img.shape)
    # Call model to extract the smaller feature set for the image.
    x = model(img.to(device), False) 
    # Shape to correct form to be accepted by LSTM captioning network.
    x = np.squeeze(x)
    return x

In [22]:
def extract_img_features(img_paths, model):
    """
    Extracts, stores and returns image features

    Parameters:
    -----------
    img_paths: list
        the paths of images
    model: CNNModel (default: None)
      a CNNModel instance

    Return:
    --------
    numpy.ndarray
        the extracted image feature matrix from CNNModel
    """ 

    start = time()
    img_features = []

    for image_path in img_paths:
        img_features.append(
            encode_image(model, image_path).cpu().data.numpy()
        )

    print(f"\nGenerating set took: {hms_string(time()-start)}")

    return img_features

In [23]:
def get_train_test(
    encoder,
    train_paths,
    test_paths
):

    train_img_features = extract_img_features(
        train_paths,
        encoder
    )

    test_img_features = extract_img_features(
        test_paths,
        encoder
    )
    return train_img_features, test_img_features

def get_train_dataloader(
    train_descriptions, 
    train_img_features,
    wordtoidx,
    max_length,
    batch_size=200
):
    train_dataset = SampleDataset(
        train_descriptions,
        train_img_features,
        wordtoidx,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size
    )
    
    return train_loader

def train_model(
    train_loader,
    vocab_size,
    embedding_dim, 
    embedding_matrix,
    hidden_size=256,
):

    caption_model = CaptionModel(
        vocab_size, 
        embedding_dim, 
        hidden_size=500,
        embedding_matrix=embedding_matrix, 
        embedding_train=True
    )

    init_weights(
        caption_model,
        embedding_pretrained=True
    )

    caption_model.to(device)

    # we will ignore the pad token in true target set
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(
        caption_model.parameters(), 
        lr=0.001
    )

    clip = 1
    start = time()
    loss = []
    for i in tqdm(range(EPOCHS * 10)):
        
        loss.append(train(caption_model, train_loader, optimizer, criterion, clip, vocab_size))
        print(loss[-1])
        
        if i >=2 and loss[-3] <= loss[-1]:
            # reduce the learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.8
        if i >= 6 and loss[-6] <= loss[-1]:
            break
    print(f"\Training took: {hms_string(time()-start)}")
        
    return caption_model

In [24]:
def generateCaption(
    model, 
    img_features,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    in_text = START

    for i in range(max_length):

        sequence = [wordtoidx[w] for w in in_text.split() if w in wordtoidx]
        sequence = np.pad(sequence, (0, max_length - len(sequence)),
                          mode='constant', constant_values=(0, 0))
        model.eval()
        yhat, _, _ = model(
            torch.FloatTensor(img_features)\
            .view(-1, model.feature_size).to(device),
            torch.LongTensor(sequence).view(-1, max_length).to(device)
        )

        yhat = yhat.view(-1, vocab_size).argmax(1)
        word = idxtoword[yhat.cpu().data.numpy()[i]]
        in_text += ' ' + word
        if word == STOP:
            break
            
    final = in_text.split()
    final = final[1 : -1]
    final = ' '.join(final)
    return final

### Evaluation

In [25]:
def eval_model(ref_data, results):
    """
    Computes evaluation metrics of the model results against the human annotated captions
    
    Parameters:
    ------------
    ref_data: dict
        a dictionary containing human annotated captions, with image name as key and a list of human annotated captions as values
    
    results: dict
        a dictionary containing model generated caption, with image name as key and a generated caption as value
        
    Returns:
    ------------
    score_dict: a dictionary containing the overall average score for the model
    """
    # download stanford nlp library
    subprocess.call(['../../scr/evaluation/get_stanford_models.sh'])
    
    # format the inputs
    gts = {}
    res = {}

    for imgId in range(len(ref_data)):
        caption_list_sel = []
        for i in range(5):
            lst = {}
            lst['caption'] = ref_data[imgId][i]
            lst['image_id'] = imgId
            lst['id'] = i
            caption_list_sel.append(lst)
        gts[imgId] = caption_list_sel

        res[imgId] = [{'caption': results[imgId]}]
        
    # tokenize
    print('tokenization...')
    tokenizer = PTBTokenizer()
    gts  = tokenizer.tokenize(gts)
    res = tokenizer.tokenize(res)
    
    # compute scores
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE"),
        (usc_sim(), "USC_similarity"),  
        ]
    score_dict = {}
    for scorer, method in scorers:
        print('computing %s score...'%(scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
        else:
            score_dict[method] = score
            
    return score_dict


In [26]:
def evaluate_results(
    test_img_features, 
    model,
    ref,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    # generate results
    print('Generating captions...')
    results = {}
    for n in range(len(test_img_features)):
        img_features = test_img_features[n]
        generated = generateCaption(
            model, 
            img_features,
            max_length,
            vocab_size,
            wordtoidx,
            idxtoword
        )
        results[n] = generated
        
    model_score = eval_model(ref, results)

    return model_score

### Cross validation

In [27]:
encoder = CNNModel(pretrained=True)
encoder.to(device)

CNNModel(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 2

In [28]:
def cross_validation(train_index, test_index, count):
    print('=' * 60)
    print(f'Split {count}:')
    print(f'Splitting data...')
    
    train_paths, test_paths = all_paths[train_index], all_paths[test_index]
    train_descriptions, test_descriptions = all_descriptions[train_index], all_descriptions[test_index]
    print(f'{len(train_paths)} images for training and {len(test_paths)} images for testing.')
    
    vocab = get_vocab(train_descriptions, word_count_threshold=10)
    idxtoword, wordtoidx = get_word_dict(vocab)
    vocab_size = get_vocab_size(idxtoword)
    embedding_dim = 500
    embedding_matrix = get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx) 

    print(f'Preparing dataloader...')
    train_img_features, test_img_features = get_train_test(encoder, train_paths, test_paths)

    train_loader = get_train_dataloader(
        train_descriptions, 
        train_img_features,
        wordtoidx,
        max_length,
        batch_size=200
    )

    print(f'Training...')
    caption_model = train_model(
        train_loader,
        vocab_size,
        embedding_dim, 
        embedding_matrix,
        hidden_size=500
    )

    
    ref = captions[test_index]
    model_score = evaluate_results(
        test_img_features, 
        caption_model,
        ref,
        max_length,
        vocab_size,
        wordtoidx,
        idxtoword
    )
    
    return caption_model, model_score

In [29]:
cv = KFold(n_splits=5, random_state=123, shuffle=True)
cv = [(train_index, test_index) for train_index, test_index in cv.split(all_paths)]  

In [30]:
caption_model1, model_score1 = cross_validation(cv[0][0], cv[0][1], 1)    

Split 1:
Splitting data...
8332 images for training and 2084 images for testing.
There are 41660 captions
preprocessed words 2659 ==> 884
The vocabulary size is 885.
793 out of 885 words are found in the pre-trained matrix.
The size of embedding_matrix is (885, 500)
Preparing dataloader...

Generating set took: 0:03:17.45


  0%|          | 0/100 [00:00<?, ?it/s]


Generating set took: 0:00:49.30
Training...


  1%|          | 1/100 [00:18<31:20, 18.99s/it]

5.734052726200649


  2%|▏         | 2/100 [00:37<30:59, 18.98s/it]

4.9464304220108755


  3%|▎         | 3/100 [00:56<30:40, 18.98s/it]

4.303166701680138


  4%|▍         | 4/100 [01:15<30:20, 18.97s/it]

3.805367969331287


  5%|▌         | 5/100 [01:34<30:00, 18.96s/it]

3.52865807783036


  6%|▌         | 6/100 [01:53<29:41, 18.95s/it]

3.33164001646496


  7%|▋         | 7/100 [02:12<29:23, 18.96s/it]

3.178011320886158


  8%|▊         | 8/100 [02:31<29:05, 18.98s/it]

3.0413631257556735


  9%|▉         | 9/100 [02:50<28:50, 19.01s/it]

2.901861520040603


 10%|█         | 10/100 [03:09<28:32, 19.03s/it]

2.7773916323979697


 11%|█         | 11/100 [03:28<28:15, 19.05s/it]

2.7018758966809227


 12%|█▏        | 12/100 [03:47<27:54, 19.03s/it]

2.6119979109082903


 13%|█▎        | 13/100 [04:06<27:34, 19.01s/it]

2.515392820040385


 14%|█▍        | 14/100 [04:25<27:14, 19.00s/it]

2.4327358972458613


 15%|█▌        | 15/100 [04:44<26:53, 18.99s/it]

2.3661125784828547


 16%|█▌        | 16/100 [05:03<26:33, 18.98s/it]

2.305848717689514


 17%|█▋        | 17/100 [05:22<26:15, 18.98s/it]

2.2616894926343645


 18%|█▊        | 18/100 [05:41<25:55, 18.97s/it]

2.2161957877022878


 19%|█▉        | 19/100 [06:00<25:36, 18.97s/it]

2.1632749920799617


 20%|██        | 20/100 [06:19<25:18, 18.98s/it]

2.1297799462363836


 21%|██        | 21/100 [06:38<24:58, 18.97s/it]

2.086217079843794


 22%|██▏       | 22/100 [06:57<24:39, 18.97s/it]

2.0454494953155518


 23%|██▎       | 23/100 [07:16<24:22, 18.99s/it]

2.014845376922971


 24%|██▍       | 24/100 [07:35<24:03, 18.99s/it]

1.9848426495279585


 25%|██▌       | 25/100 [07:54<23:44, 18.99s/it]

1.9358857160522824


 26%|██▌       | 26/100 [08:13<23:24, 18.99s/it]

1.9146472215652466


 27%|██▋       | 27/100 [08:32<23:05, 18.97s/it]

1.87541279338655


 28%|██▊       | 28/100 [08:51<22:45, 18.96s/it]

1.8365068833033245


 29%|██▉       | 29/100 [09:10<22:25, 18.96s/it]

1.8227997989881606


 30%|███       | 30/100 [09:29<22:09, 19.00s/it]

1.792984162058149


 31%|███       | 31/100 [09:48<21:52, 19.01s/it]

1.7541596123150416


 32%|███▏      | 32/100 [10:07<21:34, 19.04s/it]

1.7172393997510274


 33%|███▎      | 33/100 [10:26<21:16, 19.05s/it]

1.69328259570258


 34%|███▍      | 34/100 [10:45<20:57, 19.05s/it]

1.6643013925779433


 35%|███▌      | 35/100 [11:04<20:38, 19.06s/it]

1.6461073756217957


 36%|███▌      | 36/100 [11:23<20:18, 19.05s/it]

1.6129586214110965


 37%|███▋      | 37/100 [11:42<19:59, 19.04s/it]

1.5654275644393194


 38%|███▊      | 38/100 [12:02<19:40, 19.05s/it]

1.5454173542204357


 39%|███▉      | 39/100 [12:21<19:21, 19.03s/it]

1.5284761786460876


 40%|████      | 40/100 [12:40<19:01, 19.03s/it]

1.5074337607338315


 41%|████      | 41/100 [12:59<18:42, 19.03s/it]

1.4726681993121193


 42%|████▏     | 42/100 [13:18<18:22, 19.02s/it]

1.450360425880977


 43%|████▎     | 43/100 [13:37<18:03, 19.01s/it]

1.434130799202692


 44%|████▍     | 44/100 [13:55<17:42, 18.97s/it]

1.4326801640646798


 45%|████▌     | 45/100 [14:14<17:21, 18.94s/it]

1.4129574611073448


 46%|████▌     | 46/100 [14:33<17:02, 18.94s/it]

1.4029367963473003


 47%|████▋     | 47/100 [14:52<16:43, 18.93s/it]

1.4155860486484708


 48%|████▊     | 48/100 [15:11<16:24, 18.93s/it]

1.3907369914509


 49%|████▉     | 49/100 [15:30<16:05, 18.93s/it]

1.352152793180375


 50%|█████     | 50/100 [15:49<15:46, 18.93s/it]

1.3185158570607503


 51%|█████     | 51/100 [16:08<15:27, 18.92s/it]

1.2960589840298606


 52%|█████▏    | 52/100 [16:27<15:07, 18.92s/it]

1.2883822321891785


 53%|█████▎    | 53/100 [16:46<14:49, 18.92s/it]

1.2912849159467787


 54%|█████▍    | 54/100 [17:05<14:30, 18.92s/it]

1.2744559106372653


 55%|█████▌    | 55/100 [17:23<14:10, 18.90s/it]

1.2529438961119879


 56%|█████▌    | 56/100 [17:42<13:52, 18.92s/it]

1.247402432419005


 57%|█████▋    | 57/100 [18:01<13:34, 18.94s/it]

1.230694049880618


 58%|█████▊    | 58/100 [18:20<13:15, 18.94s/it]

1.2183452276956468


 59%|█████▉    | 59/100 [18:39<12:57, 18.95s/it]

1.2165659609295072


 60%|██████    | 60/100 [18:58<12:38, 18.96s/it]

1.2074683336984544


 61%|██████    | 61/100 [19:17<12:19, 18.96s/it]

1.18937733627501


 62%|██████▏   | 62/100 [19:36<12:00, 18.96s/it]

1.1663404561224437


 63%|██████▎   | 63/100 [19:55<11:41, 18.96s/it]

1.15452664239066


 64%|██████▍   | 64/100 [20:14<11:22, 18.97s/it]

1.1386584015119643


 65%|██████▌   | 65/100 [20:33<11:03, 18.97s/it]

1.1253880290758043


 66%|██████▌   | 66/100 [20:52<10:45, 18.97s/it]

1.1194702500388736


 67%|██████▋   | 67/100 [21:11<10:25, 18.97s/it]

1.1152784654072352


 68%|██████▊   | 68/100 [21:30<10:06, 18.96s/it]

1.1168011881056286


 69%|██████▉   | 69/100 [21:49<09:48, 18.97s/it]

1.1122061808904011


 70%|███████   | 70/100 [22:08<09:28, 18.95s/it]

1.1086098012470065


 71%|███████   | 71/100 [22:27<09:09, 18.93s/it]

1.1058670793260847


 72%|███████▏  | 72/100 [22:46<08:50, 18.95s/it]

1.1122196969531832


 73%|███████▎  | 73/100 [23:05<08:31, 18.96s/it]

1.1119944197790963


 74%|███████▍  | 74/100 [23:24<08:13, 18.97s/it]

1.0974581553822471


 75%|███████▌  | 75/100 [23:43<07:54, 18.98s/it]

1.0816652008465357


 76%|███████▌  | 76/100 [24:02<07:35, 18.98s/it]

1.0691209208397638


 77%|███████▋  | 77/100 [24:21<07:16, 18.99s/it]

1.0593171261605763


 78%|███████▊  | 78/100 [24:40<06:58, 19.00s/it]

1.052308752423241


 79%|███████▉  | 79/100 [24:59<06:38, 19.00s/it]

1.0477292594455538


 80%|████████  | 80/100 [25:18<06:19, 19.00s/it]

1.043889005978902


 81%|████████  | 81/100 [25:37<06:00, 19.00s/it]

1.0409677312487648


 82%|████████▏ | 82/100 [25:56<05:41, 18.99s/it]

1.0398858530180795


 83%|████████▎ | 83/100 [26:15<05:22, 19.00s/it]

1.0391599592708407


 84%|████████▍ | 84/100 [26:34<05:04, 19.00s/it]

1.037731525443849


 85%|████████▌ | 85/100 [26:53<04:45, 19.01s/it]

1.0378920492671786


 86%|████████▌ | 86/100 [27:12<04:26, 19.01s/it]

1.035772638661521


 87%|████████▋ | 87/100 [27:31<04:07, 19.04s/it]

1.0346294868560064


 88%|████████▊ | 88/100 [27:50<03:48, 19.07s/it]

1.0332109048253013


 89%|████████▉ | 89/100 [28:09<03:29, 19.07s/it]

1.033754831268674


 90%|█████████ | 90/100 [28:28<03:10, 19.07s/it]

1.0327600183941068


 91%|█████████ | 91/100 [28:47<02:51, 19.05s/it]

1.0328102395648049


 92%|█████████▏| 92/100 [29:06<02:32, 19.05s/it]

1.0328841635159083
1.0338986487615676
\Training took: 0:29:25.76
Generating captions...
tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [31]:
model_score1

{'Bleu_1': 0.6057266732905516,
 'Bleu_2': 0.47195977965507746,
 'Bleu_3': 0.3864751508901778,
 'Bleu_4': 0.3266986863206117,
 'METEOR': 0.27835407428470443,
 'ROUGE_L': 0.5136516817675681,
 'CIDEr': 1.7589434443397518,
 'SPICE': 0.36106246320174085,
 'USC_similarity': 0.5855836546655401}

In [32]:
caption_model2, model_score2 = cross_validation(cv[1][0], cv[1][1], 2)    

Split 2:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2688 ==> 916
The vocabulary size is 917.
819 out of 917 words are found in the pre-trained matrix.
The size of embedding_matrix is (917, 500)
Preparing dataloader...

Generating set took: 0:03:18.18



  0%|          | 0/100 [00:00<?, ?it/s][A


Generating set took: 0:00:49.56
Training...



  1%|          | 1/100 [00:18<31:19, 18.99s/it][A

5.739231961114066



  2%|▏         | 2/100 [00:37<30:47, 18.85s/it][A

5.023782355444772



  3%|▎         | 3/100 [00:56<30:22, 18.78s/it][A

4.4604212783631825



  4%|▍         | 4/100 [01:14<29:57, 18.73s/it][A

3.912609883717128



  5%|▌         | 5/100 [01:33<29:35, 18.69s/it][A

3.6368123292922974



  6%|▌         | 6/100 [01:51<29:14, 18.66s/it][A

3.4545162178221203



  7%|▋         | 7/100 [02:10<28:53, 18.64s/it][A

3.2947608686628795



  8%|▊         | 8/100 [02:29<28:34, 18.63s/it][A

3.161063234011332



  9%|▉         | 9/100 [02:47<28:14, 18.62s/it][A

3.03883311294374



 10%|█         | 10/100 [03:06<27:55, 18.61s/it][A

2.923004911059425



 11%|█         | 11/100 [03:24<27:36, 18.62s/it][A

2.8265279134114585



 12%|█▏        | 12/100 [03:43<27:17, 18.61s/it][A

2.74996010462443



 13%|█▎        | 13/100 [04:02<26:58, 18.61s/it][A

2.6799596434547786



 14%|█▍        | 14/100 [04:20<26:39, 18.60s/it][A

2.6080678076971147



 15%|█▌        | 15/100 [04:39<26:20, 18.60s/it][A

2.5435540278752646



 16%|█▌        | 16/100 [04:57<26:02, 18.60s/it][A

2.4940549646105086



 17%|█▋        | 17/100 [05:16<25:43, 18.60s/it][A

2.4464896690277826



 18%|█▊        | 18/100 [05:35<25:24, 18.60s/it][A

2.3787976673671176



 19%|█▉        | 19/100 [05:53<25:06, 18.60s/it][A

2.3126810051146007



 20%|██        | 20/100 [06:12<24:48, 18.60s/it][A

2.2591709182375954



 21%|██        | 21/100 [06:30<24:30, 18.61s/it][A

2.218374104726882



 22%|██▏       | 22/100 [06:49<24:11, 18.61s/it][A

2.1912786137490046



 23%|██▎       | 23/100 [07:08<23:50, 18.58s/it][A

2.146808385848999



 24%|██▍       | 24/100 [07:26<23:30, 18.56s/it][A

2.11502989984694



 25%|██▌       | 25/100 [07:45<23:09, 18.53s/it][A

2.08121296053841



 26%|██▌       | 26/100 [08:03<22:50, 18.53s/it][A

2.023201803366343



 27%|██▋       | 27/100 [08:22<22:31, 18.52s/it][A

1.985615889231364



 28%|██▊       | 28/100 [08:40<22:15, 18.55s/it][A

1.9480876213028318



 29%|██▉       | 29/100 [08:59<21:58, 18.56s/it][A

1.917640507221222



 30%|███       | 30/100 [09:17<21:40, 18.57s/it][A

1.8803967663219996



 31%|███       | 31/100 [09:36<21:22, 18.58s/it][A

1.8429827349526542



 32%|███▏      | 32/100 [09:55<21:04, 18.59s/it][A

1.8007568007423764



 33%|███▎      | 33/100 [10:13<20:45, 18.59s/it][A

1.7812088302203588



 34%|███▍      | 34/100 [10:32<20:26, 18.59s/it][A

1.7612473113196236



 35%|███▌      | 35/100 [10:50<20:08, 18.59s/it][A

1.7584710546902247



 36%|███▌      | 36/100 [11:09<19:49, 18.59s/it][A

1.7281083436239333



 37%|███▋      | 37/100 [11:27<19:30, 18.58s/it][A

1.694283942381541



 38%|███▊      | 38/100 [11:46<19:10, 18.56s/it][A

1.67653926497414



 39%|███▉      | 39/100 [12:05<18:51, 18.55s/it][A

1.6439700694311232



 40%|████      | 40/100 [12:23<18:34, 18.57s/it][A

1.6046386304355802



 41%|████      | 41/100 [12:42<18:16, 18.58s/it][A

1.5745188224883306



 42%|████▏     | 42/100 [13:00<17:57, 18.58s/it][A

1.5593604814438593



 43%|████▎     | 43/100 [13:19<17:39, 18.59s/it][A

1.5454089528038388



 44%|████▍     | 44/100 [13:38<17:20, 18.59s/it][A

1.5260825072016035



 45%|████▌     | 45/100 [13:56<17:03, 18.61s/it][A

1.5115464556784857



 46%|████▌     | 46/100 [14:15<16:46, 18.63s/it][A

1.4904887335641044



 47%|████▋     | 47/100 [14:34<16:27, 18.64s/it][A

1.4557872471355258



 48%|████▊     | 48/100 [14:52<16:09, 18.65s/it][A

1.4381888622329349



 49%|████▉     | 49/100 [15:11<15:51, 18.66s/it][A

1.4320196991875058



 50%|█████     | 50/100 [15:29<15:30, 18.61s/it][A

1.4233986962409246



 51%|█████     | 51/100 [15:48<15:12, 18.62s/it][A

1.4106879376229786



 52%|█████▏    | 52/100 [16:07<14:53, 18.61s/it][A

1.3802123296828497



 53%|█████▎    | 53/100 [16:25<14:34, 18.61s/it][A

1.3522016406059265



 54%|█████▍    | 54/100 [16:44<14:15, 18.61s/it][A

1.3390515815644037



 55%|█████▌    | 55/100 [17:02<13:57, 18.62s/it][A

1.3128857272011893



 56%|█████▌    | 56/100 [17:21<13:38, 18.60s/it][A

1.3025906483332317



 57%|█████▋    | 57/100 [17:40<13:19, 18.59s/it][A

1.294151632558732



 58%|█████▊    | 58/100 [17:58<13:00, 18.58s/it][A

1.281455167702266



 59%|█████▉    | 59/100 [18:17<12:42, 18.59s/it][A

1.2760366059484936



 60%|██████    | 60/100 [18:35<12:23, 18.59s/it][A

1.2574312999134971



 61%|██████    | 61/100 [18:54<12:05, 18.59s/it][A

1.2459200819333394



 62%|██████▏   | 62/100 [19:13<11:46, 18.59s/it][A

1.2437462153888883



 63%|██████▎   | 63/100 [19:31<11:27, 18.59s/it][A

1.236688202335721



 64%|██████▍   | 64/100 [19:50<11:09, 18.59s/it][A

1.2368552968615578



 65%|██████▌   | 65/100 [20:08<10:50, 18.59s/it][A

1.2099698725200834



 66%|██████▌   | 66/100 [20:27<10:31, 18.59s/it][A

1.1930986642837524



 67%|██████▋   | 67/100 [20:46<10:13, 18.60s/it][A

1.1814285658654713



 68%|██████▊   | 68/100 [21:04<09:55, 18.60s/it][A

1.1726281898362296



 69%|██████▉   | 69/100 [21:23<09:36, 18.61s/it][A

1.1702604407355899



 70%|███████   | 70/100 [21:41<09:18, 18.61s/it][A

1.169128642195747



 71%|███████   | 71/100 [22:00<08:59, 18.60s/it][A

1.1584483924366178



 72%|███████▏  | 72/100 [22:18<08:40, 18.57s/it][A

1.148444482258388



 73%|███████▎  | 73/100 [22:37<08:21, 18.56s/it][A

1.1431063952900113



 74%|███████▍  | 74/100 [22:56<08:02, 18.55s/it][A

1.1316706736882527



 75%|███████▌  | 75/100 [23:14<07:43, 18.54s/it][A

1.1176568610327584



 76%|███████▌  | 76/100 [23:33<07:24, 18.53s/it][A

1.1106196159408206



 77%|███████▋  | 77/100 [23:51<07:05, 18.51s/it][A

1.1087077941213335



 78%|███████▊  | 78/100 [24:10<06:47, 18.54s/it][A

1.1025266675722032



 79%|███████▉  | 79/100 [24:28<06:29, 18.57s/it][A

1.0902453717731295



 80%|████████  | 80/100 [24:47<06:11, 18.58s/it][A

1.0870509828839983



 81%|████████  | 81/100 [25:06<05:53, 18.60s/it][A

1.0893824015344893



 82%|████████▏ | 82/100 [25:24<05:34, 18.61s/it][A

1.086312745298658



 83%|████████▎ | 83/100 [25:43<05:16, 18.61s/it][A

1.087669551372528



 84%|████████▍ | 84/100 [26:01<04:57, 18.61s/it][A

1.087491940884363



 85%|████████▌ | 85/100 [26:20<04:39, 18.62s/it][A

1.0859626843815757



 86%|████████▌ | 86/100 [26:39<04:20, 18.64s/it][A

1.0799098525728499



 87%|████████▋ | 87/100 [26:57<04:02, 18.69s/it][A

1.0726587431771415



 88%|████████▊ | 88/100 [27:16<03:44, 18.70s/it][A

1.0647663218634469



 89%|████████▉ | 89/100 [27:35<03:25, 18.70s/it][A

1.0614243036224729



 90%|█████████ | 90/100 [27:54<03:07, 18.72s/it][A

1.0581752005077543



 91%|█████████ | 91/100 [28:12<02:48, 18.73s/it][A

1.0561862673078264



 92%|█████████▏| 92/100 [28:31<02:29, 18.74s/it][A

1.0533562416122073



 93%|█████████▎| 93/100 [28:50<02:10, 18.70s/it][A

1.0522448619206746



 94%|█████████▍| 94/100 [29:09<01:52, 18.71s/it][A

1.0454322298367817



 95%|█████████▌| 95/100 [29:27<01:33, 18.70s/it][A

1.0401689921106612



 96%|█████████▌| 96/100 [29:46<01:14, 18.64s/it][A

1.0389195538702465



 97%|█████████▋| 97/100 [30:04<00:55, 18.61s/it][A

1.0372892646562486



 98%|█████████▊| 98/100 [30:23<00:37, 18.58s/it][A

1.037779030345735



 99%|█████████▉| 99/100 [30:41<00:18, 18.56s/it][A

1.0358521938323975



100%|██████████| 100/100 [31:00<00:00, 18.60s/it][A

1.0361000867117018
\Training took: 0:31:00.39
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [33]:
model_score2

{'Bleu_1': 0.5858028945554521,
 'Bleu_2': 0.45248459452676837,
 'Bleu_3': 0.37005745553483543,
 'Bleu_4': 0.3134290887472038,
 'METEOR': 0.264380832306559,
 'ROUGE_L': 0.49162618939536,
 'CIDEr': 1.6261822918239672,
 'SPICE': 0.3393889607956865,
 'USC_similarity': 0.5671558655841357}

In [34]:
caption_model3, model_score3 = cross_validation(cv[2][0], cv[2][1], 3)    

Split 3:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2714 ==> 890
The vocabulary size is 891.
800 out of 891 words are found in the pre-trained matrix.
The size of embedding_matrix is (891, 500)
Preparing dataloader...

Generating set took: 0:03:29.38



  0%|          | 0/100 [00:00<?, ?it/s][A


Generating set took: 0:00:52.28
Training...



  1%|          | 1/100 [00:18<30:52, 18.71s/it][A

5.712037881215413



  2%|▏         | 2/100 [00:37<30:30, 18.68s/it][A

4.932338578360421



  3%|▎         | 3/100 [00:55<30:07, 18.64s/it][A

4.326750460125151



  4%|▍         | 4/100 [01:14<29:43, 18.58s/it][A

3.8559012753622874



  5%|▌         | 5/100 [01:32<29:22, 18.55s/it][A

3.580123435883295



  6%|▌         | 6/100 [01:51<29:03, 18.55s/it][A

3.3769991454623995



  7%|▋         | 7/100 [02:09<28:43, 18.53s/it][A

3.212222383135841



  8%|▊         | 8/100 [02:28<28:25, 18.54s/it][A

3.06928521110898



  9%|▉         | 9/100 [02:46<28:07, 18.55s/it][A

2.9355564571562267



 10%|█         | 10/100 [03:05<27:47, 18.53s/it][A

2.8241520609174455



 11%|█         | 11/100 [03:23<27:28, 18.52s/it][A

2.7187563464755105



 12%|█▏        | 12/100 [03:42<27:10, 18.53s/it][A

2.615761473065331



 13%|█▎        | 13/100 [04:01<26:55, 18.57s/it][A

2.527436449414208



 14%|█▍        | 14/100 [04:19<26:41, 18.62s/it][A

2.4546729099182856



 15%|█▌        | 15/100 [04:38<26:24, 18.64s/it][A

2.4075016521272206



 16%|█▌        | 16/100 [04:57<26:06, 18.65s/it][A

2.3443445364634194



 17%|█▋        | 17/100 [05:15<25:47, 18.65s/it][A

2.272489513669695



 18%|█▊        | 18/100 [05:34<25:28, 18.64s/it][A

2.2426955926985968



 19%|█▉        | 19/100 [05:53<25:11, 18.66s/it][A

2.191297392050425



 20%|██        | 20/100 [06:11<24:53, 18.67s/it][A

2.1369940439860025



 21%|██        | 21/100 [06:30<24:36, 18.69s/it][A

2.1047563496090116



 22%|██▏       | 22/100 [06:49<24:18, 18.69s/it][A

2.0621026754379272



 23%|██▎       | 23/100 [07:08<23:59, 18.70s/it][A

1.9991351649874733



 24%|██▍       | 24/100 [07:26<23:40, 18.69s/it][A

1.96524159965061



 25%|██▌       | 25/100 [07:45<23:20, 18.67s/it][A

1.9601436030297053



 26%|██▌       | 26/100 [08:03<23:00, 18.65s/it][A

1.9257129459154039



 27%|██▋       | 27/100 [08:22<22:38, 18.62s/it][A

1.8589158597446622



 28%|██▊       | 28/100 [08:41<22:19, 18.60s/it][A

1.8260627871467954



 29%|██▉       | 29/100 [08:59<21:59, 18.59s/it][A

1.7887815549260093



 30%|███       | 30/100 [09:18<21:42, 18.60s/it][A

1.759597327028002



 31%|███       | 31/100 [09:36<21:25, 18.63s/it][A

1.7391056560334706



 32%|███▏      | 32/100 [09:55<21:07, 18.64s/it][A

1.72751046646209



 33%|███▎      | 33/100 [10:14<20:50, 18.66s/it][A

1.69604838462103



 34%|███▍      | 34/100 [10:32<20:31, 18.65s/it][A

1.674020855199723



 35%|███▌      | 35/100 [10:51<20:11, 18.63s/it][A

1.666692398843311



 36%|███▌      | 36/100 [11:10<19:51, 18.62s/it][A

1.6283243014698936



 37%|███▋      | 37/100 [11:28<19:31, 18.59s/it][A

1.579988916714986



 38%|███▊      | 38/100 [11:47<19:10, 18.55s/it][A

1.5492251884369623



 39%|███▉      | 39/100 [12:05<18:50, 18.53s/it][A

1.5554585655530293



 40%|████      | 40/100 [12:24<18:30, 18.51s/it][A

1.5415214867818923



 41%|████      | 41/100 [12:42<18:11, 18.50s/it][A

1.528829900991349



 42%|████▏     | 42/100 [13:01<17:53, 18.50s/it][A

1.495171464624859



 43%|████▎     | 43/100 [13:19<17:34, 18.50s/it][A

1.4535922890617734



 44%|████▍     | 44/100 [13:38<17:17, 18.53s/it][A

1.4189873479661488



 45%|████▌     | 45/100 [13:56<16:59, 18.54s/it][A

1.3985544812111628



 46%|████▌     | 46/100 [14:15<16:41, 18.55s/it][A

1.3764368352435885



 47%|████▋     | 47/100 [14:33<16:24, 18.58s/it][A

1.369089657352084



 48%|████▊     | 48/100 [14:52<16:05, 18.57s/it][A

1.3591309700693404



 49%|████▉     | 49/100 [15:10<15:45, 18.54s/it][A

1.338428727218083



 50%|█████     | 50/100 [15:29<15:25, 18.51s/it][A

1.3222371879078092



 51%|█████     | 51/100 [15:47<15:06, 18.50s/it][A

1.311677058537801



 52%|█████▏    | 52/100 [16:06<14:47, 18.49s/it][A

1.2960781767254783



 53%|█████▎    | 53/100 [16:24<14:29, 18.50s/it][A

1.2866599361101787



 54%|█████▍    | 54/100 [16:43<14:11, 18.51s/it][A

1.278223190988813



 55%|█████▌    | 55/100 [17:01<13:52, 18.50s/it][A

1.2670866591589791



 56%|█████▌    | 56/100 [17:20<13:34, 18.51s/it][A

1.2652340332667034



 57%|█████▋    | 57/100 [17:38<13:16, 18.52s/it][A

1.2616304670061385



 58%|█████▊    | 58/100 [17:57<12:58, 18.53s/it][A

1.249121089776357



 59%|█████▉    | 59/100 [18:16<12:40, 18.55s/it][A

1.2263264939898537



 60%|██████    | 60/100 [18:34<12:22, 18.55s/it][A

1.2054858803749084



 61%|██████    | 61/100 [18:53<12:03, 18.55s/it][A

1.198428261847723



 62%|██████▏   | 62/100 [19:11<11:44, 18.55s/it][A

1.1945498018037706



 63%|██████▎   | 63/100 [19:30<11:25, 18.53s/it][A

1.1944213793391274



 64%|██████▍   | 64/100 [19:48<11:06, 18.52s/it][A

1.1878352846418108



 65%|██████▌   | 65/100 [20:07<10:47, 18.50s/it][A

1.1811508394422985



 66%|██████▌   | 66/100 [20:25<10:28, 18.48s/it][A

1.1629324158032734



 67%|██████▋   | 67/100 [20:44<10:09, 18.47s/it][A

1.1453009701910473



 68%|██████▊   | 68/100 [21:02<09:51, 18.47s/it][A

1.1374108904883975



 69%|██████▉   | 69/100 [21:21<09:32, 18.48s/it][A

1.135518042814164



 70%|███████   | 70/100 [21:39<09:14, 18.47s/it][A

1.138623487381708



 71%|███████   | 71/100 [21:57<08:55, 18.46s/it][A

1.1220365422112601



 72%|███████▏  | 72/100 [22:16<08:36, 18.46s/it][A

1.109202660265423



 73%|███████▎  | 73/100 [22:34<08:18, 18.46s/it][A

1.0983373749823797



 74%|███████▍  | 74/100 [22:53<07:59, 18.46s/it][A

1.0921446340424674



 75%|███████▌  | 75/100 [23:11<07:41, 18.46s/it][A

1.0849968024662562



 76%|███████▌  | 76/100 [23:30<07:22, 18.45s/it][A

1.0804252539362227



 77%|███████▋  | 77/100 [23:48<07:04, 18.45s/it][A

1.078854575043633



 78%|███████▊  | 78/100 [24:07<06:45, 18.45s/it][A

1.0726097169376554



 79%|███████▉  | 79/100 [24:25<06:27, 18.44s/it][A

1.0672314592770167



 80%|████████  | 80/100 [24:43<06:08, 18.45s/it][A

1.0618795241628374



 81%|████████  | 81/100 [25:02<05:50, 18.45s/it][A

1.0607896816162836



 82%|████████▏ | 82/100 [25:20<05:32, 18.45s/it][A

1.0589216181210108



 83%|████████▎ | 83/100 [25:39<05:13, 18.46s/it][A

1.0562623824392046



 84%|████████▍ | 84/100 [25:57<04:55, 18.45s/it][A

1.0522373347055345



 85%|████████▌ | 85/100 [26:16<04:36, 18.45s/it][A

1.055571266583034



 86%|████████▌ | 86/100 [26:34<04:19, 18.51s/it][A

1.0565070651826405



 87%|████████▋ | 87/100 [26:53<04:00, 18.54s/it][A

1.0554760410672142



 88%|████████▊ | 88/100 [27:12<03:42, 18.58s/it][A

1.0484154933974856



 89%|████████▉ | 89/100 [27:30<03:24, 18.60s/it][A

1.0450875191461473



 90%|█████████ | 90/100 [27:49<03:05, 18.58s/it][A

1.0418530986422585



 91%|█████████ | 91/100 [28:07<02:47, 18.58s/it][A

1.041076742467426



 92%|█████████▏| 92/100 [28:26<02:28, 18.58s/it][A

1.0426572334198725



 93%|█████████▎| 93/100 [28:45<02:09, 18.57s/it][A

1.04421728849411



 94%|█████████▍| 94/100 [29:03<01:51, 18.57s/it][A

1.0370479311261858



 95%|█████████▌| 95/100 [29:22<01:32, 18.55s/it][A

1.0274821037337893



 96%|█████████▌| 96/100 [29:40<01:14, 18.52s/it][A

1.0230711443083627



 97%|█████████▋| 97/100 [29:59<00:55, 18.51s/it][A

1.0210022188368297



 98%|█████████▊| 98/100 [30:17<00:36, 18.50s/it][A

1.0193205305508204



 99%|█████████▉| 99/100 [30:35<00:18, 18.48s/it][A

1.0178089567593165



100%|██████████| 100/100 [30:54<00:00, 18.54s/it][A

1.0164445085184914
\Training took: 0:30:54.47
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [35]:
model_score3

{'Bleu_1': 0.5854734230715952,
 'Bleu_2': 0.4572394421707651,
 'Bleu_3': 0.3775559892032205,
 'Bleu_4': 0.32208131711482635,
 'METEOR': 0.2658116557131826,
 'ROUGE_L': 0.49503577487950834,
 'CIDEr': 1.7618325602697977,
 'SPICE': 0.34379926977510017,
 'USC_similarity': 0.572967501116565}

In [36]:
caption_model4, model_score4 = cross_validation(cv[3][0], cv[3][1], 4)    

Split 4:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2680 ==> 894
The vocabulary size is 895.
806 out of 895 words are found in the pre-trained matrix.
The size of embedding_matrix is (895, 500)
Preparing dataloader...

Generating set took: 0:03:28.76



  0%|          | 0/100 [00:00<?, ?it/s][A


Generating set took: 0:00:52.18
Training...



  1%|          | 1/100 [00:18<30:40, 18.59s/it][A

5.675316322417486



  2%|▏         | 2/100 [00:37<30:23, 18.60s/it][A

4.841327928361439



  3%|▎         | 3/100 [00:55<30:04, 18.60s/it][A

4.161596928324018



  4%|▍         | 4/100 [01:14<29:45, 18.60s/it][A

3.712532259169079



  5%|▌         | 5/100 [01:33<29:27, 18.60s/it][A

3.4538447118940807



  6%|▌         | 6/100 [01:51<29:08, 18.60s/it][A

3.2618040527616228



  7%|▋         | 7/100 [02:10<28:49, 18.60s/it][A

3.1101167372294833



  8%|▊         | 8/100 [02:28<28:32, 18.62s/it][A

2.9688100871585665



  9%|▉         | 9/100 [02:47<28:14, 18.62s/it][A

2.832395388966515



 10%|█         | 10/100 [03:06<27:55, 18.62s/it][A

2.7138796022960117



 11%|█         | 11/100 [03:24<27:34, 18.59s/it][A

2.612712405976795



 12%|█▏        | 12/100 [03:43<27:14, 18.57s/it][A

2.5464777776173184



 13%|█▎        | 13/100 [04:01<26:54, 18.56s/it][A

2.4762042533783686



 14%|█▍        | 14/100 [04:20<26:35, 18.55s/it][A

2.3898012978690013



 15%|█▌        | 15/100 [04:38<26:15, 18.54s/it][A

2.318768268539792



 16%|█▌        | 16/100 [04:57<25:56, 18.53s/it][A

2.2524713902246383



 17%|█▋        | 17/100 [05:15<25:38, 18.54s/it][A

2.200103751250676



 18%|█▊        | 18/100 [05:34<25:19, 18.53s/it][A

2.153034576347896



 19%|█▉        | 19/100 [05:52<25:00, 18.53s/it][A

2.106910929793403



 20%|██        | 20/100 [06:11<24:41, 18.52s/it][A

2.0564714585031783



 21%|██        | 21/100 [06:30<24:27, 18.57s/it][A

2.0240023732185364



 22%|██▏       | 22/100 [06:48<24:11, 18.61s/it][A

1.9881544765971957



 23%|██▎       | 23/100 [07:07<23:52, 18.61s/it][A

1.961320005712055



 24%|██▍       | 24/100 [07:25<23:35, 18.62s/it][A

1.901008495262691



 25%|██▌       | 25/100 [07:44<23:15, 18.61s/it][A

1.850433817931584



 26%|██▌       | 26/100 [08:03<22:55, 18.58s/it][A

1.8052872561273121



 27%|██▋       | 27/100 [08:21<22:34, 18.55s/it][A

1.7683849930763245



 28%|██▊       | 28/100 [08:40<22:13, 18.53s/it][A

1.7425233977181571



 29%|██▉       | 29/100 [08:58<21:53, 18.51s/it][A

1.7304909513110207



 30%|███       | 30/100 [09:16<21:34, 18.50s/it][A

1.7052570694968814



 31%|███       | 31/100 [09:35<21:16, 18.50s/it][A

1.692984163761139



 32%|███▏      | 32/100 [09:53<20:57, 18.49s/it][A

1.6704472502072651



 33%|███▎      | 33/100 [10:12<20:38, 18.49s/it][A

1.6228821987197513



 34%|███▍      | 34/100 [10:31<20:22, 18.52s/it][A

1.5759145305270241



 35%|███▌      | 35/100 [10:49<20:05, 18.55s/it][A

1.5516304572423298



 36%|███▌      | 36/100 [11:08<19:47, 18.56s/it][A

1.5372886969929649



 37%|███▋      | 37/100 [11:26<19:31, 18.60s/it][A

1.5324072355315799



 38%|███▊      | 38/100 [11:45<19:14, 18.63s/it][A

1.496502989814395



 39%|███▉      | 39/100 [12:04<18:57, 18.65s/it][A

1.4634135450635637



 40%|████      | 40/100 [12:22<18:39, 18.65s/it][A

1.4485207256816683



 41%|████      | 41/100 [12:41<18:20, 18.66s/it][A

1.4342197946139745



 42%|████▏     | 42/100 [13:00<18:02, 18.67s/it][A

1.427327139036996



 43%|████▎     | 43/100 [13:18<17:43, 18.65s/it][A

1.410326063632965



 44%|████▍     | 44/100 [13:37<17:23, 18.63s/it][A

1.3979160842441378



 45%|████▌     | 45/100 [13:56<17:03, 18.61s/it][A

1.3879306628590538



 46%|████▌     | 46/100 [14:14<16:44, 18.60s/it][A

1.367154754343487



 47%|████▋     | 47/100 [14:33<16:25, 18.59s/it][A

1.3422825790586925



 48%|████▊     | 48/100 [14:51<16:05, 18.57s/it][A

1.3156802342051552



 49%|████▉     | 49/100 [15:10<15:47, 18.57s/it][A

1.3026043006352015



 50%|█████     | 50/100 [15:28<15:28, 18.57s/it][A

1.2909240609123593



 51%|█████     | 51/100 [15:47<15:09, 18.56s/it][A

1.2830700562113808



 52%|█████▏    | 52/100 [16:06<14:51, 18.57s/it][A

1.268319357009161



 53%|█████▎    | 53/100 [16:24<14:32, 18.56s/it][A

1.2535731338319325



 54%|█████▍    | 54/100 [16:43<14:14, 18.57s/it][A

1.2373479349272591



 55%|█████▌    | 55/100 [17:01<13:56, 18.58s/it][A

1.222406852812994



 56%|█████▌    | 56/100 [17:20<13:38, 18.61s/it][A

1.2148038092113675



 57%|█████▋    | 57/100 [17:38<13:19, 18.59s/it][A

1.2065764211473011



 58%|█████▊    | 58/100 [17:57<12:59, 18.55s/it][A

1.1930105487505596



 59%|█████▉    | 59/100 [18:15<12:39, 18.52s/it][A

1.19572514295578



 60%|██████    | 60/100 [18:34<12:19, 18.50s/it][A

1.195748703820365



 61%|██████    | 61/100 [18:52<12:01, 18.49s/it][A

1.1806063112758456



 62%|██████▏   | 62/100 [19:11<11:42, 18.49s/it][A

1.151613638514564



 63%|██████▎   | 63/100 [19:29<11:24, 18.50s/it][A

1.138407968339466



 64%|██████▍   | 64/100 [19:48<11:05, 18.49s/it][A

1.124070834545862



 65%|██████▌   | 65/100 [20:06<10:47, 18.49s/it][A

1.1128573247364588



 66%|██████▌   | 66/100 [20:25<10:29, 18.51s/it][A

1.1075365997496105



 67%|██████▋   | 67/100 [20:43<10:10, 18.51s/it][A

1.101896549974169



 68%|██████▊   | 68/100 [21:02<09:52, 18.51s/it][A

1.1004151474861872



 69%|██████▉   | 69/100 [21:21<09:35, 18.56s/it][A

1.0992620161601476



 70%|███████   | 70/100 [21:39<09:18, 18.61s/it][A

1.096957232270922



 71%|███████   | 71/100 [21:58<08:59, 18.61s/it][A

1.0927958970978147



 72%|███████▏  | 72/100 [22:17<08:41, 18.63s/it][A

1.090716872896467



 73%|███████▎  | 73/100 [22:35<08:23, 18.64s/it][A

1.0853179977053689



 74%|███████▍  | 74/100 [22:54<08:04, 18.64s/it][A

1.0801281758717127



 75%|███████▌  | 75/100 [23:13<07:46, 18.65s/it][A

1.0739456358410062



 76%|███████▌  | 76/100 [23:31<07:28, 18.67s/it][A

1.0703642538615636



 77%|███████▋  | 77/100 [23:50<07:09, 18.66s/it][A

1.0712399965240842



 78%|███████▊  | 78/100 [24:09<06:50, 18.67s/it][A

1.0717320498966036



 79%|███████▉  | 79/100 [24:27<06:32, 18.68s/it][A

1.0718286605108351



 80%|████████  | 80/100 [24:46<06:13, 18.66s/it][A

1.0705538590749104



 81%|████████  | 81/100 [25:05<05:54, 18.66s/it][A

1.0568321574301947



 82%|████████▏ | 82/100 [25:23<05:36, 18.68s/it][A

1.0444110802241735



 83%|████████▎ | 83/100 [25:42<05:17, 18.68s/it][A

1.0400676358313787



 84%|████████▍ | 84/100 [26:01<04:58, 18.68s/it][A

1.0379364831107003



 85%|████████▌ | 85/100 [26:19<04:40, 18.68s/it][A

1.034679543404352



 86%|████████▌ | 86/100 [26:38<04:21, 18.67s/it][A

1.031245890117827



 87%|████████▋ | 87/100 [26:57<04:02, 18.69s/it][A

1.0307796086583818



 88%|████████▊ | 88/100 [27:15<03:44, 18.69s/it][A

1.029569923877716



 89%|████████▉ | 89/100 [27:34<03:25, 18.69s/it][A

1.0297322471936543



 90%|█████████ | 90/100 [27:53<03:06, 18.66s/it][A

1.0299853796050662



 91%|█████████ | 91/100 [28:11<02:48, 18.68s/it][A

1.0297354658444722
1.0312229990959167
\Training took: 0:28:30.61
Generating captions...
tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [37]:
model_score4

{'Bleu_1': 0.5876205063518971,
 'Bleu_2': 0.45766534771248313,
 'Bleu_3': 0.37637573464992274,
 'Bleu_4': 0.31924063130965513,
 'METEOR': 0.27339544200377514,
 'ROUGE_L': 0.4988139120291324,
 'CIDEr': 1.6966393538343623,
 'SPICE': 0.34348189816932995,
 'USC_similarity': 0.574212785236237}

In [38]:
caption_model5, model_score5 = cross_validation(cv[4][0], cv[4][1], 5)    

Split 5:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2657 ==> 905
The vocabulary size is 906.
815 out of 906 words are found in the pre-trained matrix.
The size of embedding_matrix is (906, 500)
Preparing dataloader...

Generating set took: 0:03:29.79




  0%|          | 0/100 [00:00<?, ?it/s][A[A


Generating set took: 0:00:52.45
Training...




  1%|          | 1/100 [00:18<30:46, 18.65s/it][A[A

5.7748589515686035




  2%|▏         | 2/100 [00:37<30:24, 18.61s/it][A[A

4.888509784426008




  3%|▎         | 3/100 [00:55<30:04, 18.60s/it][A[A

4.224347296215239




  4%|▍         | 4/100 [01:14<29:45, 18.60s/it][A[A

3.8337473358426775




  5%|▌         | 5/100 [01:33<29:29, 18.63s/it][A[A

3.5753975766045705




  6%|▌         | 6/100 [01:51<29:09, 18.61s/it][A[A

3.3732371159962247




  7%|▋         | 7/100 [02:10<28:52, 18.63s/it][A[A

3.218034511520749




  8%|▊         | 8/100 [02:28<28:34, 18.64s/it][A[A

3.0777378309340704




  9%|▉         | 9/100 [02:47<28:12, 18.59s/it][A[A

2.9621664455958774




 10%|█         | 10/100 [03:06<27:58, 18.65s/it][A[A

2.8431124630428495




 11%|█         | 11/100 [03:24<27:38, 18.64s/it][A[A

2.7388022797448293




 12%|█▏        | 12/100 [03:43<27:19, 18.64s/it][A[A

2.6608681678771973




 13%|█▎        | 13/100 [04:02<27:01, 18.64s/it][A[A

2.599840658051627




 14%|█▍        | 14/100 [04:20<26:44, 18.66s/it][A[A

2.517498180979774




 15%|█▌        | 15/100 [04:39<26:28, 18.69s/it][A[A

2.432633258047558




 16%|█▌        | 16/100 [04:58<26:09, 18.69s/it][A[A

2.363809420948937




 17%|█▋        | 17/100 [05:16<25:48, 18.65s/it][A[A

2.3134600406601313




 18%|█▊        | 18/100 [05:35<25:29, 18.66s/it][A[A

2.269849215235029




 19%|█▉        | 19/100 [05:54<25:11, 18.67s/it][A[A

2.222342814717974




 20%|██        | 20/100 [06:12<24:50, 18.63s/it][A[A

2.1687381664911904




 21%|██        | 21/100 [06:31<24:28, 18.59s/it][A[A

2.1250135330926803




 22%|██▏       | 22/100 [06:49<24:07, 18.56s/it][A[A

2.0887558375086104




 23%|██▎       | 23/100 [07:08<23:48, 18.55s/it][A[A

2.043988369760059




 24%|██▍       | 24/100 [07:26<23:30, 18.56s/it][A[A

1.9986738193602789




 25%|██▌       | 25/100 [07:45<23:12, 18.57s/it][A[A

1.9723142953146071




 26%|██▌       | 26/100 [08:03<22:52, 18.55s/it][A[A

1.955572281564985




 27%|██▋       | 27/100 [08:22<22:33, 18.54s/it][A[A

1.9100457032521565




 28%|██▊       | 28/100 [08:41<22:16, 18.56s/it][A[A

1.8668239343734014




 29%|██▉       | 29/100 [08:59<21:58, 18.57s/it][A[A

1.821322807243892




 30%|███       | 30/100 [09:18<21:40, 18.58s/it][A[A

1.7966350260235013




 31%|███       | 31/100 [09:36<21:20, 18.56s/it][A[A

1.78754270644415




 32%|███▏      | 32/100 [09:55<21:00, 18.54s/it][A[A

1.7583451100758143




 33%|███▎      | 33/100 [10:13<20:41, 18.53s/it][A[A

1.7167134057907831




 34%|███▍      | 34/100 [10:32<20:24, 18.55s/it][A[A

1.6722272350674583




 35%|███▌      | 35/100 [10:51<20:07, 18.58s/it][A[A

1.643847298054468




 36%|███▌      | 36/100 [11:09<19:47, 18.56s/it][A[A

1.6251177872930254




 37%|███▋      | 37/100 [11:27<19:27, 18.53s/it][A[A

1.61086505651474




 38%|███▊      | 38/100 [11:46<19:08, 18.53s/it][A[A

1.606006457692101




 39%|███▉      | 39/100 [12:04<18:49, 18.52s/it][A[A

1.5860771934191387




 40%|████      | 40/100 [12:23<18:30, 18.51s/it][A[A

1.5530336726279486




 41%|████      | 41/100 [12:42<18:15, 18.56s/it][A[A

1.5305948825109572




 42%|████▏     | 42/100 [13:00<17:55, 18.54s/it][A[A

1.5067653939837502




 43%|████▎     | 43/100 [13:19<17:36, 18.53s/it][A[A

1.4852702135131473




 44%|████▍     | 44/100 [13:37<17:17, 18.53s/it][A[A

1.4762326251892817




 45%|████▌     | 45/100 [13:56<16:59, 18.54s/it][A[A

1.454352719443185




 46%|████▌     | 46/100 [14:14<16:44, 18.59s/it][A[A

1.4478379232542855




 47%|████▋     | 47/100 [14:33<16:26, 18.62s/it][A[A

1.4338717318716503




 48%|████▊     | 48/100 [14:52<16:06, 18.59s/it][A[A

1.418226806890397




 49%|████▉     | 49/100 [15:10<15:46, 18.56s/it][A[A

1.3935783193224953




 50%|█████     | 50/100 [15:29<15:26, 18.54s/it][A[A

1.3657359310558863




 51%|█████     | 51/100 [15:47<15:07, 18.53s/it][A[A

1.3353385272480192




 52%|█████▏    | 52/100 [16:06<14:49, 18.52s/it][A[A

1.3100941975911458




 53%|█████▎    | 53/100 [16:24<14:29, 18.51s/it][A[A

1.2908501227696736




 54%|█████▍    | 54/100 [16:43<14:13, 18.56s/it][A[A

1.2754855893907093




 55%|█████▌    | 55/100 [17:01<13:56, 18.59s/it][A[A

1.26716320003782




 56%|█████▌    | 56/100 [17:20<13:37, 18.57s/it][A[A

1.2605976206915719




 57%|█████▋    | 57/100 [17:39<13:17, 18.55s/it][A[A

1.2469689420291357




 58%|█████▊    | 58/100 [17:57<12:58, 18.53s/it][A[A

1.2416337700117202




 59%|█████▉    | 59/100 [18:16<12:39, 18.53s/it][A[A

1.238667261032831




 60%|██████    | 60/100 [18:34<12:21, 18.53s/it][A[A

1.2533366254397802




 61%|██████    | 61/100 [18:53<12:02, 18.52s/it][A[A

1.2397729612532116




 62%|██████▏   | 62/100 [19:11<11:43, 18.50s/it][A[A

1.2114862515812828




 63%|██████▎   | 63/100 [19:30<11:24, 18.50s/it][A[A

1.185162535735539




 64%|██████▍   | 64/100 [19:48<11:05, 18.50s/it][A[A

1.157526277360462




 65%|██████▌   | 65/100 [20:07<10:47, 18.51s/it][A[A

1.1392275066602797




 66%|██████▌   | 66/100 [20:25<10:29, 18.52s/it][A[A

1.1307034151894706




 67%|██████▋   | 67/100 [20:44<10:10, 18.51s/it][A[A

1.1208314526648748




 68%|██████▊   | 68/100 [21:02<09:52, 18.52s/it][A[A

1.1101645969209217




 69%|██████▉   | 69/100 [21:21<09:34, 18.54s/it][A[A

1.10147362947464




 70%|███████   | 70/100 [21:39<09:16, 18.56s/it][A[A

1.0896820851734705




 71%|███████   | 71/100 [21:58<08:58, 18.58s/it][A[A

1.0853212078412373




 72%|███████▏  | 72/100 [22:16<08:40, 18.58s/it][A[A

1.0816172531672887




 73%|███████▎  | 73/100 [22:35<08:21, 18.57s/it][A[A

1.0758276297932579




 74%|███████▍  | 74/100 [22:54<08:03, 18.58s/it][A[A

1.071910662310464




 75%|███████▌  | 75/100 [23:12<07:45, 18.60s/it][A[A

1.0696472355297633




 76%|███████▌  | 76/100 [23:31<07:26, 18.62s/it][A[A

1.0660209967976524




 77%|███████▋  | 77/100 [23:50<07:08, 18.63s/it][A[A

1.0648252566655476




 78%|███████▊  | 78/100 [24:08<06:49, 18.63s/it][A[A

1.0656699197632926




 79%|███████▉  | 79/100 [24:27<06:31, 18.64s/it][A[A

1.0635526606014796




 80%|████████  | 80/100 [24:46<06:13, 18.66s/it][A[A

1.0625227008547102




 81%|████████  | 81/100 [25:04<05:54, 18.68s/it][A[A

1.0657205553281874
1.0695557934897286
\Training took: 0:25:23.35
Generating captions...
tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [39]:
model_score5

{'Bleu_1': 0.5518339162406711,
 'Bleu_2': 0.4196023186597518,
 'Bleu_3': 0.33874750629249994,
 'Bleu_4': 0.2834808867567647,
 'METEOR': 0.24771855422437908,
 'ROUGE_L': 0.46235474520421005,
 'CIDEr': 1.5322677768915776,
 'SPICE': 0.30832705122819426,
 'USC_similarity': 0.5502760397826522}

In [40]:
model_scores = defaultdict(list)
for scores in [model_score1, model_score2, model_score3, model_score4, model_score5]:
    for key, value in scores.items():
        model_scores[key].append(value)

In [41]:
model_scores

defaultdict(list,
            {'Bleu_1': [0.6057266732905516,
              0.5858028945554521,
              0.5854734230715952,
              0.5876205063518971,
              0.5518339162406711],
             'Bleu_2': [0.47195977965507746,
              0.45248459452676837,
              0.4572394421707651,
              0.45766534771248313,
              0.4196023186597518],
             'Bleu_3': [0.3864751508901778,
              0.37005745553483543,
              0.3775559892032205,
              0.37637573464992274,
              0.33874750629249994],
             'Bleu_4': [0.3266986863206117,
              0.3134290887472038,
              0.32208131711482635,
              0.31924063130965513,
              0.2834808867567647],
             'METEOR': [0.27835407428470443,
              0.264380832306559,
              0.2658116557131826,
              0.27339544200377514,
              0.24771855422437908],
             'ROUGE_L': [0.5136516817675681,
              0.491626

In [42]:
tag = '19.1'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'w') as fp:
    json.dump(model_scores, fp)