## Image Captioning with Pytorch

The following contents are modified from MDS DSCI 575 lecture 8 demo

In [1]:
import os, sys, json
from collections import defaultdict
from tqdm import tqdm
import pickle
from time import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torchsummary import summary
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

from nltk.translate import bleu_score
from sklearn.model_selection import KFold

sys.path.append('../../scr/evaluation')
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.usc_sim.usc_sim import usc_sim
import subprocess


START = "startseq"
STOP = "endseq"
EPOCHS = 10
AWS = True


In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
# torch.cuda.empty_cache()
# import gc 
# gc.collect()

In [4]:
# Nicely formatted time string
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m:>02}:{s:>05.2f}"
        
if AWS:
    root_captioning = "../../data"
else:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        root_captioning = "/content/drive/My Drive/data"
        COLAB = True
        print("Note: using Google CoLab")
    except:
        print("Note: not using Google CoLab")
        COLAB = False

### Clean/Build Dataset

- Read captions
- Preprocess captions


In [5]:
def get_img_info(name, num=np.inf):
    """
    Returns img paths and captions

    Parameters:
    -----------
    name: str
        the json file name
    num: int (default: np.inf)
        the number of observations to get

    Return:
    --------
    list, dict, int
        img paths, corresponding captions, max length of captions
    """
    img_path = []
    caption = [] 
    max_length = 0
    if AWS:
        with open(f'{root_captioning}/json/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for filename in data.keys():
                if num is not None and len(caption) == num:
                    break
                img_path.append(
                    f'{root_captioning}/{name}/{filename}'
                )
                sen_list = []
                for sentence in data[filename]['sentences']:
                    max_length = max(max_length, len(sentence['tokens']))
                    sen_list.append(sentence['raw'])

                caption.append(sen_list)    
    else:            
        with open(f'{root_captioning}/interim/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for set_name in ['rsicd', 'ucm']:
                for filename in data[set_name].keys():
                    if num is not None and len(caption) == num:
                        break

                    img_path.append(
                        f'{root_captioning}/raw/imgs/{set_name}/{filename}'
                    )
                    sen_list = []
                    for sentence in data[set_name][filename]['sentences']:
                        max_length = max(max_length, len(sentence['tokens']))
                        sen_list.append(sentence['raw'])

                    caption.append(sen_list)
    
    return img_path, caption, max_length            


In [6]:
# get img path and caption list
# # only test 800 train samples and 200 valid samples
# train_paths, train_descriptions, max_length_train = get_img_info('train', 800)
# test_paths, test_descriptions, max_length_test = get_img_info('valid', 200)

train_paths, train_descriptions, max_length_train = get_img_info('train')
test_paths, test_descriptions, max_length_test = get_img_info('valid')
max_length = max(max_length_train, max_length_test)



In [7]:
all_paths = train_paths.copy()
all_paths.extend(test_paths.copy())
all_paths = np.array(all_paths)

all_descriptions = train_descriptions.copy()
all_descriptions.extend(test_descriptions.copy())
all_descriptions = np.array(all_descriptions)

captions = all_descriptions.copy()
max_length_all = max(max_length_train, max_length_test)
max_length = max_length_all + 2
      
lex = set()
for sen in all_descriptions:
    [lex.update(d.split()) for d in sen]
    
# add a start and stop token at the beginning/end
for v in all_descriptions:
    for d in range(len(v)):
        v[d] = f'{START} {v[d]} {STOP}'
        
print(f'There are {len(all_paths)} images') 
print(f'There are {len(lex)} unique words (vocab)')
print(f'The maximum length of captions with start and stop token is {max_length}.')


There are 10416 images
There are 2912 unique words (vocab)
The maximum length of captions with start and stop token is 36.


In [8]:
all_paths[-1]

'../../s3/valid/rsicd_park_33.jpg'

In [9]:
all_descriptions[-1]

array(['startseq a vast artificial lake was built in the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq'],
      dtype='<U184')

### Loading Wikipedia2vec Embeddings

In [10]:
# read the embedding matrix 
with open(f'{root_captioning}/enwiki_20180420_2338_words_500d.json', 'r', encoding='utf-8') as file:
    embeddings_index = json.load(file)

In [11]:
def get_vocab(descriptions, word_count_threshold=10):

    captions = []
    for val in descriptions:
        for cap in val:
            captions.append(cap)
    print(f'There are {len(captions)} captions')
    
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('preprocessed words %d ==> %d' % (len(word_counts), len(vocab)))
    return vocab

def get_word_dict(vocab):
    
    idxtoword = {}
    wordtoidx = {}

    ix = 1
    for w in vocab:
        wordtoidx[w] = ix
        idxtoword[ix] = w
        ix += 1

    return idxtoword, wordtoidx

def get_vocab_size(idxtoword):
    
    print(f'The vocabulary size is {len(idxtoword) + 1}.')
    return len(idxtoword) + 1


def get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx):

    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    count = 0

    for word, i in wordtoidx.items():

        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            count += 1
            # Words not found in the embedding index will be all zeros
            embedding_matrix[i] = embedding_vector
            
    print(f'{count} out of {vocab_size} words are found in the pre-trained matrix.')            
    print(f'The size of embedding_matrix is {embedding_matrix.shape}')
    return embedding_matrix

### Building the Neural Network

An embedding matrix is built from Glove.  This will be directly copied to the weight matrix of the neural network.

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
class CNNModel(nn.Module):

    def __init__(self, pretrained=True):
        """
        Initializes a CNNModel

        Parameters:
        -----------
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(CNNModel, self).__init__()

        # inception v3 expects (299, 299) sized images
        self.model = models.inception_v3(pretrained=pretrained, aux_logits=False)
        # remove the classification layer
        self.model =\
        nn.Sequential(
            *(list(self.model.children())[: 3]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            *(list(self.model.children())[3: 5]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            *(list(self.model.children())[5: -1])
        )

        self.input_size = 299

    def forward(self, img_input, train=False):
        """
        forward of the CNNModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        if not train:
          # set the model to evaluation model
          self.model.eval()

        # N x 3 x 299 x 299
        features = self.model(img_input)
        # N x 2048 x 8 x 8

        return features

In [14]:
class AttentionModel(nn.Module):

    def __init__(self, feature_size, hidden_size=256):
        """
        Initializes a AttentionModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(AttentionModel, self).__init__()


    def forward(self, img_features, h):
        """
        forward of the AttentionModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """

        # 1 x N x hidden_size
        h_a = h.repeat(img_features.size(1), 1, 1).permute(1, 0, 2)
        # N x 64 x hidden_size

        cos = nn.CosineSimilarity(dim=1, eps=1e-6)

        # attention scoring with cosine similarity
        attention = cos(h_a.permute(0, 2, 1), img_features.permute(0, 2, 1))
        # N x 64

        attention_weights = F.softmax(attention.unsqueeze(1), dim=2)
        # N x 1 x 64

        return attention_weights

In [15]:
class RNNModel(nn.Module):

    def __init__(
        self, 
        feature_size,
        vocab_size,
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):
      
        """
        Initializes a RNNModel

        Parameters:
        -----------
        feature_size: int
            the number of features in the image matrix
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """

        super(RNNModel, self).__init__()

        self.feature_size = feature_size
        self.hidden_size = hidden_size

        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()

        self.out_dense = nn.Linear(hidden_size, hidden_size)
        self.h_dense = nn.Linear(feature_size, hidden_size)
        self.c_dense = nn.Linear(feature_size, hidden_size)  
        self.img_dense = nn.Linear(feature_size, hidden_size)
        
        self.embedding =\
        nn.Embedding(vocab_size,embedding_dim, padding_idx=0)

        if embedding_matrix is not None:

            self.embedding.load_state_dict({
                'weight': torch.FloatTensor(embedding_matrix)
            })
            self.embedding.weight.requires_grad = embedding_train

        self.attention =\
        AttentionModel(feature_size, hidden_size)
        
        self.lstm =\
        nn.LSTM(hidden_size, hidden_size, batch_first=True)
      

    def forward(self, img_features, captions):
        """
        forward of the RNNModel

        Parameters:
        -----------
        img_features: torch.Tensor 
            the image feature matrix
            (N x feature_size(2048) x 8 x 8)
        captions: torch.Tensor 
            the padded caption matrix
            (N x seq_len)

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        # N = batch_size
        batch_size = captions.size(0)
        seq_len = captions.size(1)

        # N x feature_size(2048) x 8 x 8
        img_features =\
        img_features.view(
            batch_size, self.feature_size, -1
        ).permute(0, 2, 1)
        # N x 64 x feature_size(2048)

        # N x 64 x feature_size(2048)
        h = self.h_dense(img_features.mean(dim=1)).unsqueeze(0)
        c = self.c_dense(img_features.mean(dim=1)).unsqueeze(0)
        # 1 x N x hidden_size
       
        # N x feature_size(2048) x 8 x 8
        img_features =\
        self.relu(
            self.img_dense(
                    img_features
            )
        )  
        # N x 64 x hidden_size

        # N x seq_len
        embed =\
            self.embedding(
                captions
            )
        # N x seq_len x embedding_dim
        
        outputs =\
        torch.zeros(
            batch_size,
            seq_len, 
            self.hidden_size
        ).to(device)

        all_attention_weights =\
        torch.zeros(
            batch_size,
            seq_len, 
            img_features.shape[1]
        ).to(device)
        
        for i in range(seq_len):

            attention_weights = self.attention(img_features, h)
            # N x 1 x 64

            # weighted sum of img_features
            weighted = torch.bmm(attention_weights, img_features)
            # N x 1 x hidden_size

            output, (h, c) =\
            self.lstm(
                embed[:, i, :].unsqueeze(1) + weighted,
                (h, c)
            )
            # outputs: N x 1 x hidden_size
            # h: 1 x N x hidden_size
            # c: 1 x N x hidden_size

            output =\
            self.out_dense(
                output.squeeze(1) + weighted.squeeze(1) + embed[:, i, :]
            )
            # N x hidden_size
 
            outputs[:, i, :] = output.squeeze()
            all_attention_weights[:, i, :] = attention_weights.squeeze()

        return outputs, all_attention_weights



In [16]:
class CaptionModel(nn.Module):

    def __init__(
        self, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):

        """
        Initializes a CaptionModel

        Parameters:
        -----------
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """    
        super(CaptionModel, self).__init__() 

        # set feature_size based on cnn_type
        self.feature_size = 2048

        self.decoder = RNNModel(
            self.feature_size,
            vocab_size, 
            embedding_dim,
            hidden_size,
            embedding_matrix,
            embedding_train
        )

        self.relu = nn.ReLU()
        self.dense = nn.Linear(hidden_size, vocab_size) 

    # def forward(self, captions):
    def forward(self, img_features, captions):
        """
        forward of the CaptionModel

        Parameters:
        -----------
        img_features: torch.Tensor 
            the image feature matrix
            (N x feature_size(2048) x 8 x 8)
        captions: torch.Tensor 
            the padded caption matrix
            (N x seq_len)

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        decoder_out, all_attention_weights = self.decoder(img_features, captions)

        # add up decoder outputs and image features
        outputs =\
        self.dense(
            self.relu(
                decoder_out
            )
        )

        return outputs, all_attention_weights

### Train the Neural Network

In [17]:
def train(model, iterator, optimizer, criterion, clip, vocab_size):
    """
    train the CaptionModel

    Parameters:
    -----------
    model: CaptionModel
        a CaptionModel instance
    iterator: torch.utils.data.dataloader
        a PyTorch dataloader
    optimizer: torch.optim
        a PyTorch optimizer 
    criterion: nn.CrossEntropyLoss
        a PyTorch criterion 

    Return:
    --------
    float
        average loss
    """
    model.train()    
    epoch_loss = 0
    
    for img_features, captions in iterator:
        
        optimizer.zero_grad()

        # for each caption, the end word is not passed for training
        outputs, all_attention_weights = model(
            img_features.to(device),
            captions[:, :-1].to(device)
        )

        loss = criterion(
            outputs.view(-1, vocab_size), 
            captions[:, 1:].flatten().to(device)
        ) + ((1. - all_attention_weights.sum(dim=1)) ** 2).mean()
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        
    return epoch_loss / len(iterator)

In [18]:
class SampleDataset(Dataset):
    def __init__(
        self,
        descriptions,
        imgs,
        wordtoidx,
        max_length
    ):
        """
        Initializes a SampleDataset

        Parameters:
        -----------
        descriptions: list
            a list of captions
        imgs: numpy.ndarray
            the image features
        wordtoidx: dict
            the dict to get word index
        max_length: int
            all captions will be padded to this size
        """        
        self.imgs = imgs
        self.descriptions = descriptions
        self.wordtoidx = wordtoidx
        self.max_length = max_length

    def __len__(self):
        """
        Returns the batch size

        Return:
        --------
        int
            the batch size
        """
        # return len(self.descriptions)
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Prepare data for each image

        Parameters:
        -----------
        idx: int
          the index of the image to process

        Return:
        --------
        list, list, list
            [5 x image feature matrix],
            [five padded captions for this image]
            [the length of each caption]
        """

        img = self.imgs[idx // 5]
        # convert each word into a list of sequences.
        seq = [self.wordtoidx[word] for word 
               in self.descriptions[idx // 5][idx % 5].split(' ')
               if word in self.wordtoidx]
        # pad the sequence with 0 on the right side
        in_seq = np.pad(
            seq, 
            (0, max_length - len(seq)),
            mode='constant',
            constant_values=(0, 0)
            )

        return img, in_seq


In [19]:
def init_weights(model, embedding_pretrained=True):
    """
    Initialize weights and bias in the model

    Parameters:
    -----------
    model: CaptionModel
      a CaptionModel instance
    embedding_pretrained: bool (default: True)
        not initialize the embedding matrix if True
    """  
  
    for name, param in model.named_parameters():
        if embedding_pretrained and 'embedding' in name:
            continue
        elif 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            


In [20]:
def encode_image(model, img_path):
    """
    Process the images to extract features

    Parameters:
    -----------
    model: CNNModel
      a CNNModel instance
    img_path: str
        the path of the image
 
    Return:
    --------
    torch.Tensor
        the extracted feature matrix from CNNModel
    """  

    img = Image.open(img_path)

    # Perform preprocessing needed by pre-trained models
    preprocessor = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    img = preprocessor(img)
    # Expand to 2D array
    img = img.view(1, *img.shape)
    # Call model to extract the smaller feature set for the image.
    x = model(img.to(device), False) 
    # Shape to correct form to be accepted by LSTM captioning network.
    x = np.squeeze(x)
    return x

In [21]:
def extract_img_features(img_paths, model):
    """
    Extracts, stores and returns image features

    Parameters:
    -----------
    img_paths: list
        the paths of images
    model: CNNModel (default: None)
      a CNNModel instance

    Return:
    --------
    numpy.ndarray
        the extracted image feature matrix from CNNModel
    """ 

    start = time()
    img_features = []

    for image_path in tqdm(img_paths):
        img_features.append(
            encode_image(model, image_path).cpu().data.numpy()
        )

    print(f"\nGenerating set took: {hms_string(time()-start)}")

    return img_features

In [22]:
def get_train_test(
    encoder,
    train_paths,
    test_paths
):

    train_img_features = extract_img_features(
        train_paths,
        encoder
    )

    test_img_features = extract_img_features(
        test_paths,
        encoder
    )
    return train_img_features, test_img_features

def get_train_dataloader(
    train_descriptions, 
    train_img_features,
    wordtoidx,
    max_length,
    batch_size=200
):
    train_dataset = SampleDataset(
        train_descriptions,
        train_img_features,
        wordtoidx,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size
    )
    
    return train_loader

def train_model(
    train_loader,
    vocab_size,
    embedding_dim, 
    embedding_matrix,
    hidden_size=256,
):

    caption_model = CaptionModel(
        vocab_size, 
        embedding_dim, 
        hidden_size=hidden_size,
        embedding_matrix=embedding_matrix, 
        embedding_train=True
    )

    init_weights(
        caption_model,
        embedding_pretrained=True
    )

    caption_model.to(device)

    # we will ignore the pad token in true target set
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(
        caption_model.parameters(), 
        lr=0.01
    )

    clip = 1
    start = time()

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)

    # reduce the learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-4

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)
    return caption_model

In [23]:
def generateCaption(
    model, 
    img_features,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    in_text = START

    for i in range(max_length):

        sequence = [wordtoidx[w] for w in in_text.split() if w in wordtoidx]
        sequence = np.pad(sequence, (0, max_length - len(sequence)),
                          mode='constant', constant_values=(0, 0))
        model.eval()
        yhat, _ = model(
            torch.FloatTensor(img_features)\
            .view(-1, model.feature_size).to(device),
            torch.LongTensor(sequence).view(-1, max_length).to(device)
        )

        yhat = yhat.view(-1, vocab_size).argmax(1)
        word = idxtoword[yhat.cpu().data.numpy()[i]]
        in_text += ' ' + word
        if word == STOP:
            break
    final = in_text.split()
    final = final[1 : -1]
    final = ' '.join(final)
    return final

### Evaluation

In [24]:
def eval_model(ref_data, results):
    """
    Computes evaluation metrics of the model results against the human annotated captions
    
    Parameters:
    ------------
    ref_data: dict
        a dictionary containing human annotated captions, with image name as key and a list of human annotated captions as values
    
    results: dict
        a dictionary containing model generated caption, with image name as key and a generated caption as value
        
    Returns:
    ------------
    score_dict: a dictionary containing the overall average score for the model
    """
    # download stanford nlp library
    subprocess.call(['../../scr/evaluation/get_stanford_models.sh'])
    
    # format the inputs
    gts = {}
    res = {}

    for imgId in range(len(ref_data)):
        caption_list_sel = []
        for i in range(5):
            lst = {}
            lst['caption'] = ref_data[imgId][i]
            lst['image_id'] = imgId
            lst['id'] = i
            caption_list_sel.append(lst)
        gts[imgId] = caption_list_sel

        res[imgId] = [{'caption': results[imgId]}]
        
    # tokenize
    print('tokenization...')
    tokenizer = PTBTokenizer()
    gts  = tokenizer.tokenize(gts)
    res = tokenizer.tokenize(res)
    
    # compute scores
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE"),
        (usc_sim(), "USC_similarity"),  
        ]
    score_dict = {}
    for scorer, method in scorers:
        print('computing %s score...'%(scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
        else:
            score_dict[method] = score
            
    return score_dict


In [25]:
def evaluate_results(
    test_img_features, 
    model,
    ref,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    # generate results
    print('Generating captions...')
    results = {}
    for n in range(len(test_img_features)):
        img_features = test_img_features[n]
        generated = generateCaption(
            model, 
            img_features,
            max_length,
            vocab_size,
            wordtoidx,
            idxtoword
        )
        results[n] = generated
        
    model_score = eval_model(ref, results)

    return model_score

### Cross validation

In [26]:
cnn_type = 'inception_v3'
encoder = CNNModel(pretrained=True)
encoder.to(device)

CNNModel(
  (model): Sequential(
    (0): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicConv2d(
      (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1

In [27]:
def cross_validation(train_index, test_index, count):
    print('=' * 60)
    print(f'Split {count}:')
    print(f'Splitting data...')
    
    train_paths, test_paths = all_paths[train_index], all_paths[test_index]
    train_descriptions, test_descriptions = all_descriptions[train_index], all_descriptions[test_index]
    print(f'{len(train_paths)} images for training and {len(test_paths)} images for testing.')
    
    vocab = get_vocab(train_descriptions, word_count_threshold=10)
    idxtoword, wordtoidx = get_word_dict(vocab)
    vocab_size = get_vocab_size(idxtoword)
    embedding_dim = 500
    embedding_matrix = get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx) 

    print(f'Preparing dataloader...')
    train_img_features, test_img_features = get_train_test(encoder, train_paths, test_paths)

    train_loader = get_train_dataloader(
        train_descriptions, 
        train_img_features,
        wordtoidx,
        max_length,
        batch_size=1000
    )

    print(f'Training...')
    caption_model = train_model(
        train_loader,
        vocab_size,
        embedding_dim, 
        embedding_matrix,
        hidden_size=500
    )

    
    ref = captions[test_index]
    model_score = evaluate_results(
        test_img_features, 
        caption_model,
        ref,
        max_length,
        vocab_size,
        wordtoidx,
        idxtoword
    )
    
    return caption_model, model_score

In [28]:
cv = KFold(n_splits=5, random_state=123, shuffle=True)
cv = [(train_index, test_index) for train_index, test_index in cv.split(all_paths)]  

In [29]:
caption_model1, model_score1 = cross_validation(cv[0][0], cv[0][1], 1)    

Split 1:
Splitting data...
8332 images for training and 2084 images for testing.
There are 41660 captions


  0%|          | 0/8332 [00:00<?, ?it/s]

preprocessed words 2659 ==> 884
The vocabulary size is 885.
793 out of 885 words are found in the pre-trained matrix.
The size of embedding_matrix is (885, 500)
Preparing dataloader...


100%|██████████| 8332/8332 [03:31<00:00, 39.41it/s]
  0%|          | 5/2084 [00:00<00:51, 40.28it/s]


Generating set took: 0:03:31.42


100%|██████████| 2084/2084 [00:53<00:00, 39.30it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.03
Training...


  2%|▏         | 1/60 [00:05<05:39,  5.75s/it]

7.616918828752306


  3%|▎         | 2/60 [00:11<05:33,  5.74s/it]

5.1310256852044


  5%|▌         | 3/60 [00:17<05:27,  5.75s/it]

4.627918879191081


  7%|▋         | 4/60 [00:23<05:22,  5.76s/it]

4.276014222039117


  8%|▊         | 5/60 [00:28<05:16,  5.75s/it]

3.8751528528001575


 10%|█         | 6/60 [00:34<05:10,  5.75s/it]

3.5285346772935657


 12%|█▏        | 7/60 [00:40<05:05,  5.76s/it]

3.27301197581821


 13%|█▎        | 8/60 [00:46<04:59,  5.76s/it]

3.089503394232856


 15%|█▌        | 9/60 [00:51<04:53,  5.75s/it]

2.956931538052029


 17%|█▋        | 10/60 [00:57<04:47,  5.76s/it]

2.8751759794023304


 18%|█▊        | 11/60 [01:03<04:41,  5.75s/it]

2.8071044286092124


 20%|██        | 12/60 [01:09<04:36,  5.75s/it]

2.7689571380615234


 22%|██▏       | 13/60 [01:14<04:30,  5.75s/it]

2.7263804276784263


 23%|██▎       | 14/60 [01:20<04:24,  5.74s/it]

2.695460663901435


 25%|██▌       | 15/60 [01:26<04:18,  5.74s/it]

2.6513435575697155


 27%|██▋       | 16/60 [01:31<04:12,  5.74s/it]

2.622993442747328


 28%|██▊       | 17/60 [01:37<04:07,  5.75s/it]

2.5840039253234863


 30%|███       | 18/60 [01:43<04:01,  5.75s/it]

2.5572635332743325


 32%|███▏      | 19/60 [01:49<03:57,  5.78s/it]

2.5172086556752524


 33%|███▎      | 20/60 [01:55<03:50,  5.77s/it]

2.4580078125


 35%|███▌      | 21/60 [02:00<03:45,  5.77s/it]

2.4312569300333657


 37%|███▋      | 22/60 [02:06<03:39,  5.78s/it]

2.4074583848317466


 38%|███▊      | 23/60 [02:12<03:33,  5.77s/it]

2.3515859180026584


 40%|████      | 24/60 [02:18<03:27,  5.76s/it]

2.3085551261901855


 42%|████▏     | 25/60 [02:23<03:21,  5.76s/it]

2.273088561164008


 43%|████▎     | 26/60 [02:29<03:15,  5.76s/it]

2.2368261019388833


 45%|████▌     | 27/60 [02:35<03:09,  5.76s/it]

2.2426627741919622


 47%|████▋     | 28/60 [02:41<03:04,  5.76s/it]

2.2255051136016846


 48%|████▊     | 29/60 [02:46<02:58,  5.75s/it]

2.1529294782214694


 50%|█████     | 30/60 [02:52<02:52,  5.75s/it]

2.118993043899536


 52%|█████▏    | 31/60 [02:58<02:46,  5.75s/it]

2.1009145047929554


 53%|█████▎    | 32/60 [03:04<02:40,  5.75s/it]

2.075723820262485


 55%|█████▌    | 33/60 [03:09<02:35,  5.76s/it]

2.0477795203526816


 57%|█████▋    | 34/60 [03:15<02:29,  5.76s/it]

2.031914326879713


 58%|█████▊    | 35/60 [03:21<02:24,  5.77s/it]

2.014588157335917


 60%|██████    | 36/60 [03:27<02:18,  5.76s/it]

1.9899974796507094


 62%|██████▏   | 37/60 [03:32<02:12,  5.76s/it]

1.9483275810877483


 63%|██████▎   | 38/60 [03:38<02:06,  5.75s/it]

1.9467533694373236


 65%|██████▌   | 39/60 [03:44<02:00,  5.74s/it]

1.9220922920438979


 67%|██████▋   | 40/60 [03:50<01:54,  5.74s/it]

1.9355811807844374


 68%|██████▊   | 41/60 [03:55<01:49,  5.75s/it]

1.8942553599675496


 70%|███████   | 42/60 [04:01<01:43,  5.74s/it]

1.8511459959877863


 72%|███████▏  | 43/60 [04:07<01:38,  5.78s/it]

1.8322514692942302


 73%|███████▎  | 44/60 [04:13<01:32,  5.77s/it]

1.8035007980134752


 75%|███████▌  | 45/60 [04:19<01:26,  5.76s/it]

1.7937702205446031


 77%|███████▋  | 46/60 [04:24<01:20,  5.75s/it]

1.8088755475150213


 78%|███████▊  | 47/60 [04:30<01:14,  5.75s/it]

1.7970403300391302


 80%|████████  | 48/60 [04:36<01:08,  5.75s/it]

1.7829824950959947


 82%|████████▏ | 49/60 [04:42<01:03,  5.75s/it]

1.755581922001309


 83%|████████▎ | 50/60 [04:47<00:57,  5.75s/it]

1.7476002640194364


 85%|████████▌ | 51/60 [04:53<00:51,  5.75s/it]

1.714683969815572


 87%|████████▋ | 52/60 [04:59<00:45,  5.75s/it]

1.7002938985824585


 88%|████████▊ | 53/60 [05:05<00:40,  5.74s/it]

1.689014858669705


 90%|█████████ | 54/60 [05:10<00:34,  5.74s/it]

1.6743806070751615


 92%|█████████▏| 55/60 [05:16<00:28,  5.74s/it]

1.6585018634796143


 93%|█████████▎| 56/60 [05:22<00:22,  5.75s/it]

1.6373550494511921


 95%|█████████▌| 57/60 [05:27<00:17,  5.74s/it]

1.6139874458312988


 97%|█████████▋| 58/60 [05:33<00:11,  5.74s/it]

1.5753997696770563


 98%|█████████▊| 59/60 [05:39<00:05,  5.75s/it]

1.5486021836598713


100%|██████████| 60/60 [05:45<00:00,  5.75s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

1.5596655739678278


  2%|▏         | 1/60 [00:05<05:38,  5.73s/it]

1.5685193406211004


  3%|▎         | 2/60 [00:11<05:32,  5.73s/it]

1.523201094733344


  5%|▌         | 3/60 [00:17<05:26,  5.73s/it]

1.4839669730928209


  7%|▋         | 4/60 [00:22<05:21,  5.73s/it]

1.4606261518266466


  8%|▊         | 5/60 [00:28<05:16,  5.76s/it]

1.444788442717658


 10%|█         | 6/60 [00:34<05:10,  5.75s/it]

1.434902098443773


 12%|█▏        | 7/60 [00:40<05:04,  5.75s/it]

1.427845385339525


 13%|█▎        | 8/60 [00:45<04:58,  5.74s/it]

1.4222661786609225


 15%|█▌        | 9/60 [00:51<04:53,  5.75s/it]

1.417781286769443


 17%|█▋        | 10/60 [00:57<04:47,  5.74s/it]

1.4139391051398382


 18%|█▊        | 11/60 [01:03<04:42,  5.76s/it]

1.4105088975694444


 20%|██        | 12/60 [01:09<04:38,  5.80s/it]

1.4073748456107245


 22%|██▏       | 13/60 [01:14<04:31,  5.78s/it]

1.4044678741031222


 23%|██▎       | 14/60 [01:20<04:25,  5.76s/it]

1.4017931487825182


 25%|██▌       | 15/60 [01:26<04:18,  5.75s/it]

1.399301356739468


 27%|██▋       | 16/60 [01:32<04:12,  5.75s/it]

1.3969778882132635


 28%|██▊       | 17/60 [01:37<04:06,  5.74s/it]

1.3948189814885457


 30%|███       | 18/60 [01:43<04:01,  5.74s/it]

1.392803841167026


 32%|███▏      | 19/60 [01:49<03:55,  5.74s/it]

1.3908914460076227


 33%|███▎      | 20/60 [01:54<03:49,  5.73s/it]

1.389080868826972


 35%|███▌      | 21/60 [02:00<03:43,  5.73s/it]

1.3873762951956854


 37%|███▋      | 22/60 [02:06<03:37,  5.73s/it]

1.3857605457305908


 38%|███▊      | 23/60 [02:12<03:32,  5.73s/it]

1.384209977255927


 40%|████      | 24/60 [02:17<03:26,  5.74s/it]

1.3827325767940946


 42%|████▏     | 25/60 [02:23<03:21,  5.75s/it]

1.3813016679551866


 43%|████▎     | 26/60 [02:29<03:15,  5.76s/it]

1.3799143499798245


 45%|████▌     | 27/60 [02:35<03:09,  5.75s/it]

1.3785984913508098


 47%|████▋     | 28/60 [02:40<03:03,  5.74s/it]

1.3773116403155856


 48%|████▊     | 29/60 [02:46<02:57,  5.74s/it]

1.3760629759894476


 50%|█████     | 30/60 [02:52<02:52,  5.73s/it]

1.3748479684193928


 52%|█████▏    | 31/60 [02:58<02:46,  5.73s/it]

1.3736676110161676


 53%|█████▎    | 32/60 [03:03<02:40,  5.73s/it]

1.3725189235475328


 55%|█████▌    | 33/60 [03:09<02:34,  5.73s/it]

1.371403111351861


 57%|█████▋    | 34/60 [03:15<02:29,  5.73s/it]

1.3702910873625014


 58%|█████▊    | 35/60 [03:21<02:23,  5.74s/it]

1.3692141771316528


 60%|██████    | 36/60 [03:26<02:17,  5.74s/it]

1.3681521283255682


 62%|██████▏   | 37/60 [03:32<02:12,  5.75s/it]

1.3671058019002278


 63%|██████▎   | 38/60 [03:38<02:06,  5.74s/it]

1.3660738335715399


 65%|██████▌   | 39/60 [03:44<02:01,  5.78s/it]

1.3650758531358507


 67%|██████▋   | 40/60 [03:49<01:55,  5.77s/it]

1.3640924559699164


 68%|██████▊   | 41/60 [03:55<01:49,  5.76s/it]

1.3631289137734308


 70%|███████   | 42/60 [04:01<01:43,  5.75s/it]

1.3621710273954604


 72%|███████▏  | 43/60 [04:07<01:37,  5.76s/it]

1.3612319893307157


 73%|███████▎  | 44/60 [04:12<01:32,  5.75s/it]

1.3603091637293498


 75%|███████▌  | 45/60 [04:18<01:26,  5.75s/it]

1.359404378467136


 77%|███████▋  | 46/60 [04:24<01:20,  5.74s/it]

1.3584948380788167


 78%|███████▊  | 47/60 [04:30<01:14,  5.74s/it]

1.35761382844713


 80%|████████  | 48/60 [04:35<01:08,  5.74s/it]

1.3567364348305597


 82%|████████▏ | 49/60 [04:41<01:03,  5.74s/it]

1.355862882402208


 83%|████████▎ | 50/60 [04:47<00:57,  5.74s/it]

1.3550088140699599


 85%|████████▌ | 51/60 [04:53<00:51,  5.74s/it]

1.3541631566153631


 87%|████████▋ | 52/60 [04:58<00:45,  5.74s/it]

1.3533132076263428


 88%|████████▊ | 53/60 [05:04<00:40,  5.74s/it]

1.3524827162424724


 90%|█████████ | 54/60 [05:10<00:34,  5.74s/it]

1.3516443967819214


 92%|█████████▏| 55/60 [05:15<00:28,  5.73s/it]

1.3508118391036987


 93%|█████████▎| 56/60 [05:21<00:22,  5.74s/it]

1.3499872816933527


 95%|█████████▌| 57/60 [05:27<00:17,  5.74s/it]

1.3491695324579875


 97%|█████████▋| 58/60 [05:33<00:11,  5.74s/it]

1.3483563396665785


 98%|█████████▊| 59/60 [05:38<00:05,  5.74s/it]

1.347549729877048


100%|██████████| 60/60 [05:44<00:00,  5.75s/it]

1.3467535045411851
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [30]:
model_score1

{'Bleu_1': 0.544703279616429,
 'Bleu_2': 0.4134876658158496,
 'Bleu_3': 0.32693856576369873,
 'Bleu_4': 0.2657789178164309,
 'METEOR': 0.2273272223806558,
 'ROUGE_L': 0.4559596838533356,
 'CIDEr': 1.3522398368352868,
 'SPICE': 0.28368736929451943,
 'USC_similarity': 0.5264723999027141}

In [31]:
caption_model2, model_score2 = cross_validation(cv[1][0], cv[1][1], 2)    

Split 2:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<03:41, 37.61it/s]

preprocessed words 2688 ==> 916
The vocabulary size is 917.
819 out of 917 words are found in the pre-trained matrix.
The size of embedding_matrix is (917, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:34<00:00, 38.90it/s]
  0%|          | 4/2083 [00:00<00:52, 39.67it/s]


Generating set took: 0:03:34.23


100%|██████████| 2083/2083 [00:53<00:00, 38.69it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.84
Training...


  2%|▏         | 1/60 [00:05<05:41,  5.79s/it]

8.731728924645317


  3%|▎         | 2/60 [00:11<05:35,  5.78s/it]

5.019426080915663


  5%|▌         | 3/60 [00:17<05:29,  5.78s/it]

4.617237038082546


  7%|▋         | 4/60 [00:23<05:23,  5.77s/it]

4.047823323143853


  8%|▊         | 5/60 [00:28<05:17,  5.77s/it]

3.541182518005371


 10%|█         | 6/60 [00:34<05:12,  5.79s/it]

3.189673105875651


 12%|█▏        | 7/60 [00:40<05:06,  5.79s/it]

2.982234107123481


 13%|█▎        | 8/60 [00:46<05:00,  5.78s/it]

2.8356569608052573


 15%|█▌        | 9/60 [00:51<04:54,  5.77s/it]

2.741465436087714


 17%|█▋        | 10/60 [00:57<04:48,  5.77s/it]

2.6863244904412165


 18%|█▊        | 11/60 [01:03<04:42,  5.77s/it]

2.612360874811808


 20%|██        | 12/60 [01:09<04:37,  5.77s/it]

2.5729580455356174


 22%|██▏       | 13/60 [01:15<04:32,  5.80s/it]

2.505357344945272


 23%|██▎       | 14/60 [01:20<04:26,  5.80s/it]

2.4621348645952015


 25%|██▌       | 15/60 [01:26<04:20,  5.79s/it]

2.411519686381022


 27%|██▋       | 16/60 [01:32<04:14,  5.79s/it]

2.3725251886579724


 28%|██▊       | 17/60 [01:38<04:08,  5.78s/it]

2.353493054707845


 30%|███       | 18/60 [01:44<04:03,  5.79s/it]

2.2979845735761852


 32%|███▏      | 19/60 [01:49<03:56,  5.78s/it]

2.247919718424479


 33%|███▎      | 20/60 [01:55<03:51,  5.78s/it]

2.20911349190606


 35%|███▌      | 21/60 [02:01<03:45,  5.79s/it]

2.1845916112264


 37%|███▋      | 22/60 [02:07<03:40,  5.80s/it]

2.152595387564765


 38%|███▊      | 23/60 [02:13<03:34,  5.79s/it]

2.1272957722345986


 40%|████      | 24/60 [02:18<03:28,  5.79s/it]

2.0937360657585993


 42%|████▏     | 25/60 [02:24<03:22,  5.79s/it]

2.070405787891812


 43%|████▎     | 26/60 [02:30<03:16,  5.78s/it]

2.0431081453959146


 45%|████▌     | 27/60 [02:36<03:10,  5.77s/it]

1.9736746814515855


 47%|████▋     | 28/60 [02:41<03:04,  5.77s/it]

1.9427100287543402


 48%|████▊     | 29/60 [02:47<02:58,  5.77s/it]

1.908035847875807


 50%|█████     | 30/60 [02:53<02:54,  5.83s/it]

1.8928749296400282


 52%|█████▏    | 31/60 [02:59<02:48,  5.82s/it]

1.8384249607721965


 53%|█████▎    | 32/60 [03:05<02:42,  5.80s/it]

1.7903878556357489


 55%|█████▌    | 33/60 [03:10<02:36,  5.79s/it]

1.7761812210083008


 57%|█████▋    | 34/60 [03:16<02:30,  5.78s/it]

1.7528272999657526


 58%|█████▊    | 35/60 [03:22<02:24,  5.79s/it]

1.7213465770085652


 60%|██████    | 36/60 [03:28<02:18,  5.79s/it]

1.7022945218616061


 62%|██████▏   | 37/60 [03:34<02:13,  5.80s/it]

1.6747280756632488


 63%|██████▎   | 38/60 [03:39<02:07,  5.81s/it]

1.680275559425354


 65%|██████▌   | 39/60 [03:45<02:01,  5.79s/it]

1.6623714632458158


 67%|██████▋   | 40/60 [03:51<01:55,  5.79s/it]

1.6628319422403972


 68%|██████▊   | 41/60 [03:57<01:49,  5.79s/it]

1.6638477113511827


 70%|███████   | 42/60 [04:03<01:44,  5.79s/it]

1.6432699229982164


 72%|███████▏  | 43/60 [04:08<01:38,  5.78s/it]

1.611459771792094


 73%|███████▎  | 44/60 [04:14<01:32,  5.78s/it]

1.5777140855789185


 75%|███████▌  | 45/60 [04:20<01:26,  5.78s/it]

1.5429312917921278


 77%|███████▋  | 46/60 [04:26<01:20,  5.78s/it]

1.5121707783804998


 78%|███████▊  | 47/60 [04:31<01:15,  5.77s/it]

1.4684166378445096


 80%|████████  | 48/60 [04:37<01:09,  5.77s/it]

1.4349995851516724


 82%|████████▏ | 49/60 [04:43<01:03,  5.77s/it]

1.4354075060950384


 83%|████████▎ | 50/60 [04:49<00:57,  5.77s/it]

1.42782195409139


 85%|████████▌ | 51/60 [04:54<00:51,  5.77s/it]

1.4018570449617174


 87%|████████▋ | 52/60 [05:00<00:46,  5.76s/it]

1.384338332547082


 88%|████████▊ | 53/60 [05:06<00:40,  5.76s/it]

1.3548463185628254


 90%|█████████ | 54/60 [05:12<00:34,  5.77s/it]

1.3423299259609647


 92%|█████████▏| 55/60 [05:18<00:28,  5.77s/it]

1.3415475289026897


 93%|█████████▎| 56/60 [05:23<00:23,  5.77s/it]

1.3539606134096782


 95%|█████████▌| 57/60 [05:29<00:17,  5.77s/it]

1.3354855179786682


 97%|█████████▋| 58/60 [05:35<00:11,  5.77s/it]

1.3301742805374994


 98%|█████████▊| 59/60 [05:41<00:05,  5.77s/it]

1.3060190743870206


100%|██████████| 60/60 [05:46<00:00,  5.78s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

1.270358231332567


  2%|▏         | 1/60 [00:05<05:39,  5.75s/it]

1.2598576479487948


  3%|▎         | 2/60 [00:11<05:34,  5.77s/it]

1.21578930483924


  5%|▌         | 3/60 [00:17<05:28,  5.76s/it]

1.1822808583577473


  7%|▋         | 4/60 [00:23<05:22,  5.76s/it]

1.1638949513435364


  8%|▊         | 5/60 [00:28<05:17,  5.77s/it]

1.1525346835454304


 10%|█         | 6/60 [00:34<05:11,  5.77s/it]

1.1444486843215094


 12%|█▏        | 7/60 [00:40<05:05,  5.77s/it]

1.1382370856073167


 13%|█▎        | 8/60 [00:46<05:00,  5.78s/it]

1.1332520378960504


 15%|█▌        | 9/60 [00:51<04:54,  5.77s/it]

1.1290757921006944


 17%|█▋        | 10/60 [00:57<04:48,  5.77s/it]

1.1254559755325317


 18%|█▊        | 11/60 [01:03<04:42,  5.77s/it]

1.1222733590337965


 20%|██        | 12/60 [01:09<04:39,  5.81s/it]

1.1194592316945393


 22%|██▏       | 13/60 [01:15<04:32,  5.80s/it]

1.1169043249554105


 23%|██▎       | 14/60 [01:20<04:26,  5.79s/it]

1.1145877639452617


 25%|██▌       | 15/60 [01:26<04:20,  5.79s/it]

1.112426883644528


 27%|██▋       | 16/60 [01:32<04:14,  5.79s/it]

1.1104293134477403


 28%|██▊       | 17/60 [01:38<04:08,  5.78s/it]

1.1085530254575942


 30%|███       | 18/60 [01:44<04:02,  5.77s/it]

1.1068191793229845


 32%|███▏      | 19/60 [01:49<03:56,  5.78s/it]

1.1051403946346707


 33%|███▎      | 20/60 [01:55<03:50,  5.77s/it]

1.1035483479499817


 35%|███▌      | 21/60 [02:01<03:45,  5.77s/it]

1.102037747701009


 37%|███▋      | 22/60 [02:07<03:39,  5.77s/it]

1.1005658970938788


 38%|███▊      | 23/60 [02:12<03:33,  5.77s/it]

1.0991424255900912


 40%|████      | 24/60 [02:18<03:27,  5.77s/it]

1.097791764471266


 42%|████▏     | 25/60 [02:24<03:21,  5.77s/it]

1.0964896414015028


 43%|████▎     | 26/60 [02:30<03:16,  5.77s/it]

1.0952074726422627


 45%|████▌     | 27/60 [02:35<03:10,  5.78s/it]

1.0939820210138957


 47%|████▋     | 28/60 [02:41<03:04,  5.78s/it]

1.0927881399790447


 48%|████▊     | 29/60 [02:47<02:59,  5.78s/it]

1.0916417439778645


 50%|█████     | 30/60 [02:53<02:53,  5.77s/it]

1.0905176003774006


 52%|█████▏    | 31/60 [02:59<02:47,  5.77s/it]

1.0894232326083713


 53%|█████▎    | 32/60 [03:04<02:41,  5.77s/it]

1.0883565081490412


 55%|█████▌    | 33/60 [03:10<02:35,  5.77s/it]

1.0873073538144429


 57%|█████▋    | 34/60 [03:16<02:29,  5.77s/it]

1.086280345916748


 58%|█████▊    | 35/60 [03:22<02:24,  5.76s/it]

1.0852826436360676


 60%|██████    | 36/60 [03:27<02:18,  5.77s/it]

1.0842978888087802


 62%|██████▏   | 37/60 [03:33<02:12,  5.77s/it]

1.0833292802174885


 63%|██████▎   | 38/60 [03:39<02:07,  5.78s/it]

1.08237800333235


 65%|██████▌   | 39/60 [03:45<02:01,  5.78s/it]

1.0814358327123854


 67%|██████▋   | 40/60 [03:51<01:55,  5.77s/it]

1.080518815252516


 68%|██████▊   | 41/60 [03:56<01:49,  5.78s/it]

1.0796053740713332


 70%|███████   | 42/60 [04:02<01:43,  5.78s/it]

1.0787017610337999


 72%|███████▏  | 43/60 [04:08<01:38,  5.77s/it]

1.0778236322932773


 73%|███████▎  | 44/60 [04:14<01:32,  5.77s/it]

1.0769474771287706


 75%|███████▌  | 45/60 [04:19<01:26,  5.77s/it]

1.0760851966010199


 77%|███████▋  | 46/60 [04:25<01:20,  5.78s/it]

1.075235452916887


 78%|███████▊  | 47/60 [04:31<01:15,  5.82s/it]

1.074388649728563


 80%|████████  | 48/60 [04:37<01:09,  5.80s/it]

1.0735509461826749


 82%|████████▏ | 49/60 [04:43<01:03,  5.79s/it]

1.0727240708139207


 83%|████████▎ | 50/60 [04:48<00:57,  5.78s/it]

1.0719067056973774


 85%|████████▌ | 51/60 [04:54<00:51,  5.78s/it]

1.071089850531684


 87%|████████▋ | 52/60 [05:00<00:46,  5.77s/it]

1.0702857573827107


 88%|████████▊ | 53/60 [05:06<00:40,  5.78s/it]

1.069490995672014


 90%|█████████ | 54/60 [05:11<00:34,  5.77s/it]

1.0686909821298387


 92%|█████████▏| 55/60 [05:17<00:28,  5.77s/it]

1.0679067108366225


 93%|█████████▎| 56/60 [05:23<00:23,  5.77s/it]

1.0671279033025105


 95%|█████████▌| 57/60 [05:29<00:17,  5.76s/it]

1.0663485924402873


 97%|█████████▋| 58/60 [05:35<00:11,  5.77s/it]

1.0655805932150946


 98%|█████████▊| 59/60 [05:40<00:05,  5.77s/it]

1.0648158921135797


100%|██████████| 60/60 [05:46<00:00,  5.78s/it]

1.0640634960598416
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [32]:
model_score2

{'Bleu_1': 0.5497447219538867,
 'Bleu_2': 0.4210379521981906,
 'Bleu_3': 0.33840349424084093,
 'Bleu_4': 0.2798526176173945,
 'METEOR': 0.24107196500751146,
 'ROUGE_L': 0.46510956222385835,
 'CIDEr': 1.501778968876335,
 'SPICE': 0.3124203594280626,
 'USC_similarity': 0.5471561965057327}

In [33]:
caption_model3, model_score3 = cross_validation(cv[2][0], cv[2][1], 3)    

Split 3:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<03:38, 38.10it/s]

preprocessed words 2714 ==> 890
The vocabulary size is 891.
800 out of 891 words are found in the pre-trained matrix.
The size of embedding_matrix is (891, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:32<00:00, 39.16it/s]
  0%|          | 4/2083 [00:00<00:52, 39.49it/s]


Generating set took: 0:03:32.79


100%|██████████| 2083/2083 [00:53<00:00, 38.64it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.91
Training...


  2%|▏         | 1/60 [00:05<05:41,  5.79s/it]

8.162103282080757


  3%|▎         | 2/60 [00:11<05:35,  5.79s/it]

4.942037794325087


  5%|▌         | 3/60 [00:17<05:30,  5.79s/it]

4.373266167110867


  7%|▋         | 4/60 [00:23<05:24,  5.79s/it]

3.6144460572136774


  8%|▊         | 5/60 [00:28<05:18,  5.78s/it]

3.0630548265245228


 10%|█         | 6/60 [00:34<05:12,  5.78s/it]

2.744372102949354


 12%|█▏        | 7/60 [00:40<05:06,  5.78s/it]

2.529179334640503


 13%|█▎        | 8/60 [00:46<05:00,  5.78s/it]

2.375764396455553


 15%|█▌        | 9/60 [00:52<04:54,  5.77s/it]

2.245656622780694


 17%|█▋        | 10/60 [00:57<04:48,  5.77s/it]

2.1475389268663196


 18%|█▊        | 11/60 [01:03<04:42,  5.77s/it]

2.065673444006178


 20%|██        | 12/60 [01:09<04:36,  5.77s/it]

1.9874887731340196


 22%|██▏       | 13/60 [01:15<04:31,  5.77s/it]

1.9061799711651273


 23%|██▎       | 14/60 [01:20<04:25,  5.78s/it]

1.8271953927146063


 25%|██▌       | 15/60 [01:26<04:20,  5.79s/it]

1.774930066532559


 27%|██▋       | 16/60 [01:32<04:14,  5.79s/it]

1.7260530127419367


 28%|██▊       | 17/60 [01:38<04:08,  5.78s/it]

1.6895092858208551


 30%|███       | 18/60 [01:44<04:02,  5.78s/it]

1.6443317068947687


 32%|███▏      | 19/60 [01:49<03:56,  5.77s/it]

1.6107610993915134


 33%|███▎      | 20/60 [01:55<03:50,  5.77s/it]

1.6005475785997179


 35%|███▌      | 21/60 [02:01<03:45,  5.77s/it]

1.5228671232859294


 37%|███▋      | 22/60 [02:07<03:39,  5.78s/it]

1.480649709701538


 38%|███▊      | 23/60 [02:12<03:34,  5.79s/it]

1.4225087430742052


 40%|████      | 24/60 [02:18<03:28,  5.79s/it]

1.3721070422066584


 42%|████▏     | 25/60 [02:24<03:22,  5.79s/it]

1.3322096665700276


 43%|████▎     | 26/60 [02:30<03:17,  5.80s/it]

1.310085071457757


 45%|████▌     | 27/60 [02:36<03:11,  5.79s/it]

1.2995417250527277


 47%|████▋     | 28/60 [02:41<03:05,  5.79s/it]

1.3177783224317763


 48%|████▊     | 29/60 [02:47<02:59,  5.79s/it]

1.3314525683720906


 50%|█████     | 30/60 [02:53<02:53,  5.78s/it]

1.267541805903117


 52%|█████▏    | 31/60 [02:59<02:47,  5.79s/it]

1.2104799018965826


 53%|█████▎    | 32/60 [03:05<02:42,  5.80s/it]

1.1850617660416498


 55%|█████▌    | 33/60 [03:10<02:36,  5.80s/it]

1.1776865786976285


 57%|█████▋    | 34/60 [03:16<02:30,  5.79s/it]

1.1776213645935059


 58%|█████▊    | 35/60 [03:22<02:24,  5.79s/it]

1.1550910075505574


 60%|██████    | 36/60 [03:28<02:20,  5.85s/it]

1.1521736118528578


 62%|██████▏   | 37/60 [03:34<02:14,  5.83s/it]

1.1530582639906142


 63%|██████▎   | 38/60 [03:40<02:08,  5.82s/it]

1.1242805851830378


 65%|██████▌   | 39/60 [03:45<02:02,  5.81s/it]

1.0908565653695002


 67%|██████▋   | 40/60 [03:51<01:56,  5.80s/it]

1.068896969159444


 68%|██████▊   | 41/60 [03:57<01:50,  5.79s/it]

1.070668637752533


 70%|███████   | 42/60 [04:03<01:44,  5.79s/it]

1.1266917785008748


 72%|███████▏  | 43/60 [04:08<01:38,  5.79s/it]

1.082028289635976


 73%|███████▎  | 44/60 [04:14<01:32,  5.79s/it]

1.0728533930248685


 75%|███████▌  | 45/60 [04:20<01:26,  5.78s/it]

1.0715731117460463


 77%|███████▋  | 46/60 [04:26<01:20,  5.78s/it]

1.0634859667883978


 78%|███████▊  | 47/60 [04:32<01:15,  5.78s/it]

1.076148165596856


 80%|████████  | 48/60 [04:37<01:09,  5.78s/it]

1.092936482694414


 82%|████████▏ | 49/60 [04:43<01:03,  5.78s/it]

1.0986499786376953


 83%|████████▎ | 50/60 [04:49<00:57,  5.78s/it]

1.0678069856431749


 85%|████████▌ | 51/60 [04:55<00:52,  5.78s/it]

1.0237783855862088


 87%|████████▋ | 52/60 [05:00<00:46,  5.77s/it]

0.9765933553377787


 88%|████████▊ | 53/60 [05:06<00:40,  5.77s/it]

0.9368891848458184


 90%|█████████ | 54/60 [05:12<00:34,  5.77s/it]

0.9208787083625793


 92%|█████████▏| 55/60 [05:18<00:28,  5.77s/it]

0.9224594235420227


 93%|█████████▎| 56/60 [05:24<00:23,  5.79s/it]

0.9498744342062209


 95%|█████████▌| 57/60 [05:29<00:17,  5.79s/it]

0.9853618211216397


 97%|█████████▋| 58/60 [05:35<00:11,  5.79s/it]

1.0186885330412123


 98%|█████████▊| 59/60 [05:41<00:05,  5.78s/it]

0.9336071279313829


100%|██████████| 60/60 [05:47<00:00,  5.79s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.8796227243211534


  2%|▏         | 1/60 [00:05<05:40,  5.77s/it]

0.847172756989797


  3%|▎         | 2/60 [00:11<05:35,  5.78s/it]

0.8110146522521973


  5%|▌         | 3/60 [00:17<05:31,  5.82s/it]

0.7804104685783386


  7%|▋         | 4/60 [00:23<05:25,  5.80s/it]

0.7616623242696127


  8%|▊         | 5/60 [00:29<05:18,  5.79s/it]

0.7499747806125217


 10%|█         | 6/60 [00:34<05:12,  5.78s/it]

0.7420084608925713


 12%|█▏        | 7/60 [00:40<05:05,  5.77s/it]

0.7361733648512099


 13%|█▎        | 8/60 [00:46<04:59,  5.76s/it]

0.7316266695658366


 15%|█▌        | 9/60 [00:52<04:54,  5.77s/it]

0.727887299325731


 17%|█▋        | 10/60 [00:57<04:48,  5.78s/it]

0.7246722314092848


 18%|█▊        | 11/60 [01:03<04:42,  5.77s/it]

0.721819367673662


 20%|██        | 12/60 [01:09<04:36,  5.76s/it]

0.7192395461930169


 22%|██▏       | 13/60 [01:15<04:30,  5.76s/it]

0.7168840302361382


 23%|██▎       | 14/60 [01:20<04:24,  5.76s/it]

0.7147264016999139


 25%|██▌       | 15/60 [01:26<04:19,  5.77s/it]

0.7127281692292955


 27%|██▋       | 16/60 [01:32<04:13,  5.77s/it]

0.7108527686860826


 28%|██▊       | 17/60 [01:38<04:08,  5.78s/it]

0.7090885506735908


 30%|███       | 18/60 [01:44<04:02,  5.78s/it]

0.7074154085583158


 32%|███▏      | 19/60 [01:49<03:56,  5.78s/it]

0.7058253023359511


 33%|███▎      | 20/60 [01:55<03:50,  5.77s/it]

0.7043048540751139


 35%|███▌      | 21/60 [02:01<03:45,  5.77s/it]

0.7028446992238363


 37%|███▋      | 22/60 [02:07<03:39,  5.78s/it]

0.7014388574494256


 38%|███▊      | 23/60 [02:12<03:34,  5.79s/it]

0.7000819378428988


 40%|████      | 24/60 [02:18<03:28,  5.78s/it]

0.6987691455417209


 42%|████▏     | 25/60 [02:24<03:22,  5.77s/it]

0.6974996957514021


 43%|████▎     | 26/60 [02:30<03:16,  5.77s/it]

0.696269260512458


 45%|████▌     | 27/60 [02:36<03:10,  5.78s/it]

0.6950738297568427


 47%|████▋     | 28/60 [02:41<03:04,  5.78s/it]

0.6939101616541544


 48%|████▊     | 29/60 [02:47<02:58,  5.77s/it]

0.692777693271637


 50%|█████     | 30/60 [02:53<02:53,  5.77s/it]

0.691673692729738


 52%|█████▏    | 31/60 [02:59<02:47,  5.77s/it]

0.6905968421035342


 53%|█████▎    | 32/60 [03:04<02:41,  5.77s/it]

0.6895419855912527


 55%|█████▌    | 33/60 [03:10<02:35,  5.77s/it]

0.6885055303573608


 57%|█████▋    | 34/60 [03:16<02:30,  5.78s/it]

0.6874894797801971


 58%|█████▊    | 35/60 [03:22<02:24,  5.78s/it]

0.686494208044476


 60%|██████    | 36/60 [03:27<02:18,  5.77s/it]

0.6855152183108859


 62%|██████▏   | 37/60 [03:33<02:12,  5.78s/it]

0.6845564610428281


 63%|██████▎   | 38/60 [03:39<02:06,  5.77s/it]

0.6836152838336097


 65%|██████▌   | 39/60 [03:45<02:01,  5.77s/it]

0.6826865805519952


 67%|██████▋   | 40/60 [03:51<01:55,  5.77s/it]

0.681768579615487


 68%|██████▊   | 41/60 [03:56<01:49,  5.77s/it]

0.6808668076992035


 70%|███████   | 42/60 [04:02<01:43,  5.77s/it]

0.6799704730510712


 72%|███████▏  | 43/60 [04:08<01:38,  5.77s/it]

0.6790923145082262


 73%|███████▎  | 44/60 [04:14<01:32,  5.76s/it]

0.6782253450817533


 75%|███████▌  | 45/60 [04:19<01:26,  5.76s/it]

0.6773644288380941


 77%|███████▋  | 46/60 [04:25<01:20,  5.76s/it]

0.6765193707413144


 78%|███████▊  | 47/60 [04:31<01:14,  5.76s/it]

0.6756830414136251


 80%|████████  | 48/60 [04:37<01:09,  5.76s/it]

0.6748608284526401


 82%|████████▏ | 49/60 [04:42<01:03,  5.76s/it]

0.6740391651789347


 83%|████████▎ | 50/60 [04:48<00:57,  5.76s/it]

0.6732312473985884


 85%|████████▌ | 51/60 [04:54<00:51,  5.76s/it]

0.6724366082085503


 87%|████████▋ | 52/60 [05:00<00:46,  5.77s/it]

0.6716434955596924


 88%|████████▊ | 53/60 [05:05<00:40,  5.76s/it]

0.6708580917782254


 90%|█████████ | 54/60 [05:11<00:34,  5.77s/it]

0.6700661248630948


 92%|█████████▏| 55/60 [05:17<00:28,  5.76s/it]

0.669298940234714


 93%|█████████▎| 56/60 [05:23<00:23,  5.76s/it]

0.6685356563991971


 95%|█████████▌| 57/60 [05:29<00:17,  5.76s/it]

0.6677813529968262


 97%|█████████▋| 58/60 [05:34<00:11,  5.77s/it]

0.6670402487119039


 98%|█████████▊| 59/60 [05:40<00:05,  5.77s/it]

0.6662996345096164


100%|██████████| 60/60 [05:46<00:00,  5.77s/it]

0.6655637356970046
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [34]:
model_score3

{'Bleu_1': 0.5700041165439066,
 'Bleu_2': 0.4369369710710409,
 'Bleu_3': 0.3533912607678642,
 'Bleu_4': 0.29594824649102636,
 'METEOR': 0.2555919304348583,
 'ROUGE_L': 0.4801735354230924,
 'CIDEr': 1.6275426688343462,
 'SPICE': 0.32982691926523755,
 'USC_similarity': 0.5577943182805668}

In [35]:
caption_model4, model_score4 = cross_validation(cv[3][0], cv[3][1], 4)    

Split 4:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<03:42, 37.41it/s]

preprocessed words 2680 ==> 894
The vocabulary size is 895.
806 out of 895 words are found in the pre-trained matrix.
The size of embedding_matrix is (895, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:32<00:00, 39.22it/s]
  0%|          | 4/2083 [00:00<00:52, 39.87it/s]


Generating set took: 0:03:32.45


100%|██████████| 2083/2083 [00:53<00:00, 38.87it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.59
Training...


  2%|▏         | 1/60 [00:05<05:41,  5.78s/it]

8.461740069919163


  3%|▎         | 2/60 [00:11<05:34,  5.77s/it]

4.91348049375746


  5%|▌         | 3/60 [00:17<05:29,  5.77s/it]

3.848512199189928


  7%|▋         | 4/60 [00:23<05:23,  5.78s/it]

3.0519198841518826


  8%|▊         | 5/60 [00:28<05:17,  5.77s/it]

2.652053621080187


 10%|█         | 6/60 [00:34<05:11,  5.77s/it]

2.4052148395114474


 12%|█▏        | 7/60 [00:40<05:06,  5.77s/it]

2.206917021009657


 13%|█▎        | 8/60 [00:46<05:00,  5.77s/it]

2.037035279803806


 15%|█▌        | 9/60 [00:51<04:54,  5.77s/it]

1.8913064267900255


 17%|█▋        | 10/60 [00:57<04:48,  5.78s/it]

1.7818907499313354


 18%|█▊        | 11/60 [01:03<04:43,  5.79s/it]

1.70110195212894


 20%|██        | 12/60 [01:09<04:37,  5.78s/it]

1.6460962030622694


 22%|██▏       | 13/60 [01:15<04:31,  5.78s/it]

1.5687317185931735


 23%|██▎       | 14/60 [01:20<04:26,  5.78s/it]

1.5029477808210585


 25%|██▌       | 15/60 [01:26<04:20,  5.78s/it]

1.442217191060384


 27%|██▋       | 16/60 [01:32<04:14,  5.79s/it]

1.3669826189676921


 28%|██▊       | 17/60 [01:38<04:08,  5.79s/it]

1.3065085344844394


 30%|███       | 18/60 [01:44<04:03,  5.79s/it]

1.275232387913598


 32%|███▏      | 19/60 [01:49<03:57,  5.79s/it]

1.2851889928181965


 33%|███▎      | 20/60 [01:55<03:51,  5.78s/it]

1.2749559349483914


 35%|███▌      | 21/60 [02:01<03:46,  5.80s/it]

1.209752897421519


 37%|███▋      | 22/60 [02:07<03:40,  5.80s/it]

1.1735732754071553


 38%|███▊      | 23/60 [02:13<03:34,  5.80s/it]

1.1086836391025119


 40%|████      | 24/60 [02:18<03:28,  5.80s/it]

1.0547824369536505


 42%|████▏     | 25/60 [02:24<03:22,  5.79s/it]

1.0295558306905959


 43%|████▎     | 26/60 [02:30<03:16,  5.79s/it]

1.0172764195336237


 45%|████▌     | 27/60 [02:36<03:10,  5.79s/it]

0.9906156924035814


 47%|████▋     | 28/60 [02:41<03:05,  5.78s/it]

0.9621729387177361


 48%|████▊     | 29/60 [02:47<02:59,  5.78s/it]

0.9528205063607957


 50%|█████     | 30/60 [02:53<02:53,  5.77s/it]

0.9094448222054375


 52%|█████▏    | 31/60 [02:59<02:47,  5.77s/it]

0.8749817808469137


 53%|█████▎    | 32/60 [03:05<02:41,  5.78s/it]

0.8358248472213745


 55%|█████▌    | 33/60 [03:10<02:35,  5.77s/it]

0.824052717950609


 57%|█████▋    | 34/60 [03:16<02:30,  5.77s/it]

0.8313903874821134


 58%|█████▊    | 35/60 [03:22<02:24,  5.77s/it]

0.8521653479999967


 60%|██████    | 36/60 [03:28<02:18,  5.79s/it]

0.8788441353374057


 62%|██████▏   | 37/60 [03:33<02:13,  5.79s/it]

0.9005804128117032


 63%|██████▎   | 38/60 [03:39<02:07,  5.78s/it]

0.8697399629486932


 65%|██████▌   | 39/60 [03:45<02:01,  5.78s/it]

0.8358429935243394


 67%|██████▋   | 40/60 [03:51<01:55,  5.79s/it]

0.7948687672615051


 68%|██████▊   | 41/60 [03:57<01:50,  5.79s/it]

0.7696060405837165


 70%|███████   | 42/60 [04:02<01:44,  5.79s/it]

0.7651933895217048


 72%|███████▏  | 43/60 [04:08<01:38,  5.79s/it]

0.7630234956741333


 73%|███████▎  | 44/60 [04:14<01:33,  5.87s/it]

0.7727916571829054


 75%|███████▌  | 45/60 [04:20<01:27,  5.85s/it]

0.7711790468957689


 77%|███████▋  | 46/60 [04:26<01:21,  5.82s/it]

0.7872408363554213


 78%|███████▊  | 47/60 [04:32<01:15,  5.81s/it]

0.7767393853929307


 80%|████████  | 48/60 [04:37<01:09,  5.80s/it]

0.7875312831666734


 82%|████████▏ | 49/60 [04:43<01:03,  5.80s/it]

0.8303369349903531


 83%|████████▎ | 50/60 [04:49<00:57,  5.79s/it]

0.8357981509632535


 85%|████████▌ | 51/60 [04:55<00:52,  5.79s/it]

0.7774167524443732


 87%|████████▋ | 52/60 [05:01<00:46,  5.80s/it]

0.7219082448217604


 88%|████████▊ | 53/60 [05:06<00:40,  5.79s/it]

0.688430494732327


 90%|█████████ | 54/60 [05:12<00:34,  5.79s/it]

0.6487686501608955


 92%|█████████▏| 55/60 [05:18<00:28,  5.79s/it]

0.6513824363549551


 93%|█████████▎| 56/60 [05:24<00:23,  5.78s/it]

0.6619447204801772


 95%|█████████▌| 57/60 [05:29<00:17,  5.78s/it]

0.6905623740620084


 97%|█████████▋| 58/60 [05:35<00:11,  5.78s/it]

0.7376706600189209


 98%|█████████▊| 59/60 [05:41<00:05,  5.79s/it]

0.7648661997583177


100%|██████████| 60/60 [05:47<00:00,  5.79s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.7215563986036513


  2%|▏         | 1/60 [00:05<05:43,  5.81s/it]

0.6706432965066698


  3%|▎         | 2/60 [00:11<05:36,  5.80s/it]

0.6407448781861199


  5%|▌         | 3/60 [00:17<05:30,  5.80s/it]

0.6128365331225925


  7%|▋         | 4/60 [00:23<05:24,  5.80s/it]

0.5953060852156745


  8%|▊         | 5/60 [00:28<05:18,  5.79s/it]

0.5841803683174981


 10%|█         | 6/60 [00:34<05:12,  5.78s/it]

0.5762490232785543


 12%|█▏        | 7/60 [00:40<05:07,  5.80s/it]

0.5700464944044749


 13%|█▎        | 8/60 [00:46<05:01,  5.79s/it]

0.5649613638718923


 15%|█▌        | 9/60 [00:52<04:55,  5.79s/it]

0.560698515839047


 17%|█▋        | 10/60 [00:57<04:49,  5.80s/it]

0.5570427179336548


 18%|█▊        | 11/60 [01:03<04:44,  5.80s/it]

0.5538577006922828


 20%|██        | 12/60 [01:09<04:38,  5.79s/it]

0.5510340399212308


 22%|██▏       | 13/60 [01:15<04:32,  5.79s/it]

0.5485091540548537


 23%|██▎       | 14/60 [01:21<04:26,  5.79s/it]

0.5462375150786506


 25%|██▌       | 15/60 [01:26<04:20,  5.79s/it]

0.544171412785848


 27%|██▋       | 16/60 [01:32<04:14,  5.79s/it]

0.5422561599148644


 28%|██▊       | 17/60 [01:38<04:08,  5.78s/it]

0.5404734743965997


 30%|███       | 18/60 [01:44<04:02,  5.78s/it]

0.5388089484638638


 32%|███▏      | 19/60 [01:49<03:56,  5.78s/it]

0.5372433331277635


 33%|███▎      | 20/60 [01:55<03:51,  5.78s/it]

0.5357583694987826


 35%|███▌      | 21/60 [02:01<03:45,  5.78s/it]

0.5343494481510587


 37%|███▋      | 22/60 [02:07<03:39,  5.78s/it]

0.5330100556214651


 38%|███▊      | 23/60 [02:13<03:34,  5.79s/it]

0.5317305591371324


 40%|████      | 24/60 [02:18<03:28,  5.80s/it]

0.5305043425824907


 42%|████▏     | 25/60 [02:24<03:22,  5.80s/it]

0.5293257567617629


 43%|████▎     | 26/60 [02:30<03:17,  5.80s/it]

0.528193367852105


 45%|████▌     | 27/60 [02:36<03:11,  5.80s/it]

0.5271013246642219


 47%|████▋     | 28/60 [02:42<03:05,  5.79s/it]

0.5260420474741194


 48%|████▊     | 29/60 [02:47<02:59,  5.78s/it]

0.5250177913241916


 50%|█████     | 30/60 [02:53<02:53,  5.78s/it]

0.5240205360783471


 52%|█████▏    | 31/60 [02:59<02:47,  5.78s/it]

0.5230491823620267


 53%|█████▎    | 32/60 [03:05<02:41,  5.78s/it]

0.5221084422535367


 55%|█████▌    | 33/60 [03:11<02:36,  5.79s/it]

0.5211853881676992


 57%|█████▋    | 34/60 [03:16<02:30,  5.78s/it]

0.5202821029557122


 58%|█████▊    | 35/60 [03:22<02:24,  5.78s/it]

0.5194054113494025


 60%|██████    | 36/60 [03:28<02:18,  5.78s/it]

0.5185477468702528


 62%|██████▏   | 37/60 [03:34<02:12,  5.78s/it]

0.5177072021696303


 63%|██████▎   | 38/60 [03:39<02:07,  5.78s/it]

0.5168889363606771


 65%|██████▌   | 39/60 [03:45<02:01,  5.78s/it]

0.5160863995552063


 67%|██████▋   | 40/60 [03:51<01:55,  5.78s/it]

0.5152994394302368


 68%|██████▊   | 41/60 [03:57<01:49,  5.79s/it]

0.5145330131053925


 70%|███████   | 42/60 [04:03<01:45,  5.86s/it]

0.5137853655550215


 72%|███████▏  | 43/60 [04:09<01:39,  5.84s/it]

0.5130493177307976


 73%|███████▎  | 44/60 [04:14<01:33,  5.82s/it]

0.5123302406734891


 75%|███████▌  | 45/60 [04:20<01:27,  5.82s/it]

0.5116255945629544


 77%|███████▋  | 46/60 [04:26<01:21,  5.81s/it]

0.5109348297119141


 78%|███████▊  | 47/60 [04:32<01:15,  5.80s/it]

0.5102575752470229


 80%|████████  | 48/60 [04:38<01:09,  5.79s/it]

0.5095902217759026


 82%|████████▏ | 49/60 [04:43<01:03,  5.79s/it]

0.5089325375027127


 83%|████████▎ | 50/60 [04:49<00:57,  5.79s/it]

0.5082872145705752


 85%|████████▌ | 51/60 [04:55<00:52,  5.79s/it]

0.5076575809054904


 87%|████████▋ | 52/60 [05:01<00:46,  5.79s/it]

0.5070304340786405


 88%|████████▊ | 53/60 [05:06<00:40,  5.79s/it]

0.5064141816563077


 90%|█████████ | 54/60 [05:12<00:34,  5.79s/it]

0.5058101680543687


 92%|█████████▏| 55/60 [05:18<00:28,  5.79s/it]

0.5052125023470985


 93%|█████████▎| 56/60 [05:24<00:23,  5.79s/it]

0.5046213666598002


 95%|█████████▌| 57/60 [05:30<00:17,  5.79s/it]

0.5040430459711287


 97%|█████████▋| 58/60 [05:35<00:11,  5.79s/it]

0.5034679704242282


 98%|█████████▊| 59/60 [05:41<00:05,  5.79s/it]

0.5029030243555704


100%|██████████| 60/60 [05:47<00:00,  5.79s/it]

0.5023428963290321
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [36]:
model_score4

{'Bleu_1': 0.5403621443225163,
 'Bleu_2': 0.4060537687226018,
 'Bleu_3': 0.3252016955962572,
 'Bleu_4': 0.27036573542650877,
 'METEOR': 0.2434218549564466,
 'ROUGE_L': 0.45846613436735706,
 'CIDEr': 1.4245641106535625,
 'SPICE': 0.2994223468752407,
 'USC_similarity': 0.5270081463157665}

In [37]:
caption_model5, model_score5 = cross_validation(cv[4][0], cv[4][1], 5)    

Split 5:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<04:10, 33.27it/s]

preprocessed words 2657 ==> 905
The vocabulary size is 906.
815 out of 906 words are found in the pre-trained matrix.
The size of embedding_matrix is (906, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:33<00:00, 39.06it/s]
  0%|          | 4/2083 [00:00<00:52, 39.83it/s]


Generating set took: 0:03:33.36


100%|██████████| 2083/2083 [00:53<00:00, 38.96it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.46
Training...


  2%|▏         | 1/60 [00:05<05:39,  5.76s/it]

9.021933131747776


  3%|▎         | 2/60 [00:11<05:34,  5.76s/it]

4.7782781389024525


  5%|▌         | 3/60 [00:17<05:28,  5.76s/it]

4.166260480880737


  7%|▋         | 4/60 [00:23<05:22,  5.76s/it]

3.4765849908192954


  8%|▊         | 5/60 [00:28<05:17,  5.77s/it]

2.998911672168308


 10%|█         | 6/60 [00:34<05:12,  5.78s/it]

2.698606994416979


 12%|█▏        | 7/60 [00:40<05:06,  5.78s/it]

2.4963524871402316


 13%|█▎        | 8/60 [00:46<05:00,  5.79s/it]

2.3593112097846136


 15%|█▌        | 9/60 [00:52<04:55,  5.79s/it]

2.2511481708950467


 17%|█▋        | 10/60 [00:57<04:49,  5.79s/it]

2.164932754304674


 18%|█▊        | 11/60 [01:03<04:43,  5.79s/it]

2.111777146657308


 20%|██        | 12/60 [01:09<04:38,  5.80s/it]

2.047886676258511


 22%|██▏       | 13/60 [01:15<04:32,  5.80s/it]

1.967561403910319


 23%|██▎       | 14/60 [01:20<04:26,  5.79s/it]

1.8889070351918538


 25%|██▌       | 15/60 [01:26<04:20,  5.78s/it]

1.8132432169384427


 27%|██▋       | 16/60 [01:32<04:14,  5.79s/it]

1.7548438972897


 28%|██▊       | 17/60 [01:38<04:08,  5.78s/it]

1.7110981278949313


 30%|███       | 18/60 [01:44<04:02,  5.78s/it]

1.6829409069485135


 32%|███▏      | 19/60 [01:49<03:56,  5.77s/it]

1.6464601357777913


 33%|███▎      | 20/60 [01:55<03:50,  5.77s/it]

1.6302643616994221


 35%|███▌      | 21/60 [02:01<03:45,  5.77s/it]

1.580920894940694


 37%|███▋      | 22/60 [02:07<03:39,  5.77s/it]

1.5336406230926514


 38%|███▊      | 23/60 [02:12<03:33,  5.77s/it]

1.4979359176423814


 40%|████      | 24/60 [02:18<03:27,  5.77s/it]

1.449572245279948


 42%|████▏     | 25/60 [02:24<03:21,  5.77s/it]

1.4191792143715753


 43%|████▎     | 26/60 [02:30<03:16,  5.77s/it]

1.3891386058595445


 45%|████▌     | 27/60 [02:35<03:10,  5.77s/it]

1.3825919893052843


 47%|████▋     | 28/60 [02:41<03:04,  5.77s/it]

1.3803765641318426


 48%|████▊     | 29/60 [02:47<02:59,  5.78s/it]

1.3632965352800157


 50%|█████     | 30/60 [02:53<02:53,  5.78s/it]

1.3358926508161757


 52%|█████▏    | 31/60 [02:59<02:47,  5.79s/it]

1.2927907175487943


 53%|█████▎    | 32/60 [03:04<02:42,  5.79s/it]

1.248068590958913


 55%|█████▌    | 33/60 [03:10<02:36,  5.78s/it]

1.2310758895344205


 57%|█████▋    | 34/60 [03:16<02:30,  5.79s/it]

1.2029305299123128


 58%|█████▊    | 35/60 [03:22<02:24,  5.78s/it]

1.1943136718538072


 60%|██████    | 36/60 [03:28<02:18,  5.78s/it]

1.2001610464519925


 62%|██████▏   | 37/60 [03:33<02:12,  5.78s/it]

1.18000508679284


 63%|██████▎   | 38/60 [03:39<02:07,  5.78s/it]

1.154807488123576


 65%|██████▌   | 39/60 [03:45<02:01,  5.77s/it]

1.1485070056385465


 67%|██████▋   | 40/60 [03:51<01:55,  5.78s/it]

1.1173140207926433


 68%|██████▊   | 41/60 [03:56<01:49,  5.77s/it]

1.0978441370858087


 70%|███████   | 42/60 [04:02<01:43,  5.76s/it]

1.0695792900191412


 72%|███████▏  | 43/60 [04:08<01:37,  5.76s/it]

1.0349441038237677


 73%|███████▎  | 44/60 [04:14<01:32,  5.76s/it]

1.0067984660466511


 75%|███████▌  | 45/60 [04:19<01:26,  5.76s/it]

1.0015782316525776


 77%|███████▋  | 46/60 [04:25<01:20,  5.76s/it]

0.9925876326031156


 78%|███████▊  | 47/60 [04:31<01:14,  5.76s/it]

0.9927711221906874


 80%|████████  | 48/60 [04:37<01:09,  5.76s/it]

0.9802420934041342


 82%|████████▏ | 49/60 [04:42<01:03,  5.76s/it]

0.9886764023039076


 83%|████████▎ | 50/60 [04:48<00:57,  5.76s/it]

1.006128317779965


 85%|████████▌ | 51/60 [04:54<00:51,  5.76s/it]

1.0335886809560988


 87%|████████▋ | 52/60 [05:00<00:46,  5.76s/it]

1.0237671401765611


 88%|████████▊ | 53/60 [05:06<00:40,  5.76s/it]

1.0257420473628573


 90%|█████████ | 54/60 [05:11<00:34,  5.77s/it]

1.0340071717898052


 92%|█████████▏| 55/60 [05:17<00:28,  5.77s/it]

1.0308063096470303


 93%|█████████▎| 56/60 [05:23<00:23,  5.77s/it]

1.0473548902405634


 95%|█████████▌| 57/60 [05:29<00:17,  5.76s/it]

1.088747759660085


 97%|█████████▋| 58/60 [05:34<00:11,  5.76s/it]

1.0869959327909682


 98%|█████████▊| 59/60 [05:40<00:05,  5.77s/it]

1.054475724697113


100%|██████████| 60/60 [05:46<00:00,  5.77s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

1.0144091182284884


  2%|▏         | 1/60 [00:05<05:43,  5.83s/it]

0.9858366648356119


  3%|▎         | 2/60 [00:11<05:36,  5.81s/it]

0.9435030619303385


  5%|▌         | 3/60 [00:17<05:30,  5.80s/it]

0.9036296473609077


  7%|▋         | 4/60 [00:23<05:24,  5.80s/it]

0.8759935432010226


  8%|▊         | 5/60 [00:28<05:18,  5.80s/it]

0.8573426604270935


 10%|█         | 6/60 [00:34<05:12,  5.79s/it]

0.8443585104412503


 12%|█▏        | 7/60 [00:40<05:06,  5.79s/it]

0.8348467747370402


 13%|█▎        | 8/60 [00:46<05:00,  5.78s/it]

0.8274772034751045


 15%|█▌        | 9/60 [00:52<04:55,  5.79s/it]

0.8215103877915276


 17%|█▋        | 10/60 [00:57<04:49,  5.78s/it]

0.816509399149153


 18%|█▊        | 11/60 [01:03<04:42,  5.77s/it]

0.8122090962198045


 20%|██        | 12/60 [01:09<04:36,  5.77s/it]

0.8084256516562568


 22%|██▏       | 13/60 [01:15<04:32,  5.79s/it]

0.8050495717260573


 23%|██▎       | 14/60 [01:20<04:26,  5.79s/it]

0.8019782702128092


 25%|██▌       | 15/60 [01:26<04:20,  5.78s/it]

0.7991821832127042


 27%|██▋       | 16/60 [01:32<04:14,  5.78s/it]

0.7966004676289029


 28%|██▊       | 17/60 [01:38<04:08,  5.79s/it]

0.7942004402478536


 30%|███       | 18/60 [01:44<04:02,  5.78s/it]

0.7919599016507467


 32%|███▏      | 19/60 [01:49<03:56,  5.77s/it]

0.7898523807525635


 33%|███▎      | 20/60 [01:55<03:50,  5.77s/it]

0.787858956389957


 35%|███▌      | 21/60 [02:01<03:44,  5.77s/it]

0.7859627604484558


 37%|███▋      | 22/60 [02:07<03:39,  5.77s/it]

0.7841540575027466


 38%|███▊      | 23/60 [02:12<03:33,  5.77s/it]

0.7824217743343778


 40%|████      | 24/60 [02:18<03:27,  5.77s/it]

0.780757831202613


 42%|████▏     | 25/60 [02:24<03:22,  5.77s/it]

0.7791488700442843


 43%|████▎     | 26/60 [02:30<03:16,  5.77s/it]

0.7775990830527412


 45%|████▌     | 27/60 [02:36<03:10,  5.77s/it]

0.776099443435669


 47%|████▋     | 28/60 [02:41<03:04,  5.77s/it]

0.7746520770920647


 48%|████▊     | 29/60 [02:47<02:58,  5.77s/it]

0.7732487453354729


 50%|█████     | 30/60 [02:53<02:53,  5.77s/it]

0.7718914681010776


 52%|█████▏    | 31/60 [02:59<02:50,  5.86s/it]

0.7705676952997843


 53%|█████▎    | 32/60 [03:05<02:43,  5.83s/it]

0.769269605477651


 55%|█████▌    | 33/60 [03:10<02:37,  5.82s/it]

0.7680009603500366


 57%|█████▋    | 34/60 [03:16<02:31,  5.83s/it]

0.7667573955323961


 58%|█████▊    | 35/60 [03:22<02:25,  5.82s/it]

0.7655441231197782


 60%|██████    | 36/60 [03:28<02:19,  5.81s/it]

0.7643510500590006


 62%|██████▏   | 37/60 [03:34<02:13,  5.81s/it]

0.7631861898634169


 63%|██████▎   | 38/60 [03:40<02:07,  5.81s/it]

0.7620412177509732


 65%|██████▌   | 39/60 [03:45<02:01,  5.79s/it]

0.760920074250963


 67%|██████▋   | 40/60 [03:51<01:55,  5.79s/it]

0.7598201433817545


 68%|██████▊   | 41/60 [03:57<01:49,  5.79s/it]

0.7587438424428304


 70%|███████   | 42/60 [04:03<01:44,  5.78s/it]

0.7576789723502265


 72%|███████▏  | 43/60 [04:08<01:38,  5.78s/it]

0.7566320697466532


 73%|███████▎  | 44/60 [04:14<01:32,  5.78s/it]

0.7555999424722459


 75%|███████▌  | 45/60 [04:20<01:26,  5.77s/it]

0.7545873257848952


 77%|███████▋  | 46/60 [04:26<01:20,  5.77s/it]

0.7535884380340576


 78%|███████▊  | 47/60 [04:31<01:15,  5.77s/it]

0.752605077293184


 80%|████████  | 48/60 [04:37<01:09,  5.77s/it]

0.7516361441877153


 82%|████████▏ | 49/60 [04:43<01:03,  5.78s/it]

0.7506796055369906


 83%|████████▎ | 50/60 [04:49<00:57,  5.78s/it]

0.7497302061981626


 85%|████████▌ | 51/60 [04:55<00:52,  5.78s/it]

0.7487912343608009


 87%|████████▋ | 52/60 [05:00<00:46,  5.78s/it]

0.7478731804423862


 88%|████████▊ | 53/60 [05:06<00:40,  5.77s/it]

0.7469622625244988


 90%|█████████ | 54/60 [05:12<00:34,  5.77s/it]

0.7460601859622531


 92%|█████████▏| 55/60 [05:18<00:28,  5.79s/it]

0.7451672355333964


 93%|█████████▎| 56/60 [05:23<00:23,  5.78s/it]

0.7442866894933913


 95%|█████████▌| 57/60 [05:29<00:17,  5.77s/it]

0.7434080839157104


 97%|█████████▋| 58/60 [05:35<00:11,  5.78s/it]

0.7425484094354842


 98%|█████████▊| 59/60 [05:41<00:05,  5.79s/it]

0.7416860626803504


100%|██████████| 60/60 [05:47<00:00,  5.79s/it]

0.7408414747979906
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [38]:
model_score5

{'Bleu_1': 0.5608120606733,
 'Bleu_2': 0.42800387628455505,
 'Bleu_3': 0.3419038275680396,
 'Bleu_4': 0.28095786848263316,
 'METEOR': 0.2447877714265427,
 'ROUGE_L': 0.46303065203538796,
 'CIDEr': 1.4683687771537384,
 'SPICE': 0.3105531153128059,
 'USC_similarity': 0.5450782572881883}

In [39]:
model_scores = defaultdict(list)
for scores in [model_score1, model_score2, model_score3, model_score4, model_score5]:
    for key, value in scores.items():
        model_scores[key].append(value)

In [40]:
model_scores

defaultdict(list,
            {'Bleu_1': [0.544703279616429,
              0.5497447219538867,
              0.5700041165439066,
              0.5403621443225163,
              0.5608120606733],
             'Bleu_2': [0.4134876658158496,
              0.4210379521981906,
              0.4369369710710409,
              0.4060537687226018,
              0.42800387628455505],
             'Bleu_3': [0.32693856576369873,
              0.33840349424084093,
              0.3533912607678642,
              0.3252016955962572,
              0.3419038275680396],
             'Bleu_4': [0.2657789178164309,
              0.2798526176173945,
              0.29594824649102636,
              0.27036573542650877,
              0.28095786848263316],
             'METEOR': [0.2273272223806558,
              0.24107196500751146,
              0.2555919304348583,
              0.2434218549564466,
              0.2447877714265427],
             'ROUGE_L': [0.4559596838533356,
              0.4651095622238

In [41]:
tag = '11.2.1'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'w') as fp:
    json.dump(model_scores, fp)