## Image Captioning with Pytorch

The following contents are modified from MDS DSCI 575 lecture 8 demo

In [1]:
import os, sys, json
from collections import defaultdict
from tqdm import tqdm
import pickle
from time import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torchsummary import summary
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

from nltk.translate import bleu_score
from sklearn.model_selection import KFold

sys.path.append('../../scr/evaluation')
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.usc_sim.usc_sim import usc_sim
import subprocess

START = "startseq"
STOP = "endseq"
EPOCHS = 10
AWS = True


In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
# torch.cuda.empty_cache()
# import gc 
# gc.collect()

In [4]:
# Nicely formatted time string
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m:>02}:{s:>05.2f}"
        
if AWS:
    root_captioning = "../../data"
else:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        root_captioning = "/content/drive/My Drive/data"
        COLAB = True
        print("Note: using Google CoLab")
    except:
        print("Note: not using Google CoLab")
        COLAB = False

### Clean/Build Dataset

- Read captions
- Preprocess captions


In [5]:
def get_img_info(name, num=np.inf):
    """
    Returns img paths and captions

    Parameters:
    -----------
    name: str
        the json file name
    num: int (default: np.inf)
        the number of observations to get

    Return:
    --------
    list, dict, int
        img paths, corresponding captions, max length of captions
    """
    img_path = []
    caption = [] 
    max_length = 0
    if AWS:
        with open(f'{root_captioning}/json/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for filename in data.keys():
                if num is not None and len(caption) == num:
                    break
                img_path.append(
                    f'{root_captioning}/{name}/{filename}'
                )
                sen_list = []
                for sentence in data[filename]['sentences']:
                    max_length = max(max_length, len(sentence['tokens']))
                    sen_list.append(sentence['raw'])

                caption.append(sen_list)    
    else:            
        with open(f'{root_captioning}/interim/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for set_name in ['rsicd', 'ucm']:
                for filename in data[set_name].keys():
                    if num is not None and len(caption) == num:
                        break

                    img_path.append(
                        f'{root_captioning}/raw/imgs/{set_name}/{filename}'
                    )
                    sen_list = []
                    for sentence in data[set_name][filename]['sentences']:
                        max_length = max(max_length, len(sentence['tokens']))
                        sen_list.append(sentence['raw'])

                    caption.append(sen_list)
    
    return img_path, caption, max_length            


In [6]:
# get img path and caption list
# # only test 800 train samples and 200 valid samples
# train_paths, train_descriptions, max_length_train = get_img_info('train', 800)
# test_paths, test_descriptions, max_length_test = get_img_info('valid', 200)

train_paths, train_descriptions, max_length_train = get_img_info('train')
test_paths, test_descriptions, max_length_test = get_img_info('valid')
max_length = max(max_length_train, max_length_test)



In [7]:
all_paths = train_paths.copy()
all_paths.extend(test_paths.copy())
all_paths = np.array(all_paths)

all_descriptions = train_descriptions.copy()
all_descriptions.extend(test_descriptions.copy())
all_descriptions = np.array(all_descriptions)

captions = all_descriptions.copy()
max_length_all = max(max_length_train, max_length_test)
max_length = max_length_all + 2
      
lex = set()
for sen in all_descriptions:
    [lex.update(d.split()) for d in sen]
    
# add a start and stop token at the beginning/end
for v in all_descriptions:
    for d in range(len(v)):
        v[d] = f'{START} {v[d]} {STOP}'
        
print(f'There are {len(all_paths)} images') 
print(f'There are {len(lex)} unique words (vocab)')
print(f'The maximum length of captions with start and stop token is {max_length}.')


There are 10416 images
There are 2912 unique words (vocab)
The maximum length of captions with start and stop token is 36.


In [8]:
all_paths[-1]

'../../s3/valid/rsicd_park_33.jpg'

In [9]:
all_descriptions[-1]

array(['startseq a vast artificial lake was built in the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq'],
      dtype='<U184')

### Loading Glove Embeddings

In [10]:
embeddings_index = {} 
path = os.path.join(root_captioning, 'glove.6B.200d.txt') if AWS\
else os.path.join(root_captioning, 'raw', 'glove.6B.200d.txt')

f = open(
    path, 
    encoding="utf-8"
)

for line in tqdm(f):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs

f.close()
print(f'Found {len(embeddings_index)} word vectors.')

400000it [00:22, 17754.82it/s]

Found 400000 word vectors.





In [11]:
def get_vocab(descriptions, word_count_threshold=10):

    captions = []
    for val in descriptions:
        for cap in val:
            captions.append(cap)
    print(f'There are {len(captions)} captions')
    
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('preprocessed words %d ==> %d' % (len(word_counts), len(vocab)))
    return vocab

def get_word_dict(vocab):
    
    idxtoword = {}
    wordtoidx = {}

    ix = 1
    for w in vocab:
        wordtoidx[w] = ix
        idxtoword[ix] = w
        ix += 1

    return idxtoword, wordtoidx

def get_vocab_size(idxtoword):
    
    print(f'The vocabulary size is {len(idxtoword) + 1}.')
    return len(idxtoword) + 1


def get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx):

    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    count = 0

    for word, i in wordtoidx.items():

        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            count += 1
            # Words not found in the embedding index will be all zeros
            embedding_matrix[i] = embedding_vector
            
    print(f'{count} out of {vocab_size} words are found in the pre-trained matrix.')            
    print(f'The size of embedding_matrix is {embedding_matrix.shape}')
    return embedding_matrix

### Building the Neural Network

An embedding matrix is built from Glove.  This will be directly copied to the weight matrix of the neural network.

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
class CNNModel(nn.Module):

    def __init__(self, cnn_type, pretrained=True):
        """
        Initializes a CNNModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(CNNModel, self).__init__()

        if cnn_type == 'vgg16':
            self.model = models.vgg16(pretrained=pretrained)

            # remove the last two layers in classifier
            self.model.classifier = nn.Sequential(
              *list(self.model.classifier.children())[:-2]
            )
            self.input_size = 224     

        # inception v3 expects (299, 299) sized images
        elif cnn_type == 'inception_v3':
            self.model = models.inception_v3(pretrained=pretrained)
            # remove the classification layer
            self.model.fc = nn.Identity()

            # turn off auxiliary output
            self.model.aux_logits = False
            self.input_size = 299

        else:
            raise Exception("Please choose between 'vgg16' and 'inception_v3'.")

    def forward(self, img_input, train=False):
        """
        forward of the CNNModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        if not train:
            # set the model to evaluation model
            self.model.eval()

        return self.model(img_input)

In [14]:
class RNNModel(nn.Module):

    def __init__(
        self, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):
      
        """
        Initializes a RNNModel

        Parameters:
        -----------
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """

        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        if embedding_matrix is not None:

            self.embedding.load_state_dict({
              'weight': torch.FloatTensor(embedding_matrix)
            })
            self.embedding.weight.requires_grad = embedding_train

        self.dropout = nn.Dropout(p=0.5)

        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
 

    def forward(self, captions):
        """
        forward of the RNNModel

        Parameters:
        -----------
        captions: torch.Tensor
            the padded caption matrix

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        # embed the captions
        embedding = self.dropout(self.embedding(captions))

        outputs, (h, c) = self.lstm(embedding)

        return outputs, (h, c)



In [15]:
class CaptionModel(nn.Module):

    def __init__(
        self, 
        cnn_type, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):

        """
        Initializes a CaptionModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        feature_size: int
            the number of features in the image matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """    
        super(CaptionModel, self).__init__() 

        # set feature_size based on cnn_type
        if cnn_type == 'vgg16':
            self.feature_size = 4096
        elif cnn_type == 'inception_v3':
            self.feature_size = 2048
        else:
            raise Exception("Please choose between 'vgg16' and 'inception_v3'.")  

        self.decoder = RNNModel(
            vocab_size, 
            embedding_dim,
            hidden_size,
            embedding_matrix,
            embedding_train
        )
        
        self.dropout = nn.Dropout(p=0.5)
        self.dense1 = nn.Linear(self.feature_size, hidden_size) 
        self.relu1 = nn.ReLU()
          
        self.dense2 = nn.Linear(hidden_size, hidden_size) 
        self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(hidden_size, vocab_size) 

    def forward(self, img_features, captions):
        """
        forward of the CaptionModel

        Parameters:
        -----------
        img_features: torch.Tensor
            the image feature matrix
        captions: torch.Tensor
            the padded caption matrix

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        img_features =\
        self.relu1(
            self.dense1(
                self.dropout(
                    img_features
                )
            )
        )

        decoder_out, _ = self.decoder(captions)

        # add up decoder outputs and image features
        outputs =\
        self.dense3(
            self.relu2(
                self.dense2(
                    decoder_out.add(
                        (img_features.view(img_features.size(0), 1, -1))\
                        .repeat(1, decoder_out.size(1), 1)
                    )
                )
            )
        )

        return outputs

### Train the Neural Network

In [16]:
def train(model, iterator, optimizer, criterion, clip, vocab_size):
    """
    train the CaptionModel

    Parameters:
    -----------
    model: CaptionModel
        a CaptionModel instance
    iterator: torch.utils.data.dataloader
        a PyTorch dataloader
    optimizer: torch.optim
        a PyTorch optimizer 
    criterion: nn.CrossEntropyLoss
        a PyTorch criterion 

    Return:
    --------
    float
        average loss
    """
    model.train()    
    epoch_loss = 0
    
    for img_features, captions in iterator:
        
        optimizer.zero_grad()

        # for each caption, the end word is not passed for training
        outputs = model(
            img_features.to(device),
            captions[:, :-1].to(device)
        )

        loss = criterion(
            outputs.view(-1, vocab_size), 
            captions[:, 1:].flatten().to(device)
        )
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        
    return epoch_loss / len(iterator)

In [17]:
class SampleDataset(Dataset):
    def __init__(
        self,
        descriptions,
        imgs,
        wordtoidx,
        max_length
    ):
        """
        Initializes a SampleDataset

        Parameters:
        -----------
        descriptions: list
            a list of captions
        imgs: numpy.ndarray
            the image features
        wordtoidx: dict
            the dict to get word index
        max_length: int
            all captions will be padded to this size
        """        
        self.imgs = imgs
        self.descriptions = descriptions
        self.wordtoidx = wordtoidx
        self.max_length = max_length

    def __len__(self):
        """
        Returns the batch size

        Return:
        --------
        int
            the batch size
        """
        # return len(self.descriptions)
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Prepare data for each image

        Parameters:
        -----------
        idx: int
          the index of the image to process

        Return:
        --------
        list, list, list
            [5 x image feature matrix],
            [five padded captions for this image]
            [the length of each caption]
        """

        img = self.imgs[idx // 5]
        # convert each word into a list of sequences.
        seq = [self.wordtoidx[word] for word 
               in self.descriptions[idx // 5][idx % 5].split(' ')
               if word in self.wordtoidx]
        # pad the sequence with 0 on the right side
        in_seq = np.pad(
            seq, 
            (0, max_length - len(seq)),
            mode='constant',
            constant_values=(0, 0)
            )

        return img, in_seq


In [18]:
def init_weights(model, embedding_pretrained=True):
    """
    Initialize weights and bias in the model

    Parameters:
    -----------
    model: CaptionModel
      a CaptionModel instance
    embedding_pretrained: bool (default: True)
        not initialize the embedding matrix if True
    """  
  
    for name, param in model.named_parameters():
        if embedding_pretrained and 'embedding' in name:
            continue
        elif 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            


In [19]:
def encode_image(model, img_path):
    """
    Process the images to extract features

    Parameters:
    -----------
    model: CNNModel
      a CNNModel instance
    img_path: str
        the path of the image
 
    Return:
    --------
    torch.Tensor
        the extracted feature matrix from CNNModel
    """  

    img = Image.open(img_path)

    # Perform preprocessing needed by pre-trained models
    preprocessor = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    img = preprocessor(img)
    # Expand to 2D array
    img = img.view(1, *img.shape)
    # Call model to extract the smaller feature set for the image.
    x = model(img.to(device), False) 
    # Shape to correct form to be accepted by LSTM captioning network.
    x = np.squeeze(x)
    return x

In [20]:
def extract_img_features(img_paths, model):
    """
    Extracts, stores and returns image features

    Parameters:
    -----------
    img_paths: list
        the paths of images
    model: CNNModel (default: None)
      a CNNModel instance

    Return:
    --------
    numpy.ndarray
        the extracted image feature matrix from CNNModel
    """ 

    start = time()
    img_features = []

    for image_path in img_paths:
        img_features.append(
            encode_image(model, image_path).cpu().data.numpy()
    )

    print(f"\nGenerating set took: {hms_string(time()-start)}")

    return img_features

In [21]:
def get_train_test(
    encoder,
    train_paths,
    test_paths,
    cnn_type='inception_v3',
):

    train_img_features = extract_img_features(
        train_paths,
        encoder
    )

    test_img_features = extract_img_features(
        test_paths,
        encoder
    )
    return train_img_features, test_img_features

def get_train_dataloader(
    train_descriptions, 
    train_img_features,
    wordtoidx,
    max_length,
    batch_size=200
):
    train_dataset = SampleDataset(
        train_descriptions,
        train_img_features,
        wordtoidx,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size
    )
    
    return train_loader

def train_model(
    train_loader,
    vocab_size,
    embedding_dim, 
    embedding_matrix,
    cnn_type='inception_v3',
    hidden_size=256,
):

    caption_model = CaptionModel(
        cnn_type, 
        vocab_size, 
        embedding_dim, 
        hidden_size=hidden_size,
        embedding_matrix=embedding_matrix, 
        embedding_train=False
    )

    init_weights(
        caption_model,
        embedding_pretrained=True
    )

    caption_model.to(device)

    # we will ignore the pad token in true target set
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(
        caption_model.parameters(), 
        lr=0.01
    )

    clip = 1
    start = time()

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)

    # reduce the learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-4

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)
    return caption_model

In [22]:
def generateCaption(
    model, 
    img_features,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    in_text = START

    for i in range(max_length):

        sequence = [wordtoidx[w] for w in in_text.split() if w in wordtoidx]
        sequence = np.pad(sequence, (0, max_length - len(sequence)),
                          mode='constant', constant_values=(0, 0))
        model.eval()
        yhat = model(
            torch.FloatTensor(img_features)\
            .view(-1, model.feature_size).to(device),
            torch.LongTensor(sequence).view(-1, max_length).to(device)
        )

        yhat = yhat.view(-1, vocab_size).argmax(1)
        word = idxtoword[yhat.cpu().data.numpy()[i]]
        in_text += ' ' + word
        if word == STOP:
            break
    final = in_text.split()
    final = final[1 : -1]
    final = ' '.join(final)
    return final

### Evaluation

In [23]:
def eval_model(ref_data, results):
    """
    Computes evaluation metrics of the model results against the human annotated captions
    
    Parameters:
    ------------
    ref_data: dict
        a dictionary containing human annotated captions, with image name as key and a list of human annotated captions as values
    
    results: dict
        a dictionary containing model generated caption, with image name as key and a generated caption as value
        
    Returns:
    ------------
    score_dict: a dictionary containing the overall average score for the model
    """
    # download stanford nlp library
    subprocess.call(['../../scr/evaluation/get_stanford_models.sh'])
    
    # format the inputs
    gts = {}
    res = {}

    for imgId in range(len(ref_data)):
        caption_list_sel = []
        for i in range(5):
            lst = {}
            lst['caption'] = ref_data[imgId][i]
            lst['image_id'] = imgId
            lst['id'] = i
            caption_list_sel.append(lst)
        gts[imgId] = caption_list_sel

        res[imgId] = [{'caption': results[imgId]}]
        
    # tokenize
    print('tokenization...')
    tokenizer = PTBTokenizer()
    gts  = tokenizer.tokenize(gts)
    res = tokenizer.tokenize(res)
    
    # compute scores
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE"),
        (usc_sim(), "USC_similarity"),  
        ]
    score_dict = {}
    for scorer, method in scorers:
        print('computing %s score...'%(scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
        else:
            score_dict[method] = score
            
    return score_dict


In [24]:
def evaluate_results(
    test_img_features, 
    model,
    ref,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    # generate results
    print('Generating captions...')
    results = {}
    for n in range(len(test_img_features)):
        img_features = test_img_features[n]
        generated = generateCaption(
            model, 
            img_features,
            max_length,
            vocab_size,
            wordtoidx,
            idxtoword
        )
        results[n] = generated
        
    model_score = eval_model(ref, results)

    return model_score

### Cross validation

In [25]:
cnn_type = 'inception_v3'
encoder = CNNModel(cnn_type, pretrained=True)
encoder.to(device)

CNNModel(
  (model): Inception3(
    (Conv2d_1a_3x3): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2a_3x3): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2b_3x3): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_3b_1x1): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_4a_3x3): BasicConv2d(
      (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn

In [26]:
def cross_validation(train_index, test_index, count):
    print('=' * 60)
    print(f'Split {count}:')
    print(f'Splitting data...')
    
    train_paths, test_paths = all_paths[train_index], all_paths[test_index]
    train_descriptions, test_descriptions = all_descriptions[train_index], all_descriptions[test_index]
    print(f'{len(train_paths)} images for training and {len(test_paths)} images for testing.')
    
    vocab = get_vocab(train_descriptions, word_count_threshold=10)
    idxtoword, wordtoidx = get_word_dict(vocab)
    vocab_size = get_vocab_size(idxtoword)
    embedding_dim = 200
    embedding_matrix = get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx) 

    print(f'Preparing dataloader...')
    train_img_features, test_img_features = get_train_test(encoder, train_paths, test_paths)

    train_loader = get_train_dataloader(
        train_descriptions, 
        train_img_features,
        wordtoidx,
        max_length,
        batch_size=1000
    )

    print(f'Training...')
    caption_model = train_model(
        train_loader,
        vocab_size,
        embedding_dim, 
        embedding_matrix
    )

    
    ref = captions[test_index]
    model_score = evaluate_results(
        test_img_features, 
        caption_model,
        ref,
        max_length,
        vocab_size,
        wordtoidx,
        idxtoword
    )
    
    return caption_model, model_score

In [27]:
cv = KFold(n_splits=5, random_state=123, shuffle=True)
cv = [(train_index, test_index) for train_index, test_index in cv.split(all_paths)]  

In [28]:
caption_model1, model_score1 = cross_validation(cv[0][0], cv[0][1], 1)    

Split 1:
Splitting data...
8332 images for training and 2084 images for testing.
There are 41660 captions
preprocessed words 2659 ==> 884
The vocabulary size is 885.
796 out of 885 words are found in the pre-trained matrix.
The size of embedding_matrix is (885, 200)
Preparing dataloader...

Generating set took: 0:03:28.85


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.44
Training...


  2%|▏         | 1/60 [00:00<00:42,  1.39it/s]

5.514229350619846


  3%|▎         | 2/60 [00:01<00:40,  1.42it/s]

4.502158482869466


  5%|▌         | 3/60 [00:02<00:39,  1.43it/s]

4.324753708309597


  7%|▋         | 4/60 [00:02<00:39,  1.43it/s]

3.915895832909478


  8%|▊         | 5/60 [00:03<00:38,  1.42it/s]

3.4492679701911078


 10%|█         | 6/60 [00:04<00:37,  1.44it/s]

2.977285703023275


 12%|█▏        | 7/60 [00:04<00:36,  1.45it/s]

2.6207107173071966


 13%|█▎        | 8/60 [00:05<00:35,  1.45it/s]

2.368301921420627


 15%|█▌        | 9/60 [00:06<00:35,  1.42it/s]

2.1794495979944863


 17%|█▋        | 10/60 [00:06<00:34,  1.45it/s]

2.0354794793658786


 18%|█▊        | 11/60 [00:07<00:33,  1.46it/s]

1.919841832584805


 20%|██        | 12/60 [00:08<00:32,  1.47it/s]

1.8270450035730998


 22%|██▏       | 13/60 [00:08<00:32,  1.46it/s]

1.7231415377722845


 23%|██▎       | 14/60 [00:09<00:30,  1.49it/s]

1.6210342645645142


 25%|██▌       | 15/60 [00:10<00:31,  1.41it/s]

1.5294113688998752


 27%|██▋       | 16/60 [00:11<00:30,  1.43it/s]

1.450250718328688


 28%|██▊       | 17/60 [00:11<00:29,  1.46it/s]

1.386241528722975


 30%|███       | 18/60 [00:12<00:28,  1.46it/s]

1.3329377174377441


 32%|███▏      | 19/60 [00:13<00:27,  1.48it/s]

1.2893135680092707


 33%|███▎      | 20/60 [00:13<00:26,  1.49it/s]

1.2587706645329793


 35%|███▌      | 21/60 [00:14<00:26,  1.48it/s]

1.2247504194577534


 37%|███▋      | 22/60 [00:15<00:25,  1.48it/s]

1.1857063240475125


 38%|███▊      | 23/60 [00:15<00:24,  1.50it/s]

1.157996694246928


 40%|████      | 24/60 [00:16<00:24,  1.49it/s]

1.1310215592384338


 42%|████▏     | 25/60 [00:17<00:23,  1.51it/s]

1.1142139037450154


 43%|████▎     | 26/60 [00:17<00:22,  1.50it/s]

1.080123033788469


 45%|████▌     | 27/60 [00:18<00:21,  1.51it/s]

1.0639443198839824


 47%|████▋     | 28/60 [00:19<00:21,  1.50it/s]

1.0502040452427335


 48%|████▊     | 29/60 [00:19<00:20,  1.49it/s]

1.0310747424761455


 50%|█████     | 30/60 [00:20<00:20,  1.48it/s]

1.004809856414795


 52%|█████▏    | 31/60 [00:21<00:19,  1.49it/s]

0.9879322912957933


 53%|█████▎    | 32/60 [00:21<00:18,  1.49it/s]

0.9695126546753777


 55%|█████▌    | 33/60 [00:22<00:18,  1.45it/s]

0.9758769671122233


 57%|█████▋    | 34/60 [00:23<00:17,  1.48it/s]

0.9641458988189697


 58%|█████▊    | 35/60 [00:23<00:17,  1.43it/s]

0.9783777859475877


 60%|██████    | 36/60 [00:24<00:17,  1.40it/s]

0.9521188272370232


 62%|██████▏   | 37/60 [00:25<00:15,  1.44it/s]

0.9097718132866753


 63%|██████▎   | 38/60 [00:26<00:15,  1.39it/s]

0.8669042388598124


 65%|██████▌   | 39/60 [00:26<00:14,  1.41it/s]

0.8519044319788615


 67%|██████▋   | 40/60 [00:27<00:14,  1.42it/s]

0.8277674118677775


 68%|██████▊   | 41/60 [00:28<00:13,  1.41it/s]

0.8147961298624674


 70%|███████   | 42/60 [00:28<00:12,  1.44it/s]

0.8030693895286984


 72%|███████▏  | 43/60 [00:29<00:11,  1.45it/s]

0.801741811964247


 73%|███████▎  | 44/60 [00:30<00:11,  1.45it/s]

0.7976549996270074


 75%|███████▌  | 45/60 [00:30<00:10,  1.47it/s]

0.8074642717838287


 77%|███████▋  | 46/60 [00:31<00:09,  1.45it/s]

0.8124683234426711


 78%|███████▊  | 47/60 [00:32<00:08,  1.47it/s]

0.7955210043324364


 80%|████████  | 48/60 [00:32<00:08,  1.44it/s]

0.7782979177104102


 82%|████████▏ | 49/60 [00:33<00:07,  1.40it/s]

0.7718973954518636


 83%|████████▎ | 50/60 [00:34<00:07,  1.38it/s]

0.771431671248542


 85%|████████▌ | 51/60 [00:35<00:06,  1.41it/s]

0.7654221554597219


 87%|████████▋ | 52/60 [00:35<00:05,  1.38it/s]

0.7584561208883921


 88%|████████▊ | 53/60 [00:36<00:05,  1.34it/s]

0.7306431002087064


 90%|█████████ | 54/60 [00:37<00:04,  1.37it/s]

0.7008335027429793


 92%|█████████▏| 55/60 [00:38<00:03,  1.42it/s]

0.6869644754462771


 93%|█████████▎| 56/60 [00:38<00:02,  1.45it/s]

0.67476487159729


 95%|█████████▌| 57/60 [00:39<00:02,  1.45it/s]

0.6579834785726335


 97%|█████████▋| 58/60 [00:40<00:01,  1.47it/s]

0.6545898814996084


 98%|█████████▊| 59/60 [00:40<00:00,  1.48it/s]

0.6534211999840207


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

0.6679081850581698


  2%|▏         | 1/60 [00:00<00:41,  1.41it/s]

0.6729084220197465


  3%|▎         | 2/60 [00:01<00:40,  1.45it/s]

0.642169800069597


  5%|▌         | 3/60 [00:02<00:38,  1.47it/s]

0.6194452775849236


  7%|▋         | 4/60 [00:02<00:37,  1.49it/s]

0.6069211496247185


  8%|▊         | 5/60 [00:03<00:37,  1.48it/s]

0.6019570198323991


 10%|█         | 6/60 [00:04<00:36,  1.48it/s]

0.5935240023665957


 12%|█▏        | 7/60 [00:04<00:35,  1.47it/s]

0.5897727443112267


 13%|█▎        | 8/60 [00:05<00:35,  1.48it/s]

0.5885518855518765


 15%|█▌        | 9/60 [00:06<00:34,  1.47it/s]

0.5844520827134451


 17%|█▋        | 10/60 [00:06<00:33,  1.49it/s]

0.5835145976808336


 18%|█▊        | 11/60 [00:07<00:33,  1.48it/s]

0.5839147137271034


 20%|██        | 12/60 [00:08<00:33,  1.44it/s]

0.5855883393022749


 22%|██▏       | 13/60 [00:08<00:32,  1.47it/s]

0.5801976720492045


 23%|██▎       | 14/60 [00:09<00:31,  1.47it/s]

0.5823647942807939


 25%|██▌       | 15/60 [00:10<00:30,  1.47it/s]

0.5796844296985202


 27%|██▋       | 16/60 [00:10<00:29,  1.49it/s]

0.5790471666389041


 28%|██▊       | 17/60 [00:11<00:28,  1.51it/s]

0.5801929765277438


 30%|███       | 18/60 [00:12<00:27,  1.51it/s]

0.575655010011461


 32%|███▏      | 19/60 [00:12<00:26,  1.52it/s]

0.5744518207179176


 33%|███▎      | 20/60 [00:13<00:26,  1.53it/s]

0.5762001011106703


 35%|███▌      | 21/60 [00:14<00:25,  1.53it/s]

0.5755659706062741


 37%|███▋      | 22/60 [00:14<00:24,  1.54it/s]

0.5752431021796333


 38%|███▊      | 23/60 [00:15<00:23,  1.54it/s]

0.5716677407423655


 40%|████      | 24/60 [00:16<00:23,  1.51it/s]

0.5711139539877573


 42%|████▏     | 25/60 [00:16<00:22,  1.53it/s]

0.5697775118880801


 43%|████▎     | 26/60 [00:17<00:22,  1.54it/s]

0.5687599745061662


 45%|████▌     | 27/60 [00:18<00:21,  1.51it/s]

0.569854564136929


 47%|████▋     | 28/60 [00:18<00:21,  1.49it/s]

0.5720549987422096


 48%|████▊     | 29/60 [00:19<00:21,  1.47it/s]

0.5642567111386193


 50%|█████     | 30/60 [00:20<00:20,  1.47it/s]

0.5673549440171983


 52%|█████▏    | 31/60 [00:20<00:19,  1.46it/s]

0.5671780639224582


 53%|█████▎    | 32/60 [00:21<00:18,  1.48it/s]

0.5688295265038809


 55%|█████▌    | 33/60 [00:22<00:18,  1.42it/s]

0.5647474461131625


 57%|█████▋    | 34/60 [00:22<00:17,  1.45it/s]

0.5639675590727065


 58%|█████▊    | 35/60 [00:23<00:16,  1.47it/s]

0.5643688440322876


 60%|██████    | 36/60 [00:24<00:16,  1.49it/s]

0.5649922721915774


 62%|██████▏   | 37/60 [00:24<00:15,  1.49it/s]

0.565628707408905


 63%|██████▎   | 38/60 [00:25<00:14,  1.48it/s]

0.5607301460372077


 65%|██████▌   | 39/60 [00:26<00:14,  1.47it/s]

0.564148293601142


 67%|██████▋   | 40/60 [00:26<00:13,  1.47it/s]

0.5623478922579024


 68%|██████▊   | 41/60 [00:27<00:13,  1.46it/s]

0.5619058741463555


 70%|███████   | 42/60 [00:28<00:12,  1.48it/s]

0.5608940455648634


 72%|███████▏  | 43/60 [00:28<00:11,  1.48it/s]

0.5613431533177694


 73%|███████▎  | 44/60 [00:29<00:10,  1.48it/s]

0.561685260799196


 75%|███████▌  | 45/60 [00:30<00:10,  1.47it/s]

0.5602678491009606


 77%|███████▋  | 46/60 [00:30<00:09,  1.47it/s]

0.557510687245263


 78%|███████▊  | 47/60 [00:31<00:08,  1.48it/s]

0.5578052004178365


 80%|████████  | 48/60 [00:32<00:08,  1.49it/s]

0.558052831225925


 82%|████████▏ | 49/60 [00:32<00:07,  1.51it/s]

0.5594621102015177


 83%|████████▎ | 50/60 [00:33<00:06,  1.50it/s]

0.5567679372098711


 85%|████████▌ | 51/60 [00:34<00:05,  1.51it/s]

0.5572213100062476


 87%|████████▋ | 52/60 [00:34<00:05,  1.50it/s]

0.556281202369266


 88%|████████▊ | 53/60 [00:35<00:04,  1.50it/s]

0.5558694998423258


 90%|█████████ | 54/60 [00:36<00:04,  1.49it/s]

0.5580971042315165


 92%|█████████▏| 55/60 [00:36<00:03,  1.50it/s]

0.5547243621614244


 93%|█████████▎| 56/60 [00:37<00:02,  1.49it/s]

0.5555216636922624


 95%|█████████▌| 57/60 [00:38<00:01,  1.51it/s]

0.5522981319162581


 97%|█████████▋| 58/60 [00:38<00:01,  1.48it/s]

0.5561354094081454


 98%|█████████▊| 59/60 [00:39<00:00,  1.48it/s]

0.5543147524197897


100%|██████████| 60/60 [00:40<00:00,  1.49it/s]

0.5512393712997437
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [29]:
model_score1

{'Bleu_1': 0.563616964523848,
 'Bleu_2': 0.4294634886655052,
 'Bleu_3': 0.3459804433826117,
 'Bleu_4': 0.2886627182799963,
 'METEOR': 0.25157413391451316,
 'ROUGE_L': 0.46926991145831937,
 'CIDEr': 1.4617400972966967,
 'SPICE': 0.3152712089117034,
 'USC_similarity': 0.5413237524817693}

In [30]:
caption_model2, model_score2 = cross_validation(cv[1][0], cv[1][1], 2)    

Split 2:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2688 ==> 916
The vocabulary size is 917.
822 out of 917 words are found in the pre-trained matrix.
The size of embedding_matrix is (917, 200)
Preparing dataloader...

Generating set took: 0:03:32.13


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.25
Training...


  2%|▏         | 1/60 [00:00<00:40,  1.46it/s]

6.968498335944282


  3%|▎         | 2/60 [00:01<00:39,  1.46it/s]

4.839290248023139


  5%|▌         | 3/60 [00:02<00:38,  1.48it/s]

4.595922152201335


  7%|▋         | 4/60 [00:02<00:37,  1.49it/s]

4.352124479081896


  8%|▊         | 5/60 [00:03<00:39,  1.38it/s]

3.9605274200439453


 10%|█         | 6/60 [00:04<00:38,  1.40it/s]

3.4368532763587103


 12%|█▏        | 7/60 [00:04<00:37,  1.41it/s]

2.9931491216023765


 13%|█▎        | 8/60 [00:05<00:36,  1.42it/s]

2.6446221669514975


 15%|█▌        | 9/60 [00:06<00:35,  1.43it/s]

2.3780778778923883


 17%|█▋        | 10/60 [00:06<00:34,  1.43it/s]

2.1875472995969982


 18%|█▊        | 11/60 [00:07<00:33,  1.46it/s]

2.0354400475819907


 20%|██        | 12/60 [00:08<00:32,  1.46it/s]

1.9134620163175795


 22%|██▏       | 13/60 [00:09<00:32,  1.47it/s]

1.8153112729390461


 23%|██▎       | 14/60 [00:09<00:31,  1.47it/s]

1.7259795798195734


 25%|██▌       | 15/60 [00:10<00:30,  1.46it/s]

1.648133118947347


 27%|██▋       | 16/60 [00:11<00:30,  1.45it/s]

1.580843898985121


 28%|██▊       | 17/60 [00:11<00:29,  1.47it/s]

1.5311067899068196


 30%|███       | 18/60 [00:12<00:28,  1.48it/s]

1.4703310992982652


 32%|███▏      | 19/60 [00:13<00:27,  1.47it/s]

1.4295155074861314


 33%|███▎      | 20/60 [00:13<00:26,  1.49it/s]

1.3791209724214342


 35%|███▌      | 21/60 [00:14<00:26,  1.48it/s]

1.3315604262881808


 37%|███▋      | 22/60 [00:15<00:25,  1.46it/s]

1.2865571909480624


 38%|███▊      | 23/60 [00:15<00:25,  1.46it/s]

1.2602758142683241


 40%|████      | 24/60 [00:16<00:24,  1.45it/s]

1.2380896144443088


 42%|████▏     | 25/60 [00:17<00:23,  1.47it/s]

1.2103419635030959


 43%|████▎     | 26/60 [00:17<00:23,  1.46it/s]

1.19101482629776


 45%|████▌     | 27/60 [00:18<00:22,  1.44it/s]

1.1754205624262493


 47%|████▋     | 28/60 [00:19<00:22,  1.43it/s]

1.1514366004202101


 48%|████▊     | 29/60 [00:19<00:21,  1.46it/s]

1.1409890453020732


 50%|█████     | 30/60 [00:20<00:20,  1.48it/s]

1.1250199410650465


 52%|█████▏    | 31/60 [00:21<00:19,  1.45it/s]

1.0977101392216153


 53%|█████▎    | 32/60 [00:22<00:19,  1.43it/s]

1.0712124307950337


 55%|█████▌    | 33/60 [00:22<00:19,  1.36it/s]

1.059319535891215


 57%|█████▋    | 34/60 [00:23<00:18,  1.39it/s]

1.0570904943678114


 58%|█████▊    | 35/60 [00:24<00:17,  1.41it/s]

1.0529725419150457


 60%|██████    | 36/60 [00:24<00:16,  1.43it/s]

1.026157193713718


 62%|██████▏   | 37/60 [00:25<00:15,  1.46it/s]

0.9960033694903055


 63%|██████▎   | 38/60 [00:26<00:15,  1.46it/s]

0.9735353920194838


 65%|██████▌   | 39/60 [00:26<00:14,  1.46it/s]

0.964208722114563


 67%|██████▋   | 40/60 [00:27<00:13,  1.46it/s]

0.9555155436197916


 68%|██████▊   | 41/60 [00:28<00:13,  1.46it/s]

0.9527316159672208


 70%|███████   | 42/60 [00:28<00:12,  1.47it/s]

0.9436991612116495


 72%|███████▏  | 43/60 [00:29<00:11,  1.47it/s]

0.9421046508683099


 73%|███████▎  | 44/60 [00:30<00:10,  1.48it/s]

0.9445252087381151


 75%|███████▌  | 45/60 [00:30<00:10,  1.49it/s]

0.957838859823015


 77%|███████▋  | 46/60 [00:31<00:09,  1.50it/s]

0.9604425430297852


 78%|███████▊  | 47/60 [00:32<00:08,  1.50it/s]

0.9716223875681559


 80%|████████  | 48/60 [00:32<00:07,  1.51it/s]

0.9823964436848959


 82%|████████▏ | 49/60 [00:33<00:07,  1.51it/s]

0.9558375080426534


 83%|████████▎ | 50/60 [00:34<00:06,  1.51it/s]

0.9139946103096008


 85%|████████▌ | 51/60 [00:34<00:06,  1.50it/s]

0.8841617041163974


 87%|████████▋ | 52/60 [00:35<00:05,  1.49it/s]

0.860255648692449


 88%|████████▊ | 53/60 [00:36<00:04,  1.47it/s]

0.8506466150283813


 90%|█████████ | 54/60 [00:37<00:04,  1.48it/s]

0.8487492236826155


 92%|█████████▏| 55/60 [00:37<00:03,  1.49it/s]

0.8607657816674974


 93%|█████████▎| 56/60 [00:38<00:02,  1.48it/s]

0.8769311573770311


 95%|█████████▌| 57/60 [00:39<00:02,  1.47it/s]

0.8880375358793471


 97%|█████████▋| 58/60 [00:39<00:01,  1.46it/s]

0.8735153939988878


 98%|█████████▊| 59/60 [00:40<00:00,  1.45it/s]

0.8362249996927049


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

0.8041796154446073


  2%|▏         | 1/60 [00:00<00:39,  1.50it/s]

0.7878153853946261


  3%|▎         | 2/60 [00:01<00:38,  1.49it/s]

0.7663825253645579


  5%|▌         | 3/60 [00:02<00:38,  1.48it/s]

0.7498825225565169


  7%|▋         | 4/60 [00:02<00:37,  1.48it/s]

0.7390130029784309


  8%|▊         | 5/60 [00:03<00:36,  1.49it/s]

0.7324656711684333


 10%|█         | 6/60 [00:04<00:36,  1.48it/s]

0.7261454198095534


 12%|█▏        | 7/60 [00:04<00:35,  1.49it/s]

0.723550041516622


 13%|█▎        | 8/60 [00:05<00:35,  1.48it/s]

0.7235908110936483


 15%|█▌        | 9/60 [00:06<00:34,  1.47it/s]

0.719998362991545


 17%|█▋        | 10/60 [00:06<00:34,  1.47it/s]

0.7155844734774696


 18%|█▊        | 11/60 [00:07<00:33,  1.48it/s]

0.715645025173823


 20%|██        | 12/60 [00:08<00:32,  1.47it/s]

0.7181217504872216


 22%|██▏       | 13/60 [00:08<00:31,  1.48it/s]

0.7131497661272684


 23%|██▎       | 14/60 [00:09<00:31,  1.48it/s]

0.7140128645632002


 25%|██▌       | 15/60 [00:10<00:30,  1.47it/s]

0.7120694782998827


 27%|██▋       | 16/60 [00:10<00:30,  1.46it/s]

0.7099617256058587


 28%|██▊       | 17/60 [00:11<00:29,  1.48it/s]

0.7085078226195441


 30%|███       | 18/60 [00:12<00:28,  1.47it/s]

0.7095502250724368


 32%|███▏      | 19/60 [00:12<00:27,  1.47it/s]

0.708434498972363


 33%|███▎      | 20/60 [00:13<00:26,  1.48it/s]

0.7076157795058357


 35%|███▌      | 21/60 [00:14<00:26,  1.48it/s]

0.7009152505132887


 37%|███▋      | 22/60 [00:14<00:25,  1.47it/s]

0.7072079016102685


 38%|███▊      | 23/60 [00:15<00:24,  1.49it/s]

0.702832493517134


 40%|████      | 24/60 [00:16<00:24,  1.48it/s]

0.7048796481556363


 42%|████▏     | 25/60 [00:16<00:23,  1.47it/s]

0.7029887735843658


 43%|████▎     | 26/60 [00:17<00:23,  1.48it/s]

0.7060040997134315


 45%|████▌     | 27/60 [00:18<00:22,  1.47it/s]

0.6999369727240669


 47%|████▋     | 28/60 [00:19<00:21,  1.46it/s]

0.701446039809121


 48%|████▊     | 29/60 [00:19<00:22,  1.35it/s]

0.7001146442360349


 50%|█████     | 30/60 [00:20<00:21,  1.39it/s]

0.7002322971820831


 52%|█████▏    | 31/60 [00:21<00:20,  1.39it/s]

0.6969302694002787


 53%|█████▎    | 32/60 [00:21<00:19,  1.43it/s]

0.6977107061280144


 55%|█████▌    | 33/60 [00:22<00:18,  1.46it/s]

0.6970432698726654


 57%|█████▋    | 34/60 [00:23<00:17,  1.48it/s]

0.6967293918132782


 58%|█████▊    | 35/60 [00:23<00:16,  1.49it/s]

0.6953974631097581


 60%|██████    | 36/60 [00:24<00:16,  1.50it/s]

0.6945486002498202


 62%|██████▏   | 37/60 [00:25<00:15,  1.50it/s]

0.697061022122701


 63%|██████▎   | 38/60 [00:25<00:14,  1.51it/s]

0.6946071286996206


 65%|██████▌   | 39/60 [00:26<00:13,  1.51it/s]

0.6946765416198306


 67%|██████▋   | 40/60 [00:27<00:13,  1.51it/s]

0.6954101887014177


 68%|██████▊   | 41/60 [00:27<00:12,  1.51it/s]

0.6953517099221548


 70%|███████   | 42/60 [00:28<00:12,  1.48it/s]

0.6913073129124112


 72%|███████▏  | 43/60 [00:29<00:11,  1.49it/s]

0.6925005184279548


 73%|███████▎  | 44/60 [00:29<00:10,  1.49it/s]

0.6912099619706472


 75%|███████▌  | 45/60 [00:30<00:10,  1.49it/s]

0.6921149988969167


 77%|███████▋  | 46/60 [00:31<00:09,  1.50it/s]

0.6916845871342553


 78%|███████▊  | 47/60 [00:31<00:08,  1.50it/s]

0.6890014542473687


 80%|████████  | 48/60 [00:32<00:07,  1.51it/s]

0.6887686120139228


 82%|████████▏ | 49/60 [00:33<00:07,  1.51it/s]

0.6877278155750699


 83%|████████▎ | 50/60 [00:33<00:06,  1.50it/s]

0.6885522140396966


 85%|████████▌ | 51/60 [00:34<00:05,  1.50it/s]

0.6872084968619876


 87%|████████▋ | 52/60 [00:35<00:05,  1.48it/s]

0.6872152388095856


 88%|████████▊ | 53/60 [00:35<00:04,  1.48it/s]

0.6892295148637559


 90%|█████████ | 54/60 [00:36<00:04,  1.49it/s]

0.6865871051947275


 92%|█████████▏| 55/60 [00:37<00:03,  1.50it/s]

0.6870051556163363


 93%|█████████▎| 56/60 [00:37<00:02,  1.51it/s]

0.6850443151262071


 95%|█████████▌| 57/60 [00:38<00:01,  1.51it/s]

0.6858474877145555


 97%|█████████▋| 58/60 [00:39<00:01,  1.52it/s]

0.6823625597688887


 98%|█████████▊| 59/60 [00:39<00:00,  1.50it/s]

0.6844095720185174


100%|██████████| 60/60 [00:40<00:00,  1.48it/s]

0.6845011280642616
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [31]:
model_score2

{'Bleu_1': 0.5785459859347621,
 'Bleu_2': 0.4448593124302879,
 'Bleu_3': 0.36235505770352855,
 'Bleu_4': 0.3058452890340524,
 'METEOR': 0.2560589259488988,
 'ROUGE_L': 0.4823575766155836,
 'CIDEr': 1.6056737477331315,
 'SPICE': 0.3298222080502241,
 'USC_similarity': 0.5566665730939439}

In [32]:
caption_model3, model_score3 = cross_validation(cv[2][0], cv[2][1], 3)    

Split 3:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2714 ==> 890
The vocabulary size is 891.
804 out of 891 words are found in the pre-trained matrix.
The size of embedding_matrix is (891, 200)
Preparing dataloader...

Generating set took: 0:03:31.64


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.99
Training...


  2%|▏         | 1/60 [00:00<00:40,  1.45it/s]

7.562453534868029


  3%|▎         | 2/60 [00:01<00:39,  1.47it/s]

4.861868699391683


  5%|▌         | 3/60 [00:02<00:38,  1.46it/s]

4.553022808498806


  7%|▋         | 4/60 [00:02<00:37,  1.49it/s]

4.230864153967963


  8%|▊         | 5/60 [00:03<00:37,  1.47it/s]

3.7795336511400013


 10%|█         | 6/60 [00:04<00:37,  1.46it/s]

3.2158266968197293


 12%|█▏        | 7/60 [00:04<00:39,  1.34it/s]

2.8176759348975287


 13%|█▎        | 8/60 [00:05<00:37,  1.37it/s]

2.5043775770399304


 15%|█▌        | 9/60 [00:06<00:36,  1.39it/s]

2.2829039096832275


 17%|█▋        | 10/60 [00:07<00:35,  1.43it/s]

2.1068099207348294


 18%|█▊        | 11/60 [00:07<00:33,  1.45it/s]

1.9665739668740168


 20%|██        | 12/60 [00:08<00:32,  1.47it/s]

1.8497486379411485


 22%|██▏       | 13/60 [00:08<00:31,  1.49it/s]

1.7526458236906264


 23%|██▎       | 14/60 [00:09<00:30,  1.50it/s]

1.6667975982030232


 25%|██▌       | 15/60 [00:10<00:29,  1.51it/s]

1.5894825723436143


 27%|██▋       | 16/60 [00:10<00:29,  1.51it/s]

1.5254672765731812


 28%|██▊       | 17/60 [00:11<00:28,  1.52it/s]

1.4642915195888944


 30%|███       | 18/60 [00:12<00:27,  1.52it/s]

1.407977991633945


 32%|███▏      | 19/60 [00:12<00:26,  1.52it/s]

1.343541635407342


 33%|███▎      | 20/60 [00:13<00:26,  1.52it/s]

1.2971638043721516


 35%|███▌      | 21/60 [00:14<00:25,  1.52it/s]

1.2592858341005113


 37%|███▋      | 22/60 [00:14<00:24,  1.52it/s]

1.237511846754286


 38%|███▊      | 23/60 [00:15<00:24,  1.50it/s]

1.2010177440113492


 40%|████      | 24/60 [00:16<00:24,  1.49it/s]

1.164326720767551


 42%|████▏     | 25/60 [00:16<00:23,  1.50it/s]

1.138938652144538


 43%|████▎     | 26/60 [00:17<00:22,  1.51it/s]

1.1263063881132338


 45%|████▌     | 27/60 [00:18<00:22,  1.48it/s]

1.121654252211253


 47%|████▋     | 28/60 [00:18<00:21,  1.49it/s]

1.091657002766927


 48%|████▊     | 29/60 [00:19<00:20,  1.48it/s]

1.06939344935947


 50%|█████     | 30/60 [00:20<00:20,  1.48it/s]

1.0397120515505474


 52%|█████▏    | 31/60 [00:20<00:19,  1.49it/s]

1.0233182311058044


 53%|█████▎    | 32/60 [00:21<00:18,  1.50it/s]

1.0087597171465557


 55%|█████▌    | 33/60 [00:22<00:18,  1.48it/s]

0.9967208703358968


 57%|█████▋    | 34/60 [00:22<00:17,  1.49it/s]

0.988489932484097


 58%|█████▊    | 35/60 [00:23<00:16,  1.48it/s]

0.9597715139389038


 60%|██████    | 36/60 [00:24<00:16,  1.46it/s]

0.9444025556246439


 62%|██████▏   | 37/60 [00:25<00:15,  1.46it/s]

0.9356394873725044


 63%|██████▎   | 38/60 [00:25<00:15,  1.46it/s]

0.9229292935795255


 65%|██████▌   | 39/60 [00:26<00:14,  1.46it/s]

0.8997549878226386


 67%|██████▋   | 40/60 [00:27<00:13,  1.46it/s]

0.8687831362088522


 68%|██████▊   | 41/60 [00:27<00:12,  1.47it/s]

0.8441657854451073


 70%|███████   | 42/60 [00:28<00:12,  1.49it/s]

0.8180964257982042


 72%|███████▏  | 43/60 [00:29<00:11,  1.48it/s]

0.804929667048984


 73%|███████▎  | 44/60 [00:29<00:10,  1.49it/s]

0.800736837916904


 75%|███████▌  | 45/60 [00:30<00:10,  1.50it/s]

0.7850649456183115


 77%|███████▋  | 46/60 [00:31<00:09,  1.49it/s]

0.7909561230076684


 78%|███████▊  | 47/60 [00:31<00:08,  1.46it/s]

0.7985611293050978


 80%|████████  | 48/60 [00:32<00:08,  1.46it/s]

0.7928850750128428


 82%|████████▏ | 49/60 [00:33<00:07,  1.48it/s]

0.7991997500260671


 83%|████████▎ | 50/60 [00:33<00:06,  1.49it/s]

0.8089614477422502


 85%|████████▌ | 51/60 [00:34<00:05,  1.50it/s]

0.8206935425599416


 87%|████████▋ | 52/60 [00:35<00:05,  1.50it/s]

0.8250993225309584


 88%|████████▊ | 53/60 [00:35<00:04,  1.51it/s]

0.8162096606360542


 90%|█████████ | 54/60 [00:36<00:03,  1.51it/s]

0.7929454147815704


 92%|█████████▏| 55/60 [00:37<00:03,  1.51it/s]

0.7676043444209628


 93%|█████████▎| 56/60 [00:37<00:02,  1.52it/s]

0.739208416806327


 95%|█████████▌| 57/60 [00:38<00:01,  1.52it/s]

0.7194038530190786


 97%|█████████▋| 58/60 [00:39<00:01,  1.52it/s]

0.7093147933483124


 98%|█████████▊| 59/60 [00:39<00:00,  1.49it/s]

0.6903580990102556


100%|██████████| 60/60 [00:40<00:00,  1.48it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

0.6901587115393745


  2%|▏         | 1/60 [00:00<00:38,  1.52it/s]

0.7005951205889384


  3%|▎         | 2/60 [00:01<00:38,  1.52it/s]

0.6700332462787628


  5%|▌         | 3/60 [00:01<00:37,  1.51it/s]

0.6377748648325602


  7%|▋         | 4/60 [00:02<00:37,  1.49it/s]

0.6230705877145132


  8%|▊         | 5/60 [00:03<00:37,  1.48it/s]

0.6119989487859938


 10%|█         | 6/60 [00:04<00:36,  1.48it/s]

0.6077804002496932


 12%|█▏        | 7/60 [00:04<00:36,  1.47it/s]

0.6026480926407708


 13%|█▎        | 8/60 [00:05<00:35,  1.45it/s]

0.5994108551078372


 15%|█▌        | 9/60 [00:06<00:35,  1.45it/s]

0.5981231033802032


 17%|█▋        | 10/60 [00:06<00:34,  1.45it/s]

0.5941537519296011


 18%|█▊        | 11/60 [00:07<00:33,  1.45it/s]

0.5928024252255758


 20%|██        | 12/60 [00:08<00:32,  1.48it/s]

0.5895608961582184


 22%|██▏       | 13/60 [00:08<00:33,  1.40it/s]

0.586419423421224


 23%|██▎       | 14/60 [00:09<00:34,  1.35it/s]

0.589569005701277


 25%|██▌       | 15/60 [00:10<00:32,  1.38it/s]

0.5869067509969076


 27%|██▋       | 16/60 [00:11<00:30,  1.42it/s]

0.5834913982285393


 28%|██▊       | 17/60 [00:11<00:29,  1.45it/s]

0.5846842858526442


 30%|███       | 18/60 [00:12<00:28,  1.47it/s]

0.5807293885284


 32%|███▏      | 19/60 [00:13<00:27,  1.48it/s]

0.5785483850373162


 33%|███▎      | 20/60 [00:13<00:27,  1.47it/s]

0.5781066715717316


 35%|███▌      | 21/60 [00:14<00:26,  1.48it/s]

0.5789741145239936


 37%|███▋      | 22/60 [00:15<00:25,  1.48it/s]

0.576576421658198


 38%|███▊      | 23/60 [00:15<00:24,  1.50it/s]

0.5758328404691484


 40%|████      | 24/60 [00:16<00:23,  1.51it/s]

0.5744055112202963


 42%|████▏     | 25/60 [00:17<00:23,  1.51it/s]

0.5695056948396895


 43%|████▎     | 26/60 [00:17<00:22,  1.52it/s]

0.5712289578384824


 45%|████▌     | 27/60 [00:18<00:21,  1.51it/s]

0.5740855468644036


 47%|████▋     | 28/60 [00:19<00:21,  1.52it/s]

0.5701034135288663


 48%|████▊     | 29/60 [00:19<00:20,  1.52it/s]

0.571451723575592


 50%|█████     | 30/60 [00:20<00:19,  1.51it/s]

0.5711054305235544


 52%|█████▏    | 31/60 [00:21<00:19,  1.51it/s]

0.567142085896598


 53%|█████▎    | 32/60 [00:21<00:18,  1.52it/s]

0.5699804458353255


 55%|█████▌    | 33/60 [00:22<00:17,  1.52it/s]

0.5647822419802347


 57%|█████▋    | 34/60 [00:22<00:17,  1.53it/s]

0.5666676031218635


 58%|█████▊    | 35/60 [00:23<00:16,  1.50it/s]

0.5680877301428053


 60%|██████    | 36/60 [00:24<00:15,  1.50it/s]

0.5648586683803134


 62%|██████▏   | 37/60 [00:25<00:15,  1.51it/s]

0.566739430030187


 63%|██████▎   | 38/60 [00:25<00:14,  1.49it/s]

0.5652131074004703


 65%|██████▌   | 39/60 [00:26<00:14,  1.47it/s]

0.5634404189056821


 67%|██████▋   | 40/60 [00:27<00:13,  1.49it/s]

0.5626689791679382


 68%|██████▊   | 41/60 [00:27<00:12,  1.47it/s]

0.565568271610472


 70%|███████   | 42/60 [00:28<00:12,  1.48it/s]

0.5617318020926582


 72%|███████▏  | 43/60 [00:29<00:11,  1.49it/s]

0.5642505817943149


 73%|███████▎  | 44/60 [00:29<00:10,  1.47it/s]

0.5636290609836578


 75%|███████▌  | 45/60 [00:30<00:10,  1.49it/s]

0.5624396469857957


 77%|███████▋  | 46/60 [00:31<00:09,  1.48it/s]

0.5570059286223518


 78%|███████▊  | 47/60 [00:31<00:08,  1.47it/s]

0.5601380235619016


 80%|████████  | 48/60 [00:32<00:08,  1.48it/s]

0.5607609781954024


 82%|████████▏ | 49/60 [00:33<00:07,  1.49it/s]

0.5578591922918955


 83%|████████▎ | 50/60 [00:33<00:06,  1.49it/s]

0.5611722701125674


 85%|████████▌ | 51/60 [00:34<00:05,  1.50it/s]

0.5585713850127326


 87%|████████▋ | 52/60 [00:35<00:05,  1.51it/s]

0.5597295165061951


 88%|████████▊ | 53/60 [00:35<00:04,  1.52it/s]

0.5553351375791762


 90%|█████████ | 54/60 [00:36<00:03,  1.52it/s]

0.5585647622744242


 92%|█████████▏| 55/60 [00:37<00:03,  1.50it/s]

0.5554764866828918


 93%|█████████▎| 56/60 [00:37<00:02,  1.47it/s]

0.5582125749852922


 95%|█████████▌| 57/60 [00:38<00:02,  1.48it/s]

0.552367604441113


 97%|█████████▋| 58/60 [00:39<00:01,  1.49it/s]

0.5556019047896067


 98%|█████████▊| 59/60 [00:40<00:00,  1.34it/s]

0.5545858773920271


100%|██████████| 60/60 [00:40<00:00,  1.47it/s]

0.5561814374393887
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [33]:
model_score3

{'Bleu_1': 0.574418193080233,
 'Bleu_2': 0.4421688182439931,
 'Bleu_3': 0.35924982053440085,
 'Bleu_4': 0.3017466121523458,
 'METEOR': 0.2624485798782999,
 'ROUGE_L': 0.48238566719182535,
 'CIDEr': 1.6307152663418596,
 'SPICE': 0.332317953874064,
 'USC_similarity': 0.5571499854480948}

In [34]:
caption_model4, model_score4 = cross_validation(cv[3][0], cv[3][1], 4)    

Split 4:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2680 ==> 894
The vocabulary size is 895.
809 out of 895 words are found in the pre-trained matrix.
The size of embedding_matrix is (895, 200)
Preparing dataloader...

Generating set took: 0:03:29.64


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.32
Training...


  2%|▏         | 1/60 [00:00<00:41,  1.44it/s]

6.674317995707194


  3%|▎         | 2/60 [00:01<00:40,  1.45it/s]

4.94692325592041


  5%|▌         | 3/60 [00:02<00:39,  1.45it/s]

4.7411514388190374


  7%|▋         | 4/60 [00:02<00:38,  1.45it/s]

4.389904552035862


  8%|▊         | 5/60 [00:03<00:37,  1.48it/s]

3.9909693400065103


 10%|█         | 6/60 [00:04<00:36,  1.47it/s]

3.4395955668555365


 12%|█▏        | 7/60 [00:04<00:35,  1.49it/s]

3.0308054288228354


 13%|█▎        | 8/60 [00:05<00:35,  1.47it/s]

2.738634983698527


 15%|█▌        | 9/60 [00:06<00:34,  1.46it/s]

2.4888557857937283


 17%|█▋        | 10/60 [00:06<00:34,  1.46it/s]

2.2869721253712973


 18%|█▊        | 11/60 [00:07<00:33,  1.46it/s]

2.1194562779532538


 20%|██        | 12/60 [00:08<00:32,  1.46it/s]

1.9883422586652968


 22%|██▏       | 13/60 [00:08<00:32,  1.45it/s]

1.8768106698989868


 23%|██▎       | 14/60 [00:09<00:31,  1.45it/s]

1.7803514003753662


 25%|██▌       | 15/60 [00:10<00:31,  1.42it/s]

1.7022605074776544


 27%|██▋       | 16/60 [00:10<00:30,  1.45it/s]

1.6256775591108534


 28%|██▊       | 17/60 [00:11<00:29,  1.47it/s]

1.5616307788425021


 30%|███       | 18/60 [00:12<00:28,  1.46it/s]

1.502992762459649


 32%|███▏      | 19/60 [00:13<00:28,  1.46it/s]

1.4354997873306274


 33%|███▎      | 20/60 [00:13<00:27,  1.45it/s]

1.3872435887654622


 35%|███▌      | 21/60 [00:14<00:27,  1.41it/s]

1.3423565493689642


 37%|███▋      | 22/60 [00:15<00:26,  1.42it/s]

1.3008806374337938


 38%|███▊      | 23/60 [00:15<00:26,  1.40it/s]

1.2595281998316448


 40%|████      | 24/60 [00:16<00:25,  1.44it/s]

1.2205838097466364


 42%|████▏     | 25/60 [00:17<00:23,  1.46it/s]

1.1934135092629328


 43%|████▎     | 26/60 [00:17<00:23,  1.47it/s]

1.178222934405009


 45%|████▌     | 27/60 [00:18<00:22,  1.48it/s]

1.1690627535184224


 47%|████▋     | 28/60 [00:19<00:22,  1.45it/s]

1.1556638214323256


 48%|████▊     | 29/60 [00:19<00:21,  1.44it/s]

1.1271476878060236


 50%|█████     | 30/60 [00:20<00:20,  1.46it/s]

1.0870091848903232


 52%|█████▏    | 31/60 [00:21<00:19,  1.46it/s]

1.0379336542553372


 53%|█████▎    | 32/60 [00:21<00:19,  1.46it/s]

1.0159527990553114


 55%|█████▌    | 33/60 [00:22<00:18,  1.47it/s]

0.9934971597459581


 57%|█████▋    | 34/60 [00:23<00:17,  1.49it/s]

0.9775836931334602


 58%|█████▊    | 35/60 [00:23<00:16,  1.49it/s]

0.957735644446479


 60%|██████    | 36/60 [00:24<00:16,  1.48it/s]

0.9293661846054925


 62%|██████▏   | 37/60 [00:25<00:15,  1.47it/s]

0.9036832518047757


 63%|██████▎   | 38/60 [00:26<00:14,  1.49it/s]

0.8793012102444967


 65%|██████▌   | 39/60 [00:26<00:14,  1.47it/s]

0.8572797311676873


 67%|██████▋   | 40/60 [00:27<00:13,  1.48it/s]

0.8430519567595588


 68%|██████▊   | 41/60 [00:28<00:12,  1.47it/s]

0.8427747322453393


 70%|███████   | 42/60 [00:28<00:12,  1.49it/s]

0.8430986172623105


 72%|███████▏  | 43/60 [00:29<00:12,  1.32it/s]

0.8472269905938042


 73%|███████▎  | 44/60 [00:30<00:11,  1.36it/s]

0.8563219507535299


 75%|███████▌  | 45/60 [00:31<00:10,  1.39it/s]

0.8392560680707296


 77%|███████▋  | 46/60 [00:31<00:09,  1.42it/s]

0.8073639902803633


 78%|███████▊  | 47/60 [00:32<00:08,  1.45it/s]

0.7810443176163567


 80%|████████  | 48/60 [00:33<00:08,  1.47it/s]

0.7793405652046204


 82%|████████▏ | 49/60 [00:33<00:07,  1.48it/s]

0.7610534528891245


 83%|████████▎ | 50/60 [00:34<00:06,  1.48it/s]

0.731797436873118


 85%|████████▌ | 51/60 [00:35<00:06,  1.46it/s]

0.7194017271200815


 87%|████████▋ | 52/60 [00:35<00:05,  1.46it/s]

0.7038402689827813


 88%|████████▊ | 53/60 [00:36<00:04,  1.48it/s]

0.6981623437669542


 90%|█████████ | 54/60 [00:37<00:04,  1.47it/s]

0.7018850876225365


 92%|█████████▏| 55/60 [00:37<00:03,  1.46it/s]

0.6970381902323829


 93%|█████████▎| 56/60 [00:38<00:02,  1.48it/s]

0.7064537935786777


 95%|█████████▌| 57/60 [00:39<00:02,  1.49it/s]

0.7096636460887061


 97%|█████████▋| 58/60 [00:39<00:01,  1.48it/s]

0.7296372618940141


 98%|█████████▊| 59/60 [00:40<00:00,  1.49it/s]

0.7236628267500136


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

0.7278102768792046


  2%|▏         | 1/60 [00:00<00:46,  1.27it/s]

0.7034206125471327


  3%|▎         | 2/60 [00:01<00:43,  1.32it/s]

0.6740801864200168


  5%|▌         | 3/60 [00:02<00:41,  1.37it/s]

0.6496889127625359


  7%|▋         | 4/60 [00:02<00:40,  1.39it/s]

0.6323864327536689


  8%|▊         | 5/60 [00:03<00:39,  1.41it/s]

0.622483491897583


 10%|█         | 6/60 [00:04<00:37,  1.44it/s]

0.6182578537199233


 12%|█▏        | 7/60 [00:04<00:36,  1.44it/s]

0.6118407845497131


 13%|█▎        | 8/60 [00:05<00:36,  1.44it/s]

0.6095757650004493


 15%|█▌        | 9/60 [00:06<00:36,  1.41it/s]

0.606552435292138


 17%|█▋        | 10/60 [00:06<00:35,  1.43it/s]

0.6035411291652255


 18%|█▊        | 11/60 [00:07<00:34,  1.41it/s]

0.5995110472043356


 20%|██        | 12/60 [00:08<00:33,  1.43it/s]

0.6006060043970743


 22%|██▏       | 13/60 [00:09<00:32,  1.44it/s]

0.598889582686954


 23%|██▎       | 14/60 [00:09<00:31,  1.45it/s]

0.5955742299556732


 25%|██▌       | 15/60 [00:10<00:30,  1.45it/s]

0.5958088835080465


 27%|██▋       | 16/60 [00:11<00:30,  1.42it/s]

0.591465049319797


 28%|██▊       | 17/60 [00:11<00:30,  1.41it/s]

0.5896102653609382


 30%|███       | 18/60 [00:12<00:29,  1.42it/s]

0.591935031943851


 32%|███▏      | 19/60 [00:13<00:28,  1.45it/s]

0.5898158649603525


 33%|███▎      | 20/60 [00:13<00:27,  1.45it/s]

0.5875891149044037


 35%|███▌      | 21/60 [00:14<00:26,  1.45it/s]

0.5876754621664683


 37%|███▋      | 22/60 [00:15<00:26,  1.45it/s]

0.5871107975641886


 38%|███▊      | 23/60 [00:16<00:25,  1.45it/s]

0.5834018488725027


 40%|████      | 24/60 [00:16<00:24,  1.47it/s]

0.5823549826939901


 42%|████▏     | 25/60 [00:17<00:23,  1.46it/s]

0.5825100044409434


 43%|████▎     | 26/60 [00:18<00:23,  1.43it/s]

0.5838722586631775


 45%|████▌     | 27/60 [00:18<00:23,  1.43it/s]

0.5784650776121352


 47%|████▋     | 28/60 [00:19<00:21,  1.46it/s]

0.5800622569190131


 48%|████▊     | 29/60 [00:20<00:21,  1.47it/s]

0.5794006519847446


 50%|█████     | 30/60 [00:20<00:20,  1.49it/s]

0.5775232050153944


 52%|█████▏    | 31/60 [00:21<00:19,  1.49it/s]

0.5759829812579684


 53%|█████▎    | 32/60 [00:22<00:18,  1.49it/s]

0.5750317871570587


 55%|█████▌    | 33/60 [00:22<00:18,  1.47it/s]

0.5780500736501482


 57%|█████▋    | 34/60 [00:23<00:17,  1.48it/s]

0.574063198433982


 58%|█████▊    | 35/60 [00:24<00:16,  1.47it/s]

0.5722175306744046


 60%|██████    | 36/60 [00:24<00:16,  1.48it/s]

0.5748936500814226


 62%|██████▏   | 37/60 [00:25<00:15,  1.49it/s]

0.5731510586208768


 63%|██████▎   | 38/60 [00:26<00:14,  1.48it/s]

0.5717971258693271


 65%|██████▌   | 39/60 [00:26<00:14,  1.47it/s]

0.5722685390048556


 67%|██████▋   | 40/60 [00:27<00:13,  1.49it/s]

0.572830448547999


 68%|██████▊   | 41/60 [00:28<00:12,  1.50it/s]

0.5715024471282959


 70%|███████   | 42/60 [00:28<00:11,  1.50it/s]

0.5669505662388272


 72%|███████▏  | 43/60 [00:29<00:11,  1.49it/s]

0.5696987013022105


 73%|███████▎  | 44/60 [00:30<00:10,  1.49it/s]

0.5682326091660393


 75%|███████▌  | 45/60 [00:30<00:10,  1.47it/s]

0.5680545932716794


 77%|███████▋  | 46/60 [00:31<00:09,  1.46it/s]

0.5683516793780856


 78%|███████▊  | 47/60 [00:32<00:08,  1.47it/s]

0.5664929979377322


 80%|████████  | 48/60 [00:32<00:08,  1.47it/s]

0.5675447881221771


 82%|████████▏ | 49/60 [00:33<00:07,  1.46it/s]

0.5666383935345544


 83%|████████▎ | 50/60 [00:34<00:06,  1.47it/s]

0.5676566958427429


 85%|████████▌ | 51/60 [00:34<00:06,  1.47it/s]

0.5658147798644172


 87%|████████▋ | 52/60 [00:35<00:05,  1.47it/s]

0.5617484019862281


 88%|████████▊ | 53/60 [00:36<00:04,  1.46it/s]

0.5662431485123105


 90%|█████████ | 54/60 [00:37<00:04,  1.35it/s]

0.5622819662094116


 92%|█████████▏| 55/60 [00:37<00:03,  1.40it/s]

0.565586096710629


 93%|█████████▎| 56/60 [00:38<00:02,  1.42it/s]

0.5627955827448103


 95%|█████████▌| 57/60 [00:39<00:02,  1.45it/s]

0.559676875670751


 97%|█████████▋| 58/60 [00:39<00:01,  1.44it/s]

0.5642122560077243


 98%|█████████▊| 59/60 [00:40<00:00,  1.43it/s]

0.559569431675805


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]

0.5612466004159715
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [35]:
model_score4

{'Bleu_1': 0.569833634879283,
 'Bleu_2': 0.43622849826275784,
 'Bleu_3': 0.3534908746607434,
 'Bleu_4': 0.296716296814266,
 'METEOR': 0.26063582167232985,
 'ROUGE_L': 0.4822058687055563,
 'CIDEr': 1.5984258627891395,
 'SPICE': 0.33053591535303134,
 'USC_similarity': 0.5547139884228449}

In [36]:
caption_model5, model_score5 = cross_validation(cv[4][0], cv[4][1], 5)    

Split 5:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2657 ==> 905
The vocabulary size is 906.
818 out of 906 words are found in the pre-trained matrix.
The size of embedding_matrix is (906, 200)
Preparing dataloader...

Generating set took: 0:03:32.90


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.84
Training...


  2%|▏         | 1/60 [00:00<00:40,  1.47it/s]

5.819000932905409


  3%|▎         | 2/60 [00:01<00:38,  1.49it/s]

4.623998271094428


  5%|▌         | 3/60 [00:01<00:38,  1.49it/s]

4.311259534623888


  7%|▋         | 4/60 [00:02<00:38,  1.45it/s]

3.6859467559390597


  8%|▊         | 5/60 [00:03<00:38,  1.44it/s]

3.1441366937425403


 10%|█         | 6/60 [00:04<00:37,  1.43it/s]

2.7388337983025446


 12%|█▏        | 7/60 [00:04<00:38,  1.39it/s]

2.434957159890069


 13%|█▎        | 8/60 [00:05<00:37,  1.39it/s]

2.2209160327911377


 15%|█▌        | 9/60 [00:06<00:36,  1.38it/s]

2.0404233932495117


 17%|█▋        | 10/60 [00:07<00:35,  1.42it/s]

1.8972665866216023


 18%|█▊        | 11/60 [00:07<00:33,  1.45it/s]

1.7812832858827379


 20%|██        | 12/60 [00:08<00:33,  1.45it/s]

1.6727632284164429


 22%|██▏       | 13/60 [00:09<00:32,  1.44it/s]

1.5825706058078342


 23%|██▎       | 14/60 [00:09<00:31,  1.47it/s]

1.4916062884860568


 25%|██▌       | 15/60 [00:10<00:30,  1.48it/s]

1.433678600523207


 27%|██▋       | 16/60 [00:11<00:30,  1.44it/s]

1.391125480333964


 28%|██▊       | 17/60 [00:11<00:29,  1.46it/s]

1.3424060212241278


 30%|███       | 18/60 [00:12<00:28,  1.48it/s]

1.3033683233790927


 32%|███▏      | 19/60 [00:13<00:27,  1.47it/s]

1.2415456308258905


 33%|███▎      | 20/60 [00:13<00:27,  1.47it/s]

1.1984360218048096


 35%|███▌      | 21/60 [00:14<00:26,  1.47it/s]

1.1614536841710408


 37%|███▋      | 22/60 [00:15<00:25,  1.46it/s]

1.1331197685665555


 38%|███▊      | 23/60 [00:15<00:25,  1.46it/s]

1.1291412247551813


 40%|████      | 24/60 [00:16<00:24,  1.46it/s]

1.1109963986608717


 42%|████▏     | 25/60 [00:17<00:23,  1.48it/s]

1.068782091140747


 43%|████▎     | 26/60 [00:17<00:23,  1.47it/s]

1.0253342125150893


 45%|████▌     | 27/60 [00:18<00:22,  1.46it/s]

1.0007941193050809


 47%|████▋     | 28/60 [00:19<00:21,  1.46it/s]

0.9677770733833313


 48%|████▊     | 29/60 [00:19<00:21,  1.46it/s]

0.9487965967920091


 50%|█████     | 30/60 [00:20<00:20,  1.47it/s]

0.9391911692089505


 52%|█████▏    | 31/60 [00:21<00:20,  1.41it/s]

0.9170370234383477


 53%|█████▎    | 32/60 [00:22<00:20,  1.34it/s]

0.9060399532318115


 55%|█████▌    | 33/60 [00:22<00:19,  1.38it/s]

0.9086577428711785


 57%|█████▋    | 34/60 [00:23<00:18,  1.42it/s]

0.8814406130048964


 58%|█████▊    | 35/60 [00:24<00:17,  1.43it/s]

0.8507642878426446


 60%|██████    | 36/60 [00:24<00:16,  1.46it/s]

0.8361554642518362


 62%|██████▏   | 37/60 [00:25<00:15,  1.47it/s]

0.8221549457973905


 63%|██████▎   | 38/60 [00:26<00:17,  1.29it/s]

0.8147957623004913


 65%|██████▌   | 39/60 [00:27<00:15,  1.33it/s]

0.8197076386875577


 67%|██████▋   | 40/60 [00:27<00:14,  1.39it/s]

0.8233094844553206


 68%|██████▊   | 41/60 [00:28<00:13,  1.41it/s]

0.8205106258392334


 70%|███████   | 42/60 [00:29<00:12,  1.44it/s]

0.8177879055341085


 72%|███████▏  | 43/60 [00:30<00:12,  1.41it/s]

0.8108910255961947


 73%|███████▎  | 44/60 [00:30<00:11,  1.44it/s]

0.7774438990486993


 75%|███████▌  | 45/60 [00:31<00:10,  1.46it/s]

0.7562845746676127


 77%|███████▋  | 46/60 [00:31<00:09,  1.48it/s]

0.7454724311828613


 78%|███████▊  | 47/60 [00:32<00:08,  1.50it/s]

0.7435076004929013


 80%|████████  | 48/60 [00:33<00:08,  1.50it/s]

0.763325528966056


 82%|████████▏ | 49/60 [00:33<00:07,  1.50it/s]

0.7673898273044162


 83%|████████▎ | 50/60 [00:34<00:06,  1.51it/s]

0.7557051910294427


 85%|████████▌ | 51/60 [00:35<00:05,  1.50it/s]

0.705045246415668


 87%|████████▋ | 52/60 [00:35<00:05,  1.51it/s]

0.6787815557585822


 88%|████████▊ | 53/60 [00:36<00:04,  1.46it/s]

0.6517044537597232


 90%|█████████ | 54/60 [00:37<00:04,  1.47it/s]

0.6314957870377435


 92%|█████████▏| 55/60 [00:38<00:03,  1.49it/s]

0.624818334976832


 93%|█████████▎| 56/60 [00:38<00:02,  1.48it/s]

0.6147537463241153


 95%|█████████▌| 57/60 [00:39<00:02,  1.49it/s]

0.613374776310391


 97%|█████████▋| 58/60 [00:40<00:01,  1.50it/s]

0.6141041318575541


 98%|█████████▊| 59/60 [00:40<00:00,  1.45it/s]

0.5990649031268226


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

0.5990864005353715


  2%|▏         | 1/60 [00:00<00:38,  1.52it/s]

0.585053284962972


  3%|▎         | 2/60 [00:01<00:39,  1.49it/s]

0.5650495853688982


  5%|▌         | 3/60 [00:02<00:39,  1.46it/s]

0.547608474890391


  7%|▋         | 4/60 [00:02<00:38,  1.46it/s]

0.535837451616923


  8%|▊         | 5/60 [00:03<00:37,  1.45it/s]

0.5296599600050185


 10%|█         | 6/60 [00:04<00:38,  1.41it/s]

0.5225343505541483


 12%|█▏        | 7/60 [00:04<00:36,  1.44it/s]

0.5248219337728288


 13%|█▎        | 8/60 [00:05<00:36,  1.44it/s]

0.5195172660880618


 15%|█▌        | 9/60 [00:06<00:36,  1.39it/s]

0.5184169511000315


 17%|█▋        | 10/60 [00:07<00:35,  1.41it/s]

0.5163183079825507


 18%|█▊        | 11/60 [00:07<00:34,  1.42it/s]

0.5187278158134885


 20%|██        | 12/60 [00:08<00:33,  1.43it/s]

0.5123660067717234


 22%|██▏       | 13/60 [00:09<00:32,  1.44it/s]

0.5098487966590457


 23%|██▎       | 14/60 [00:09<00:31,  1.45it/s]

0.5137540168232388


 25%|██▌       | 15/60 [00:10<00:31,  1.45it/s]

0.5103020204438103


 27%|██▋       | 16/60 [00:11<00:30,  1.45it/s]

0.5069649186399248


 28%|██▊       | 17/60 [00:11<00:29,  1.45it/s]

0.5080334577295516


 30%|███       | 18/60 [00:12<00:28,  1.45it/s]

0.505882653925154


 32%|███▏      | 19/60 [00:13<00:28,  1.45it/s]

0.5045655734009213


 33%|███▎      | 20/60 [00:13<00:27,  1.46it/s]

0.5063358247280121


 35%|███▌      | 21/60 [00:14<00:27,  1.44it/s]

0.5045105086432563


 37%|███▋      | 22/60 [00:15<00:26,  1.45it/s]

0.504467142952813


 38%|███▊      | 23/60 [00:15<00:25,  1.47it/s]

0.5040562219089932


 40%|████      | 24/60 [00:16<00:24,  1.49it/s]

0.5013259914186265


 42%|████▏     | 25/60 [00:17<00:23,  1.47it/s]

0.5008742147021823


 43%|████▎     | 26/60 [00:17<00:22,  1.49it/s]

0.4980105360349019


 45%|████▌     | 27/60 [00:18<00:22,  1.47it/s]

0.49995884630415177


 47%|████▋     | 28/60 [00:19<00:21,  1.47it/s]

0.5012862748569913


 48%|████▊     | 29/60 [00:20<00:21,  1.45it/s]

0.4975113521019618


 50%|█████     | 30/60 [00:20<00:20,  1.45it/s]

0.49820635716120404


 52%|█████▏    | 31/60 [00:21<00:20,  1.42it/s]

0.4964798208740022


 53%|█████▎    | 32/60 [00:22<00:19,  1.45it/s]

0.4980522692203522


 55%|█████▌    | 33/60 [00:22<00:18,  1.47it/s]

0.49767551488346523


 57%|█████▋    | 34/60 [00:23<00:17,  1.49it/s]

0.4964868956142002


 58%|█████▊    | 35/60 [00:24<00:16,  1.50it/s]

0.4963143567244212


 60%|██████    | 36/60 [00:24<00:15,  1.51it/s]

0.4947066522306866


 62%|██████▏   | 37/60 [00:25<00:15,  1.51it/s]

0.49243129293123883


 63%|██████▎   | 38/60 [00:26<00:15,  1.46it/s]

0.49398382504781085


 65%|██████▌   | 39/60 [00:26<00:14,  1.47it/s]

0.491820494333903


 67%|██████▋   | 40/60 [00:27<00:13,  1.49it/s]

0.49333399203088546


 68%|██████▊   | 41/60 [00:28<00:12,  1.50it/s]

0.49191172586547005


 70%|███████   | 42/60 [00:28<00:12,  1.50it/s]

0.49000973337226444


 72%|███████▏  | 43/60 [00:29<00:11,  1.50it/s]

0.4914623134666019


 73%|███████▎  | 44/60 [00:30<00:10,  1.49it/s]

0.4901556985245811


 75%|███████▌  | 45/60 [00:30<00:09,  1.50it/s]

0.48912012245919967


 77%|███████▋  | 46/60 [00:31<00:09,  1.51it/s]

0.4899021089076996


 78%|███████▊  | 47/60 [00:32<00:08,  1.51it/s]

0.4880598551697201


 80%|████████  | 48/60 [00:32<00:07,  1.51it/s]

0.49085312750604415


 82%|████████▏ | 49/60 [00:33<00:07,  1.52it/s]

0.48683376444710624


 83%|████████▎ | 50/60 [00:34<00:06,  1.51it/s]

0.48894432849354214


 85%|████████▌ | 51/60 [00:34<00:05,  1.51it/s]

0.4878729283809662


 87%|████████▋ | 52/60 [00:35<00:05,  1.52it/s]

0.4882989128430684


 88%|████████▊ | 53/60 [00:36<00:04,  1.52it/s]

0.4860113925404019


 90%|█████████ | 54/60 [00:36<00:03,  1.52it/s]

0.4857048971785439


 92%|█████████▏| 55/60 [00:37<00:03,  1.52it/s]

0.48816466828187305


 93%|█████████▎| 56/60 [00:38<00:02,  1.53it/s]

0.48808714085155064


 95%|█████████▌| 57/60 [00:38<00:01,  1.52it/s]

0.4876735938919915


 97%|█████████▋| 58/60 [00:39<00:01,  1.52it/s]

0.4876762330532074


 98%|█████████▊| 59/60 [00:40<00:00,  1.51it/s]

0.48566734459665084


100%|██████████| 60/60 [00:40<00:00,  1.47it/s]

0.4829261435402764
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [37]:
model_score5

{'Bleu_1': 0.5691472799252449,
 'Bleu_2': 0.4307540418705746,
 'Bleu_3': 0.34516250875639776,
 'Bleu_4': 0.2866749467156631,
 'METEOR': 0.2550328192561611,
 'ROUGE_L': 0.47416165913321595,
 'CIDEr': 1.5328609906231703,
 'SPICE': 0.3200072907793861,
 'USC_similarity': 0.5509068637888698}

In [38]:
model_scores = defaultdict(list)
for scores in [model_score1, model_score2, model_score3, model_score4, model_score5]:
    for key, value in scores.items():
        model_scores[key].append(value)

In [39]:
model_scores

defaultdict(list,
            {'Bleu_1': [0.563616964523848,
              0.5785459859347621,
              0.574418193080233,
              0.569833634879283,
              0.5691472799252449],
             'Bleu_2': [0.4294634886655052,
              0.4448593124302879,
              0.4421688182439931,
              0.43622849826275784,
              0.4307540418705746],
             'Bleu_3': [0.3459804433826117,
              0.36235505770352855,
              0.35924982053440085,
              0.3534908746607434,
              0.34516250875639776],
             'Bleu_4': [0.2886627182799963,
              0.3058452890340524,
              0.3017466121523458,
              0.296716296814266,
              0.2866749467156631],
             'METEOR': [0.25157413391451316,
              0.2560589259488988,
              0.2624485798782999,
              0.26063582167232985,
              0.2550328192561611],
             'ROUGE_L': [0.46926991145831937,
              0.4823575766155

In [40]:
tag = '9.1.2.1'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'w') as fp:
    json.dump(model_scores, fp)