## Image Captioning with Pytorch

The following contents are modified from MDS DSCI 575 lecture 8 demo

In [1]:
import os, sys, json
from collections import defaultdict
from tqdm import tqdm
import pickle
from time import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torchsummary import summary
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

from nltk.translate import bleu_score
from sklearn.model_selection import KFold

sys.path.append('../../scr/evaluation')
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.usc_sim.usc_sim import usc_sim
import subprocess


START = "startseq"
STOP = "endseq"
EPOCHS = 10
AWS = True


In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
# torch.cuda.empty_cache()
# import gc 
# gc.collect()

In [4]:
# Nicely formatted time string
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m:>02}:{s:>05.2f}"
        
if AWS:
    root_captioning = "../../data"
else:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        root_captioning = "/content/drive/My Drive/data"
        COLAB = True
        print("Note: using Google CoLab")
    except:
        print("Note: not using Google CoLab")
        COLAB = False

### Clean/Build Dataset

- Read captions
- Preprocess captions


In [5]:
def get_img_info(name, num=np.inf):
    """
    Returns img paths and captions

    Parameters:
    -----------
    name: str
        the json file name
    num: int (default: np.inf)
        the number of observations to get

    Return:
    --------
    list, dict, int
        img paths, corresponding captions, max length of captions
    """
    img_path = []
    caption = [] 
    max_length = 0
    if AWS:
        with open(f'{root_captioning}/json/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for filename in data.keys():
                if num is not None and len(caption) == num:
                    break
                img_path.append(
                    f'{root_captioning}/{name}/{filename}'
                )
                sen_list = []
                for sentence in data[filename]['sentences']:
                    max_length = max(max_length, len(sentence['tokens']))
                    sen_list.append(sentence['raw'])

                caption.append(sen_list)    
    else:            
        with open(f'{root_captioning}/interim/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for set_name in ['rsicd', 'ucm']:
                for filename in data[set_name].keys():
                    if num is not None and len(caption) == num:
                        break

                    img_path.append(
                        f'{root_captioning}/raw/imgs/{set_name}/{filename}'
                    )
                    sen_list = []
                    for sentence in data[set_name][filename]['sentences']:
                        max_length = max(max_length, len(sentence['tokens']))
                        sen_list.append(sentence['raw'])

                    caption.append(sen_list)
    
    return img_path, caption, max_length            


In [6]:
# get img path and caption list
# # only test 800 train samples and 200 valid samples
# train_paths, train_descriptions, max_length_train = get_img_info('train', 800)
# test_paths, test_descriptions, max_length_test = get_img_info('valid', 200)

train_paths, train_descriptions, max_length_train = get_img_info('train')
test_paths, test_descriptions, max_length_test = get_img_info('valid')
max_length = max(max_length_train, max_length_test)



In [7]:
all_paths = train_paths.copy()
all_paths.extend(test_paths.copy())
all_paths = np.array(all_paths)

all_descriptions = train_descriptions.copy()
all_descriptions.extend(test_descriptions.copy())
all_descriptions = np.array(all_descriptions)

captions = all_descriptions.copy()
max_length_all = max(max_length_train, max_length_test)
max_length = max_length_all + 2
      
lex = set()
for sen in all_descriptions:
    [lex.update(d.split()) for d in sen]
    
# add a start and stop token at the beginning/end
for v in all_descriptions:
    for d in range(len(v)):
        v[d] = f'{START} {v[d]} {STOP}'
        
print(f'There are {len(all_paths)} images') 
print(f'There are {len(lex)} unique words (vocab)')
print(f'The maximum length of captions with start and stop token is {max_length}.')


There are 10416 images
There are 2912 unique words (vocab)
The maximum length of captions with start and stop token is 36.


In [8]:
all_paths[-1]

'../../s3/valid/rsicd_park_33.jpg'

In [9]:
all_descriptions[-1]

array(['startseq a vast artificial lake was built in the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq'],
      dtype='<U184')

### Loading Wikipedia2vec Embeddings

In [10]:
# read the embedding matrix 
with open(f'{root_captioning}/enwiki_20180420_2338_words_500d.json', 'r', encoding='utf-8') as file:
    embeddings_index = json.load(file)

In [11]:
def get_vocab(descriptions, word_count_threshold=10):

    captions = []
    for val in descriptions:
        for cap in val:
            captions.append(cap)
    print(f'There are {len(captions)} captions')
    
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('preprocessed words %d ==> %d' % (len(word_counts), len(vocab)))
    return vocab

def get_word_dict(vocab):
    
    idxtoword = {}
    wordtoidx = {}

    ix = 1
    for w in vocab:
        wordtoidx[w] = ix
        idxtoword[ix] = w
        ix += 1

    return idxtoword, wordtoidx

def get_vocab_size(idxtoword):
    
    print(f'The vocabulary size is {len(idxtoword) + 1}.')
    return len(idxtoword) + 1


def get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx):

    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    count = 0

    for word, i in wordtoidx.items():

        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            count += 1
            # Words not found in the embedding index will be all zeros
            embedding_matrix[i] = embedding_vector
            
    print(f'{count} out of {vocab_size} words are found in the pre-trained matrix.')            
    print(f'The size of embedding_matrix is {embedding_matrix.shape}')
    return embedding_matrix

### Building the Neural Network

An embedding matrix is built from Glove.  This will be directly copied to the weight matrix of the neural network.

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
class CNNModel(nn.Module):

    def __init__(self, pretrained=True):
        """
        Initializes a CNNModel

        Parameters:
        -----------
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(CNNModel, self).__init__()

        # inception v3 expects (299, 299) sized images
        self.model = models.inception_v3(pretrained=pretrained, aux_logits=False)
        # remove the classification layer
        self.model =\
        nn.Sequential(
            *(list(self.model.children())[: 3]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            *(list(self.model.children())[3: 5]),
            nn.MaxPool2d(kernel_size=3, stride=2),
            *(list(self.model.children())[5: -1])
        )

        self.input_size = 299

    def forward(self, img_input, train=False):
        """
        forward of the CNNModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        if not train:
          # set the model to evaluation model
          self.model.eval()

        # N x 3 x 299 x 299
        features = self.model(img_input)
        # N x 2048 x 8 x 8

        return features

In [14]:
class AttentionModel(nn.Module):

    def __init__(self, feature_size, hidden_size=256):
        """
        Initializes a AttentionModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(AttentionModel, self).__init__()

        self.W_a = nn.Linear(hidden_size, hidden_size)
        self.v_a = nn.Parameter(torch.rand(hidden_size))

    def forward(self, img_features, h):
        """
        forward of the AttentionModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """

        # N = batch_size
        batch_size = img_features.size(0)

        # 1 x N x hidden_size
        h_a = h.repeat(img_features.size(1), 1, 1).permute(1, 0, 2)
        # N x 64 x hidden_size

        # attention scoring function v_a(tanh(W_a[s;h]))
        # tanh(W_a[s;h])
        energy =\
        torch.tanh(
            self.W_a(
                h_a + img_features
            )
        ).permute(0, 2, 1)
        # N x hidden_size x 64

        # hidden_size
        v = self.v_a.repeat(batch_size, 1).unsqueeze(1)
        # N x 1 x hidden_size

        # v_a(tanh(W_a[s;h]))
        # torch.bmm takes 3D tensors
        attention = torch.bmm(v, energy)
        # N x 1 x 64

        attention_weights = F.softmax(attention, dim=2)
        # N x 1 x 64

        return attention_weights

In [15]:
class RNNModel(nn.Module):

    def __init__(
        self, 
        feature_size,
        vocab_size,
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):
      
        """
        Initializes a RNNModel

        Parameters:
        -----------
        feature_size: int
            the number of features in the image matrix
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """

        super(RNNModel, self).__init__()

        self.feature_size = feature_size
        self.hidden_size = hidden_size

        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()

        self.out_dense = nn.Linear(hidden_size, hidden_size)
        self.h_dense = nn.Linear(feature_size, hidden_size)
        self.c_dense = nn.Linear(feature_size, hidden_size)  
        self.img_dense = nn.Linear(feature_size, hidden_size)
        
        self.embedding =\
        nn.Embedding(vocab_size,embedding_dim, padding_idx=0)

        if embedding_matrix is not None:

            self.embedding.load_state_dict({
                'weight': torch.FloatTensor(embedding_matrix)
            })
            self.embedding.weight.requires_grad = embedding_train

        self.attention =\
        AttentionModel(feature_size, hidden_size)
        
        self.lstm =\
        nn.LSTM(hidden_size, hidden_size, batch_first=True)
      

    def forward(self, img_features, captions):
        """
        forward of the RNNModel

        Parameters:
        -----------
        img_features: torch.Tensor 
            the image feature matrix
            (N x feature_size(2048) x 8 x 8)
        captions: torch.Tensor 
            the padded caption matrix
            (N x seq_len)

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        # N = batch_size
        batch_size = captions.size(0)
        seq_len = captions.size(1)

        # N x feature_size(2048) x 8 x 8
        img_features =\
        img_features.view(
            batch_size, self.feature_size, -1
        ).permute(0, 2, 1)
        # N x 64 x feature_size(2048)

        # N x 64 x feature_size(2048)
        h = self.h_dense(img_features.mean(dim=1)).unsqueeze(0)
        c = self.c_dense(img_features.mean(dim=1)).unsqueeze(0)
        # 1 x N x hidden_size
       
        # N x feature_size(2048) x 8 x 8
        img_features =\
        self.relu(
            self.img_dense(
                self.dropout(
                    img_features
                )
            )
        )  
        # N x 64 x hidden_size

        # N x seq_len
        embed =\
        self.dropout(
            self.embedding(
                captions
            )
        )
        # N x seq_len x embedding_dim
        
        outputs =\
        torch.zeros(
            batch_size,
            seq_len, 
            self.hidden_size
        ).to(device)

        all_attention_weights =\
        torch.zeros(
            batch_size,
            seq_len, 
            img_features.shape[1]
        ).to(device)
        
        for i in range(seq_len):

            attention_weights = self.attention(img_features, h)
            # N x 1 x 64

            # weighted sum of img_features
            weighted = torch.bmm(attention_weights, img_features)
            # N x 1 x hidden_size

            output, (h, c) =\
            self.lstm(
                embed[:, i, :].unsqueeze(1) + weighted,
                (h, c)
            )
            # outputs: N x 1 x hidden_size
            # h: 1 x N x hidden_size
            # c: 1 x N x hidden_size

            output =\
            self.out_dense(
                output.squeeze(1) + weighted.squeeze(1) + embed[:, i, :]
            )
            # N x hidden_size
 
            outputs[:, i, :] = output.squeeze()
            all_attention_weights[:, i, :] = attention_weights.squeeze()

        return outputs, all_attention_weights



In [16]:
class CaptionModel(nn.Module):

    def __init__(
        self, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):

        """
        Initializes a CaptionModel

        Parameters:
        -----------
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """    
        super(CaptionModel, self).__init__() 

        # set feature_size based on cnn_type
        self.feature_size = 2048

        self.decoder = RNNModel(
            self.feature_size,
            vocab_size, 
            embedding_dim,
            hidden_size,
            embedding_matrix,
            embedding_train
        )

        self.relu = nn.ReLU()
        self.dense = nn.Linear(hidden_size, vocab_size) 

    # def forward(self, captions):
    def forward(self, img_features, captions):
        """
        forward of the CaptionModel

        Parameters:
        -----------
        img_features: torch.Tensor 
            the image feature matrix
            (N x feature_size(2048) x 8 x 8)
        captions: torch.Tensor 
            the padded caption matrix
            (N x seq_len)

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        decoder_out, all_attention_weights = self.decoder(img_features, captions)

        # add up decoder outputs and image features
        outputs =\
        self.dense(
            self.relu(
                decoder_out
            )
        )

        return outputs, all_attention_weights

### Train the Neural Network

In [17]:
def train(model, iterator, optimizer, criterion, clip, vocab_size):
    """
    train the CaptionModel

    Parameters:
    -----------
    model: CaptionModel
        a CaptionModel instance
    iterator: torch.utils.data.dataloader
        a PyTorch dataloader
    optimizer: torch.optim
        a PyTorch optimizer 
    criterion: nn.CrossEntropyLoss
        a PyTorch criterion 

    Return:
    --------
    float
        average loss
    """
    model.train()    
    epoch_loss = 0
    
    for img_features, captions in iterator:
        
        optimizer.zero_grad()

        # for each caption, the end word is not passed for training
        outputs, all_attention_weights = model(
            img_features.to(device),
            captions[:, :-1].to(device)
        )

        loss = criterion(
            outputs.view(-1, vocab_size), 
            captions[:, 1:].flatten().to(device)
        ) + ((1. - all_attention_weights.sum(dim=1)) ** 2).mean()
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        
    return epoch_loss / len(iterator)

In [18]:
class SampleDataset(Dataset):
    def __init__(
        self,
        descriptions,
        imgs,
        wordtoidx,
        max_length
    ):
        """
        Initializes a SampleDataset

        Parameters:
        -----------
        descriptions: list
            a list of captions
        imgs: numpy.ndarray
            the image features
        wordtoidx: dict
            the dict to get word index
        max_length: int
            all captions will be padded to this size
        """        
        self.imgs = imgs
        self.descriptions = descriptions
        self.wordtoidx = wordtoidx
        self.max_length = max_length

    def __len__(self):
        """
        Returns the batch size

        Return:
        --------
        int
            the batch size
        """
        # return len(self.descriptions)
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Prepare data for each image

        Parameters:
        -----------
        idx: int
          the index of the image to process

        Return:
        --------
        list, list, list
            [5 x image feature matrix],
            [five padded captions for this image]
            [the length of each caption]
        """

        img = self.imgs[idx // 5]
        # convert each word into a list of sequences.
        seq = [self.wordtoidx[word] for word 
               in self.descriptions[idx // 5][idx % 5].split(' ')
               if word in self.wordtoidx]
        # pad the sequence with 0 on the right side
        in_seq = np.pad(
            seq, 
            (0, max_length - len(seq)),
            mode='constant',
            constant_values=(0, 0)
            )

        return img, in_seq


In [19]:
def init_weights(model, embedding_pretrained=True):
    """
    Initialize weights and bias in the model

    Parameters:
    -----------
    model: CaptionModel
      a CaptionModel instance
    embedding_pretrained: bool (default: True)
        not initialize the embedding matrix if True
    """  
  
    for name, param in model.named_parameters():
        if embedding_pretrained and 'embedding' in name:
            continue
        elif 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            


In [20]:
def encode_image(model, img_path):
    """
    Process the images to extract features

    Parameters:
    -----------
    model: CNNModel
      a CNNModel instance
    img_path: str
        the path of the image
 
    Return:
    --------
    torch.Tensor
        the extracted feature matrix from CNNModel
    """  

    img = Image.open(img_path)

    # Perform preprocessing needed by pre-trained models
    preprocessor = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    img = preprocessor(img)
    # Expand to 2D array
    img = img.view(1, *img.shape)
    # Call model to extract the smaller feature set for the image.
    x = model(img.to(device), False) 
    # Shape to correct form to be accepted by LSTM captioning network.
    x = np.squeeze(x)
    return x

In [21]:
def extract_img_features(img_paths, model):
    """
    Extracts, stores and returns image features

    Parameters:
    -----------
    img_paths: list
        the paths of images
    model: CNNModel (default: None)
      a CNNModel instance

    Return:
    --------
    numpy.ndarray
        the extracted image feature matrix from CNNModel
    """ 

    start = time()
    img_features = []

    for image_path in tqdm(img_paths):
        img_features.append(
            encode_image(model, image_path).cpu().data.numpy()
        )

    print(f"\nGenerating set took: {hms_string(time()-start)}")

    return img_features

In [22]:
def get_train_test(
    encoder,
    train_paths,
    test_paths,
    cnn_type='inception_v3',
):

    train_img_features = extract_img_features(
        train_paths,
        encoder
    )

    test_img_features = extract_img_features(
        test_paths,
        encoder
    )
    return train_img_features, test_img_features

def get_train_dataloader(
    train_descriptions, 
    train_img_features,
    wordtoidx,
    max_length,
    batch_size=200
):
    train_dataset = SampleDataset(
        train_descriptions,
        train_img_features,
        wordtoidx,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size
    )
    
    return train_loader

def train_model(
    train_loader,
    vocab_size,
    embedding_dim, 
    embedding_matrix,
    hidden_size=256,
):

    caption_model = CaptionModel(
        vocab_size, 
        embedding_dim, 
        hidden_size=hidden_size,
        embedding_matrix=embedding_matrix, 
        embedding_train=True
    )

    init_weights(
        caption_model,
        embedding_pretrained=True
    )

    caption_model.to(device)

    # we will ignore the pad token in true target set
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(
        caption_model.parameters(), 
        lr=0.01
    )

    clip = 1
    start = time()

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)

    # reduce the learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-4

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)
    return caption_model

In [23]:
def generateCaption(
    model, 
    img_features,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    in_text = START

    for i in range(max_length):

        sequence = [wordtoidx[w] for w in in_text.split() if w in wordtoidx]
        sequence = np.pad(sequence, (0, max_length - len(sequence)),
                          mode='constant', constant_values=(0, 0))
        model.eval()
        yhat, _ = model(
            torch.FloatTensor(img_features)\
            .view(-1, model.feature_size).to(device),
            torch.LongTensor(sequence).view(-1, max_length).to(device)
        )

        yhat = yhat.view(-1, vocab_size).argmax(1)
        word = idxtoword[yhat.cpu().data.numpy()[i]]
        in_text += ' ' + word
        if word == STOP:
            break
    final = in_text.split()
    final = final[1 : -1]
    final = ' '.join(final)
    return final

### Evaluation

In [24]:
def eval_model(ref_data, results):
    """
    Computes evaluation metrics of the model results against the human annotated captions
    
    Parameters:
    ------------
    ref_data: dict
        a dictionary containing human annotated captions, with image name as key and a list of human annotated captions as values
    
    results: dict
        a dictionary containing model generated caption, with image name as key and a generated caption as value
        
    Returns:
    ------------
    score_dict: a dictionary containing the overall average score for the model
    """
    # download stanford nlp library
    subprocess.call(['../../scr/evaluation/get_stanford_models.sh'])
    
    # format the inputs
    gts = {}
    res = {}

    for imgId in range(len(ref_data)):
        caption_list_sel = []
        for i in range(5):
            lst = {}
            lst['caption'] = ref_data[imgId][i]
            lst['image_id'] = imgId
            lst['id'] = i
            caption_list_sel.append(lst)
        gts[imgId] = caption_list_sel

        res[imgId] = [{'caption': results[imgId]}]
        
    # tokenize
    print('tokenization...')
    tokenizer = PTBTokenizer()
    gts  = tokenizer.tokenize(gts)
    res = tokenizer.tokenize(res)
    
    # compute scores
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE"),
        (usc_sim(), "USC_similarity"),  
        ]
    score_dict = {}
    for scorer, method in scorers:
        print('computing %s score...'%(scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
        else:
            score_dict[method] = score
            
    return score_dict


In [25]:
def evaluate_results(
    test_img_features, 
    model,
    ref,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    # generate results
    print('Generating captions...')
    results = {}
    for n in range(len(test_img_features)):
        img_features = test_img_features[n]
        generated = generateCaption(
            model, 
            img_features,
            max_length,
            vocab_size,
            wordtoidx,
            idxtoword
        )
        results[n] = generated
        
    model_score = eval_model(ref, results)

    return model_score

### Cross validation

In [26]:
cnn_type = 'inception_v3'
encoder = CNNModel(pretrained=True)
encoder.to(device)

CNNModel(
  (model): Sequential(
    (0): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicConv2d(
      (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1

In [27]:
def cross_validation(train_index, test_index, count):
    print('=' * 60)
    print(f'Split {count}:')
    print(f'Splitting data...')
    
    train_paths, test_paths = all_paths[train_index], all_paths[test_index]
    train_descriptions, test_descriptions = all_descriptions[train_index], all_descriptions[test_index]
    print(f'{len(train_paths)} images for training and {len(test_paths)} images for testing.')
    
    vocab = get_vocab(train_descriptions, word_count_threshold=10)
    idxtoword, wordtoidx = get_word_dict(vocab)
    vocab_size = get_vocab_size(idxtoword)
    embedding_dim = 500
    embedding_matrix = get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx) 

    print(f'Preparing dataloader...')
    train_img_features, test_img_features = get_train_test(encoder, train_paths, test_paths)

    train_loader = get_train_dataloader(
        train_descriptions, 
        train_img_features,
        wordtoidx,
        max_length,
        batch_size=1000
    )

    print(f'Training...')
    caption_model = train_model(
        train_loader,
        vocab_size,
        embedding_dim, 
        embedding_matrix,
        hidden_size=500
    )

    
    ref = captions[test_index]
    model_score = evaluate_results(
        test_img_features, 
        caption_model,
        ref,
        max_length,
        vocab_size,
        wordtoidx,
        idxtoword
    )
    
    return caption_model, model_score

In [28]:
cv = KFold(n_splits=5, random_state=123, shuffle=True)
cv = [(train_index, test_index) for train_index, test_index in cv.split(all_paths)]  

In [29]:
caption_model1, model_score1 = cross_validation(cv[0][0], cv[0][1], 1)    

Split 1:
Splitting data...
8332 images for training and 2084 images for testing.
There are 41660 captions


  0%|          | 0/8332 [00:00<?, ?it/s]

preprocessed words 2659 ==> 884
The vocabulary size is 885.
793 out of 885 words are found in the pre-trained matrix.
The size of embedding_matrix is (885, 500)
Preparing dataloader...


100%|██████████| 8332/8332 [03:31<00:00, 39.48it/s]
  0%|          | 5/2084 [00:00<00:51, 40.27it/s]


Generating set took: 0:03:31.06


100%|██████████| 2084/2084 [00:53<00:00, 39.29it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.04
Training...


  2%|▏         | 1/60 [00:07<07:00,  7.13s/it]

10.92996237013075


  3%|▎         | 2/60 [00:14<06:52,  7.11s/it]

4.890611860487196


  5%|▌         | 3/60 [00:21<06:45,  7.11s/it]

3.9659491380055747


  7%|▋         | 4/60 [00:28<06:37,  7.10s/it]

3.265617397096422


  8%|▊         | 5/60 [00:35<06:30,  7.09s/it]

2.854031377368503


 10%|█         | 6/60 [00:42<06:22,  7.09s/it]

2.6024077203538685


 12%|█▏        | 7/60 [00:49<06:15,  7.09s/it]

2.4020527998606362


 13%|█▎        | 8/60 [00:56<06:08,  7.09s/it]

2.2307625346713595


 15%|█▌        | 9/60 [01:03<06:01,  7.10s/it]

2.091664883825514


 17%|█▋        | 10/60 [01:10<05:54,  7.10s/it]

1.9773075845506456


 18%|█▊        | 11/60 [01:18<05:47,  7.10s/it]

1.8994710710313585


 20%|██        | 12/60 [01:25<05:40,  7.10s/it]

1.8286912706163194


 22%|██▏       | 13/60 [01:32<05:33,  7.10s/it]

1.7351057926813762


 23%|██▎       | 14/60 [01:39<05:26,  7.10s/it]

1.6522583034303453


 25%|██▌       | 15/60 [01:46<05:19,  7.11s/it]

1.5696636968188815


 27%|██▋       | 16/60 [01:53<05:14,  7.14s/it]

1.4893906248940363


 28%|██▊       | 17/60 [02:00<05:06,  7.13s/it]

1.4305594099892511


 30%|███       | 18/60 [02:07<04:59,  7.12s/it]

1.3763963911268446


 32%|███▏      | 19/60 [02:14<04:51,  7.11s/it]

1.3405957221984863


 33%|███▎      | 20/60 [02:22<04:44,  7.11s/it]

1.317326029141744


 35%|███▌      | 21/60 [02:29<04:36,  7.10s/it]

1.297496259212494


 37%|███▋      | 22/60 [02:36<04:30,  7.11s/it]

1.2462259531021118


 38%|███▊      | 23/60 [02:43<04:23,  7.11s/it]

1.2143969337145488


 40%|████      | 24/60 [02:50<04:15,  7.11s/it]

1.1984957522816129


 42%|████▏     | 25/60 [02:57<04:08,  7.11s/it]

1.18360526031918


 43%|████▎     | 26/60 [03:04<04:01,  7.11s/it]

1.163129727045695


 45%|████▌     | 27/60 [03:11<03:54,  7.11s/it]

1.1473117536968656


 47%|████▋     | 28/60 [03:18<03:47,  7.11s/it]

1.1488724019792345


 48%|████▊     | 29/60 [03:26<03:40,  7.11s/it]

1.0910406443807814


 50%|█████     | 30/60 [03:33<03:33,  7.10s/it]

1.0622521771325006


 52%|█████▏    | 31/60 [03:40<03:25,  7.10s/it]

1.0281426045629714


 53%|█████▎    | 32/60 [03:47<03:18,  7.10s/it]

0.9969927072525024


 55%|█████▌    | 33/60 [03:54<03:11,  7.10s/it]

0.9843948682149252


 57%|█████▋    | 34/60 [04:01<03:04,  7.11s/it]

0.98071112897661


 58%|█████▊    | 35/60 [04:08<02:58,  7.13s/it]

0.9693286551369561


 60%|██████    | 36/60 [04:15<02:51,  7.13s/it]

0.9610758953624301


 62%|██████▏   | 37/60 [04:22<02:43,  7.12s/it]

0.951280931631724


 63%|██████▎   | 38/60 [04:30<02:36,  7.12s/it]

0.9101386268933614


 65%|██████▌   | 39/60 [04:37<02:29,  7.11s/it]

0.9024313555823432


 67%|██████▋   | 40/60 [04:44<02:22,  7.14s/it]

0.8788293070263333


 68%|██████▊   | 41/60 [04:51<02:15,  7.15s/it]

0.8670786950323317


 70%|███████   | 42/60 [04:58<02:08,  7.14s/it]

0.8581056925985548


 72%|███████▏  | 43/60 [05:05<02:01,  7.13s/it]

0.8665529290835062


 73%|███████▎  | 44/60 [05:12<01:53,  7.12s/it]

0.8566286563873291


 75%|███████▌  | 45/60 [05:19<01:46,  7.12s/it]

0.8469975392023722


 77%|███████▋  | 46/60 [05:27<01:39,  7.11s/it]

0.8102770381503634


 78%|███████▊  | 47/60 [05:34<01:32,  7.11s/it]

0.7885245482126871


 80%|████████  | 48/60 [05:41<01:25,  7.11s/it]

0.7631232009993659


 82%|████████▏ | 49/60 [05:48<01:18,  7.11s/it]

0.7520327170689901


 83%|████████▎ | 50/60 [05:55<01:11,  7.11s/it]

0.7549524505933126


 85%|████████▌ | 51/60 [06:02<01:03,  7.11s/it]

0.7528064582082961


 87%|████████▋ | 52/60 [06:09<00:56,  7.11s/it]

0.7432083917988671


 88%|████████▊ | 53/60 [06:16<00:49,  7.10s/it]

0.7349435753292508


 90%|█████████ | 54/60 [06:23<00:42,  7.10s/it]

0.7179477148585849


 92%|█████████▏| 55/60 [06:31<00:35,  7.11s/it]

0.7105458676815033


 93%|█████████▎| 56/60 [06:38<00:28,  7.10s/it]

0.7015411754449209


 95%|█████████▌| 57/60 [06:45<00:21,  7.10s/it]

0.6948098937670389


 97%|█████████▋| 58/60 [06:52<00:14,  7.10s/it]

0.696942968500985


 98%|█████████▊| 59/60 [06:59<00:07,  7.10s/it]

0.6825484500990974


100%|██████████| 60/60 [07:06<00:00,  7.11s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.6783451967769198


  2%|▏         | 1/60 [00:07<06:58,  7.09s/it]

0.6865550544526842


  3%|▎         | 2/60 [00:14<06:51,  7.10s/it]

0.6475772923893399


  5%|▌         | 3/60 [00:21<06:45,  7.11s/it]

0.6155605316162109


  7%|▋         | 4/60 [00:28<06:37,  7.10s/it]

0.6002097196049161


  8%|▊         | 5/60 [00:35<06:31,  7.11s/it]

0.5928822656472524


 10%|█         | 6/60 [00:42<06:24,  7.11s/it]

0.5868506365352206


 12%|█▏        | 7/60 [00:49<06:16,  7.11s/it]

0.5815166168742709


 13%|█▎        | 8/60 [00:56<06:09,  7.11s/it]

0.5777962903181711


 15%|█▌        | 9/60 [01:03<06:01,  7.10s/it]

0.5757747656769223


 17%|█▋        | 10/60 [01:11<05:54,  7.09s/it]

0.5737589730156792


 18%|█▊        | 11/60 [01:18<05:47,  7.09s/it]

0.5732982092433505


 20%|██        | 12/60 [01:25<05:42,  7.13s/it]

0.5697944627867805


 22%|██▏       | 13/60 [01:32<05:34,  7.12s/it]

0.5690208971500397


 23%|██▎       | 14/60 [01:39<05:27,  7.11s/it]

0.5668455395433638


 25%|██▌       | 15/60 [01:46<05:19,  7.10s/it]

0.5651732484499613


 27%|██▋       | 16/60 [01:53<05:12,  7.11s/it]

0.5655471682548523


 28%|██▊       | 17/60 [02:00<05:05,  7.11s/it]

0.5649978849622939


 30%|███       | 18/60 [02:07<04:58,  7.10s/it]

0.563184357351727


 32%|███▏      | 19/60 [02:15<04:51,  7.11s/it]

0.5617933803134494


 33%|███▎      | 20/60 [02:22<04:44,  7.11s/it]

0.5603227383560605


 35%|███▌      | 21/60 [02:29<04:37,  7.11s/it]

0.5596885747379727


 37%|███▋      | 22/60 [02:36<04:29,  7.10s/it]

0.5605442855093214


 38%|███▊      | 23/60 [02:43<04:23,  7.11s/it]

0.5573394960827298


 40%|████      | 24/60 [02:50<04:15,  7.11s/it]

0.5577602320247226


 42%|████▏     | 25/60 [02:57<04:09,  7.11s/it]

0.5575199259652032


 43%|████▎     | 26/60 [03:04<04:01,  7.11s/it]

0.5577275819248624


 45%|████▌     | 27/60 [03:11<03:54,  7.11s/it]

0.5540774332152473


 47%|████▋     | 28/60 [03:19<03:47,  7.11s/it]

0.5549507505363889


 48%|████▊     | 29/60 [03:26<03:40,  7.11s/it]

0.5542619493272569


 50%|█████     | 30/60 [03:33<03:33,  7.11s/it]

0.5531823138395945


 52%|█████▏    | 31/60 [03:40<03:26,  7.11s/it]

0.5526387790838877


 53%|█████▎    | 32/60 [03:47<03:18,  7.10s/it]

0.5516817106140984


 55%|█████▌    | 33/60 [03:54<03:12,  7.11s/it]

0.5530256628990173


 57%|█████▋    | 34/60 [04:01<03:04,  7.11s/it]

0.5514235893885294


 58%|█████▊    | 35/60 [04:08<02:57,  7.11s/it]

0.5496642490228018


 60%|██████    | 36/60 [04:15<02:50,  7.11s/it]

0.54938938220342


 62%|██████▏   | 37/60 [04:23<02:43,  7.12s/it]

0.5485293666521708


 63%|██████▎   | 38/60 [04:30<02:36,  7.11s/it]

0.549669451183743


 65%|██████▌   | 39/60 [04:37<02:29,  7.12s/it]

0.5470463534196218


 67%|██████▋   | 40/60 [04:44<02:23,  7.17s/it]

0.547193232509825


 68%|██████▊   | 41/60 [04:51<02:16,  7.16s/it]

0.5469686686992645


 70%|███████   | 42/60 [04:58<02:08,  7.15s/it]

0.547355936633216


 72%|███████▏  | 43/60 [05:05<02:01,  7.14s/it]

0.5465981960296631


 73%|███████▎  | 44/60 [05:13<01:54,  7.13s/it]

0.5457552936342027


 75%|███████▌  | 45/60 [05:20<01:46,  7.12s/it]

0.5445904963546329


 77%|███████▋  | 46/60 [05:27<01:39,  7.11s/it]

0.5443403687742021


 78%|███████▊  | 47/60 [05:34<01:32,  7.11s/it]

0.5457012587123447


 80%|████████  | 48/60 [05:41<01:25,  7.11s/it]

0.5436302224795023


 82%|████████▏ | 49/60 [05:48<01:18,  7.10s/it]

0.5438521305720011


 83%|████████▎ | 50/60 [05:55<01:11,  7.11s/it]

0.5426276955339644


 85%|████████▌ | 51/60 [06:02<01:03,  7.11s/it]

0.541952927907308


 87%|████████▋ | 52/60 [06:09<00:56,  7.11s/it]

0.5419294966591729


 88%|████████▊ | 53/60 [06:17<00:49,  7.12s/it]

0.5412415729628669


 90%|█████████ | 54/60 [06:24<00:42,  7.11s/it]

0.5406237244606018


 92%|█████████▏| 55/60 [06:31<00:35,  7.12s/it]

0.5390174720022414


 93%|█████████▎| 56/60 [06:38<00:28,  7.11s/it]

0.5389558540450202


 95%|█████████▌| 57/60 [06:45<00:21,  7.12s/it]

0.5403624739911821


 97%|█████████▋| 58/60 [06:52<00:14,  7.12s/it]

0.5380562278959486


 98%|█████████▊| 59/60 [06:59<00:07,  7.11s/it]

0.5394076969888475


100%|██████████| 60/60 [07:06<00:00,  7.11s/it]

0.5365480515691969
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [30]:
model_score1

{'Bleu_1': 0.5430511873831741,
 'Bleu_2': 0.4040510387044148,
 'Bleu_3': 0.3202125719013248,
 'Bleu_4': 0.26499074763786395,
 'METEOR': 0.23442467563859665,
 'ROUGE_L': 0.44063030412906334,
 'CIDEr': 1.284062242860915,
 'SPICE': 0.2835413805275309,
 'USC_similarity': 0.520579584497483}

In [31]:
caption_model2, model_score2 = cross_validation(cv[1][0], cv[1][1], 2)    

Split 2:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<03:43, 37.26it/s]

preprocessed words 2688 ==> 916
The vocabulary size is 917.
819 out of 917 words are found in the pre-trained matrix.
The size of embedding_matrix is (917, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:35<00:00, 38.64it/s]
  0%|          | 4/2083 [00:00<00:53, 38.64it/s]


Generating set took: 0:03:35.68


100%|██████████| 2083/2083 [00:53<00:00, 38.59it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.98
Training...


  2%|▏         | 1/60 [00:07<07:11,  7.32s/it]

10.061019155714247


  3%|▎         | 2/60 [00:14<07:00,  7.25s/it]

4.97698269950019


  5%|▌         | 3/60 [00:21<06:50,  7.21s/it]

4.67575650744968


  7%|▋         | 4/60 [00:28<06:41,  7.17s/it]

4.293290932973226


  8%|▊         | 5/60 [00:35<06:33,  7.15s/it]

3.836088842815823


 10%|█         | 6/60 [00:42<06:25,  7.14s/it]

3.4894101884629993


 12%|█▏        | 7/60 [00:49<06:17,  7.13s/it]

3.264280663596259


 13%|█▎        | 8/60 [00:57<06:10,  7.12s/it]

3.1338774893018932


 15%|█▌        | 9/60 [01:04<06:02,  7.12s/it]

3.0516369342803955


 17%|█▋        | 10/60 [01:11<05:55,  7.11s/it]

2.982200860977173


 18%|█▊        | 11/60 [01:18<05:48,  7.11s/it]

2.8813842402564154


 20%|██        | 12/60 [01:25<05:41,  7.12s/it]

2.780631833606296


 22%|██▏       | 13/60 [01:32<05:34,  7.11s/it]

2.708916743596395


 23%|██▎       | 14/60 [01:39<05:27,  7.11s/it]

2.628864367802938


 25%|██▌       | 15/60 [01:46<05:19,  7.11s/it]

2.588158210118612


 27%|██▋       | 16/60 [01:53<05:12,  7.11s/it]

2.504541529549493


 28%|██▊       | 17/60 [02:01<05:06,  7.12s/it]

2.4657600190904407


 30%|███       | 18/60 [02:08<04:59,  7.13s/it]

2.4253409968482122


 32%|███▏      | 19/60 [02:15<04:52,  7.13s/it]

2.365854342778524


 33%|███▎      | 20/60 [02:22<04:45,  7.13s/it]

2.2726147174835205


 35%|███▌      | 21/60 [02:29<04:37,  7.12s/it]

2.2081063985824585


 37%|███▋      | 22/60 [02:36<04:30,  7.13s/it]

2.1380242904027305


 38%|███▊      | 23/60 [02:43<04:23,  7.13s/it]

2.085267278883192


 40%|████      | 24/60 [02:51<04:18,  7.19s/it]

2.0485790967941284


 42%|████▏     | 25/60 [02:58<04:11,  7.18s/it]

2.00264839331309


 43%|████▎     | 26/60 [03:05<04:04,  7.18s/it]

1.9476550949944391


 45%|████▌     | 27/60 [03:12<03:56,  7.16s/it]

1.9261966149012248


 47%|████▋     | 28/60 [03:19<03:49,  7.16s/it]

1.8933593299653795


 48%|████▊     | 29/60 [03:26<03:41,  7.16s/it]

1.8686589267518785


 50%|█████     | 30/60 [03:34<03:34,  7.15s/it]

1.799387468232049


 52%|█████▏    | 31/60 [03:41<03:27,  7.16s/it]

1.7451519171396892


 53%|█████▎    | 32/60 [03:48<03:20,  7.16s/it]

1.703747206264072


 55%|█████▌    | 33/60 [03:55<03:13,  7.15s/it]

1.675542950630188


 57%|█████▋    | 34/60 [04:02<03:05,  7.15s/it]

1.6456293132570055


 58%|█████▊    | 35/60 [04:09<02:58,  7.16s/it]

1.6300941970613267


 60%|██████    | 36/60 [04:17<02:51,  7.16s/it]

1.6030148267745972


 62%|██████▏   | 37/60 [04:24<02:44,  7.15s/it]

1.5712102121777005


 63%|██████▎   | 38/60 [04:31<02:37,  7.15s/it]

1.5353734890619914


 65%|██████▌   | 39/60 [04:38<02:30,  7.14s/it]

1.510686331325107


 67%|██████▋   | 40/60 [04:45<02:22,  7.14s/it]

1.4834795925352309


 68%|██████▊   | 41/60 [04:52<02:15,  7.14s/it]

1.4692067172792223


 70%|███████   | 42/60 [04:59<02:08,  7.15s/it]

1.4612450467215643


 72%|███████▏  | 43/60 [05:06<02:01,  7.14s/it]

1.4335623184839885


 73%|███████▎  | 44/60 [05:14<01:54,  7.14s/it]

1.4346629116270277


 75%|███████▌  | 45/60 [05:21<01:47,  7.14s/it]

1.4114372200436063


 77%|███████▋  | 46/60 [05:28<01:39,  7.13s/it]

1.362103521823883


 78%|███████▊  | 47/60 [05:35<01:32,  7.14s/it]

1.3346266878975763


 80%|████████  | 48/60 [05:42<01:25,  7.14s/it]

1.2985942562421162


 82%|████████▏ | 49/60 [05:49<01:18,  7.14s/it]

1.2668334907955594


 83%|████████▎ | 50/60 [05:56<01:11,  7.14s/it]

1.2348406910896301


 85%|████████▌ | 51/60 [06:04<01:04,  7.14s/it]

1.2217312057813008


 87%|████████▋ | 52/60 [06:11<00:57,  7.14s/it]

1.2258669402864244


 88%|████████▊ | 53/60 [06:18<00:50,  7.14s/it]

1.21789464685652


 90%|█████████ | 54/60 [06:25<00:42,  7.14s/it]

1.204172068172031


 92%|█████████▏| 55/60 [06:32<00:35,  7.14s/it]

1.1989552709791396


 93%|█████████▎| 56/60 [06:39<00:28,  7.16s/it]

1.1946379807260301


 95%|█████████▌| 57/60 [06:47<00:21,  7.15s/it]

1.160498234960768


 97%|█████████▋| 58/60 [06:54<00:14,  7.17s/it]

1.1178750263320074


 98%|█████████▊| 59/60 [07:01<00:07,  7.16s/it]

1.0933061639467876


100%|██████████| 60/60 [07:08<00:00,  7.14s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

1.0811361140675015


  2%|▏         | 1/60 [00:07<07:01,  7.14s/it]

1.06335668431388


  3%|▎         | 2/60 [00:14<06:54,  7.14s/it]

1.0308435691727533


  5%|▌         | 3/60 [00:21<06:49,  7.19s/it]

1.0071429477797613


  7%|▋         | 4/60 [00:28<06:42,  7.19s/it]

0.9948008987638686


  8%|▊         | 5/60 [00:35<06:34,  7.17s/it]

0.9885703590181139


 10%|█         | 6/60 [00:43<06:26,  7.16s/it]

0.9831070039007399


 12%|█▏        | 7/60 [00:50<06:21,  7.20s/it]

0.9794879224565294


 13%|█▎        | 8/60 [00:57<06:13,  7.18s/it]

0.9765990111562941


 15%|█▌        | 9/60 [01:04<06:06,  7.18s/it]

0.9704205791155497


 17%|█▋        | 10/60 [01:11<05:58,  7.16s/it]

0.9677737951278687


 18%|█▊        | 11/60 [01:18<05:50,  7.16s/it]

0.9644533660676744


 20%|██        | 12/60 [01:26<05:44,  7.17s/it]

0.9636462065908644


 22%|██▏       | 13/60 [01:33<05:36,  7.16s/it]

0.9627143873108758


 23%|██▎       | 14/60 [01:40<05:29,  7.16s/it]

0.9583558241526285


 25%|██▌       | 15/60 [01:47<05:22,  7.16s/it]

0.9574338263935513


 27%|██▋       | 16/60 [01:54<05:14,  7.15s/it]

0.9567767845259773


 28%|██▊       | 17/60 [02:01<05:07,  7.15s/it]

0.9547592533959283


 30%|███       | 18/60 [02:08<05:00,  7.14s/it]

0.9535760482152303


 32%|███▏      | 19/60 [02:16<04:52,  7.14s/it]

0.9533300730917189


 33%|███▎      | 20/60 [02:23<04:45,  7.14s/it]

0.9502197768953111


 35%|███▌      | 21/60 [02:30<04:38,  7.14s/it]

0.9508706132570902


 37%|███▋      | 22/60 [02:37<04:31,  7.13s/it]

0.9502550694677565


 38%|███▊      | 23/60 [02:44<04:23,  7.13s/it]

0.9474346306588914


 40%|████      | 24/60 [02:51<04:16,  7.13s/it]

0.949292348490821


 42%|████▏     | 25/60 [02:58<04:09,  7.12s/it]

0.9448418087429471


 43%|████▎     | 26/60 [03:06<04:02,  7.13s/it]

0.9441897008154128


 45%|████▌     | 27/60 [03:13<03:55,  7.13s/it]

0.9416997167799208


 47%|████▋     | 28/60 [03:20<03:48,  7.13s/it]

0.9429974357287089


 48%|████▊     | 29/60 [03:27<03:41,  7.13s/it]

0.9423646728197733


 50%|█████     | 30/60 [03:34<03:33,  7.12s/it]

0.9412905971209208


 52%|█████▏    | 31/60 [03:41<03:26,  7.12s/it]

0.9419406652450562


 53%|█████▎    | 32/60 [03:48<03:19,  7.12s/it]

0.9385478165414598


 55%|█████▌    | 33/60 [03:55<03:12,  7.12s/it]

0.9389697644445631


 57%|█████▋    | 34/60 [04:03<03:05,  7.13s/it]

0.938461012310452


 58%|█████▊    | 35/60 [04:10<02:58,  7.13s/it]

0.9349179267883301


 60%|██████    | 36/60 [04:17<02:50,  7.12s/it]

0.9375128083758884


 62%|██████▏   | 37/60 [04:24<02:43,  7.13s/it]

0.9356165197160509


 63%|██████▎   | 38/60 [04:31<02:36,  7.13s/it]

0.9348322086864047


 65%|██████▌   | 39/60 [04:38<02:29,  7.12s/it]

0.9346691370010376


 67%|██████▋   | 40/60 [04:45<02:22,  7.13s/it]

0.9343856308195326


 68%|██████▊   | 41/60 [04:52<02:15,  7.13s/it]

0.93421381049686


 70%|███████   | 42/60 [05:00<02:09,  7.18s/it]

0.9293533431159126


 72%|███████▏  | 43/60 [05:07<02:01,  7.16s/it]

0.9319884843296475


 73%|███████▎  | 44/60 [05:14<01:54,  7.15s/it]

0.9305525554551018


 75%|███████▌  | 45/60 [05:21<01:47,  7.14s/it]

0.9295956359969245


 77%|███████▋  | 46/60 [05:28<01:40,  7.14s/it]

0.9306544264157613


 78%|███████▊  | 47/60 [05:35<01:32,  7.15s/it]

0.928139156765408


 80%|████████  | 48/60 [05:43<01:25,  7.14s/it]

0.927635027302636


 82%|████████▏ | 49/60 [05:50<01:18,  7.14s/it]

0.9265252815352546


 83%|████████▎ | 50/60 [05:57<01:11,  7.14s/it]

0.9249901109271579


 85%|████████▌ | 51/60 [06:04<01:04,  7.14s/it]

0.928083168135749


 87%|████████▋ | 52/60 [06:11<00:57,  7.14s/it]

0.924973275926378


 88%|████████▊ | 53/60 [06:18<00:50,  7.14s/it]

0.9204473627938164


 90%|█████████ | 54/60 [06:25<00:42,  7.14s/it]

0.9256411989529928


 92%|█████████▏| 55/60 [06:32<00:35,  7.14s/it]

0.9225681159231398


 93%|█████████▎| 56/60 [06:40<00:28,  7.14s/it]

0.9217709898948669


 95%|█████████▌| 57/60 [06:47<00:21,  7.14s/it]

0.9204380644692315


 97%|█████████▋| 58/60 [06:54<00:14,  7.14s/it]

0.9219896660910712


 98%|█████████▊| 59/60 [07:01<00:07,  7.14s/it]

0.9202675157123141


100%|██████████| 60/60 [07:08<00:00,  7.14s/it]

0.9200836022694906
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [32]:
model_score2

{'Bleu_1': 0.5223053064684876,
 'Bleu_2': 0.39087007653399974,
 'Bleu_3': 0.3113853373714817,
 'Bleu_4': 0.2582108151703191,
 'METEOR': 0.22031991810802834,
 'ROUGE_L': 0.4218176742259456,
 'CIDEr': 1.313764382035548,
 'SPICE': 0.28175420506525184,
 'USC_similarity': 0.5201317327842832}

In [33]:
caption_model3, model_score3 = cross_validation(cv[2][0], cv[2][1], 3)    

Split 3:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<03:45, 36.90it/s]

preprocessed words 2714 ==> 890
The vocabulary size is 891.
800 out of 891 words are found in the pre-trained matrix.
The size of embedding_matrix is (891, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:38<00:00, 38.22it/s]
  0%|          | 4/2083 [00:00<00:54, 38.30it/s]


Generating set took: 0:03:38.01


100%|██████████| 2083/2083 [00:55<00:00, 37.79it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:55.13
Training...


  2%|▏         | 1/60 [00:07<07:07,  7.25s/it]

10.967783821953667


  3%|▎         | 2/60 [00:14<06:57,  7.20s/it]

4.665954377916124


  5%|▌         | 3/60 [00:21<06:49,  7.18s/it]

3.8092906210157604


  7%|▋         | 4/60 [00:28<06:40,  7.16s/it]

3.157911698023478


  8%|▊         | 5/60 [00:35<06:32,  7.15s/it]

2.8202331595950656


 10%|█         | 6/60 [00:42<06:25,  7.13s/it]

2.5875430372026234


 12%|█▏        | 7/60 [00:49<06:17,  7.13s/it]

2.424020687739054


 13%|█▎        | 8/60 [00:57<06:10,  7.12s/it]

2.2768927415211997


 15%|█▌        | 9/60 [01:04<06:03,  7.13s/it]

2.1611568133036294


 17%|█▋        | 10/60 [01:11<05:56,  7.13s/it]

2.0800361500846014


 18%|█▊        | 11/60 [01:18<05:49,  7.13s/it]

1.990999115837945


 20%|██        | 12/60 [01:25<05:41,  7.12s/it]

1.9158542023764715


 22%|██▏       | 13/60 [01:32<05:35,  7.13s/it]

1.8553718593385484


 23%|██▎       | 14/60 [01:39<05:27,  7.13s/it]

1.8110025458865695


 25%|██▌       | 15/60 [01:46<05:20,  7.12s/it]

1.7574989133410983


 27%|██▋       | 16/60 [01:54<05:13,  7.12s/it]

1.6782658629947238


 28%|██▊       | 17/60 [02:01<05:06,  7.13s/it]

1.5955474774042766


 30%|███       | 18/60 [02:08<04:59,  7.13s/it]

1.5365025732252333


 32%|███▏      | 19/60 [02:15<04:52,  7.12s/it]

1.4827142026689317


 33%|███▎      | 20/60 [02:22<04:44,  7.12s/it]

1.4539751609166462


 35%|███▌      | 21/60 [02:29<04:37,  7.12s/it]

1.4253953695297241


 37%|███▋      | 22/60 [02:36<04:30,  7.13s/it]

1.4091228114234076


 38%|███▊      | 23/60 [02:43<04:23,  7.13s/it]

1.408366322517395


 40%|████      | 24/60 [02:51<04:16,  7.13s/it]

1.3612469302283392


 42%|████▏     | 25/60 [02:58<04:09,  7.13s/it]

1.3383422957526312


 43%|████▎     | 26/60 [03:05<04:02,  7.14s/it]

1.3146217399173312


 45%|████▌     | 27/60 [03:12<03:55,  7.14s/it]

1.283638834953308


 47%|████▋     | 28/60 [03:19<03:48,  7.14s/it]

1.2575253778033786


 48%|████▊     | 29/60 [03:26<03:41,  7.14s/it]

1.2423403925365872


 50%|█████     | 30/60 [03:33<03:34,  7.14s/it]

1.2165038188298543


 52%|█████▏    | 31/60 [03:41<03:27,  7.14s/it]

1.188429421848721


 53%|█████▎    | 32/60 [03:48<03:19,  7.14s/it]

1.1640075776312087


 55%|█████▌    | 33/60 [03:55<03:12,  7.13s/it]

1.1325975126690335


 57%|█████▋    | 34/60 [04:02<03:05,  7.13s/it]

1.1179425716400146


 58%|█████▊    | 35/60 [04:09<02:58,  7.15s/it]

1.094046539730496


 60%|██████    | 36/60 [04:16<02:51,  7.14s/it]

1.0842645035849676


 62%|██████▏   | 37/60 [04:23<02:44,  7.14s/it]

1.0773605240715876


 63%|██████▎   | 38/60 [04:31<02:37,  7.15s/it]

1.089201807975769


 65%|██████▌   | 39/60 [04:38<02:29,  7.14s/it]

1.0990979274113972


 67%|██████▋   | 40/60 [04:45<02:23,  7.18s/it]

1.0946520566940308


 68%|██████▊   | 41/60 [04:52<02:16,  7.18s/it]

1.0656696028179593


 70%|███████   | 42/60 [04:59<02:08,  7.16s/it]

1.0145684745576646


 72%|███████▏  | 43/60 [05:06<02:01,  7.14s/it]

1.0016344388326008


 73%|███████▎  | 44/60 [05:13<01:54,  7.13s/it]

0.9699206352233887


 75%|███████▌  | 45/60 [05:21<01:46,  7.13s/it]

0.9511749744415283


 77%|███████▋  | 46/60 [05:28<01:39,  7.13s/it]

0.9441843827565511


 78%|███████▊  | 47/60 [05:35<01:32,  7.13s/it]

0.9357587960031297


 80%|████████  | 48/60 [05:42<01:25,  7.13s/it]

0.9476006163491143


 82%|████████▏ | 49/60 [05:49<01:18,  7.14s/it]

0.9500263465775384


 83%|████████▎ | 50/60 [05:56<01:11,  7.16s/it]

0.9334319167666965


 85%|████████▌ | 51/60 [06:03<01:04,  7.15s/it]

0.9031633337338766


 87%|████████▋ | 52/60 [06:11<00:57,  7.14s/it]

0.8784092466036478


 88%|████████▊ | 53/60 [06:18<00:49,  7.14s/it]

0.8636360036002265


 90%|█████████ | 54/60 [06:25<00:42,  7.14s/it]

0.8538455698225234


 92%|█████████▏| 55/60 [06:32<00:35,  7.14s/it]

0.8388936718304952


 93%|█████████▎| 56/60 [06:39<00:28,  7.14s/it]

0.8349580830997891


 95%|█████████▌| 57/60 [06:46<00:21,  7.14s/it]

0.8397775424851311


 97%|█████████▋| 58/60 [06:53<00:14,  7.14s/it]

0.8376123441590203


 98%|█████████▊| 59/60 [07:01<00:07,  7.14s/it]

0.8326010637813144


100%|██████████| 60/60 [07:08<00:00,  7.14s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.8095423844125536


  2%|▏         | 1/60 [00:07<06:59,  7.12s/it]

0.7743820283148024


  3%|▎         | 2/60 [00:14<06:52,  7.12s/it]

0.74727373652988


  5%|▌         | 3/60 [00:21<06:45,  7.12s/it]

0.7230889134936862


  7%|▋         | 4/60 [00:28<06:38,  7.12s/it]

0.7055499917931027


  8%|▊         | 5/60 [00:35<06:31,  7.12s/it]

0.6967931787172953


 10%|█         | 6/60 [00:42<06:24,  7.12s/it]

0.6872832510206435


 12%|█▏        | 7/60 [00:49<06:16,  7.11s/it]

0.6804658406310611


 13%|█▎        | 8/60 [00:56<06:09,  7.11s/it]

0.6802761720286475


 15%|█▌        | 9/60 [01:04<06:02,  7.11s/it]

0.6756836175918579


 17%|█▋        | 10/60 [01:11<05:55,  7.11s/it]

0.6700089507632785


 18%|█▊        | 11/60 [01:18<05:49,  7.13s/it]

0.6687697933779823


 20%|██        | 12/60 [01:25<05:42,  7.14s/it]

0.6687204274866316


 22%|██▏       | 13/60 [01:32<05:35,  7.14s/it]

0.6642493440045251


 23%|██▎       | 14/60 [01:39<05:28,  7.14s/it]

0.6627717746628655


 25%|██▌       | 15/60 [01:46<05:20,  7.13s/it]

0.6609650651613871


 27%|██▋       | 16/60 [01:53<05:13,  7.12s/it]

0.6634736624028947


 28%|██▊       | 17/60 [02:01<05:06,  7.12s/it]

0.6585219005743662


 30%|███       | 18/60 [02:08<05:01,  7.18s/it]

0.655113140741984


 32%|███▏      | 19/60 [02:15<04:53,  7.16s/it]

0.6566536558998955


 33%|███▎      | 20/60 [02:22<04:45,  7.14s/it]

0.6534667677349515


 35%|███▌      | 21/60 [02:29<04:37,  7.13s/it]

0.6560260454813639


 37%|███▋      | 22/60 [02:36<04:30,  7.12s/it]

0.6522942582766215


 38%|███▊      | 23/60 [02:43<04:23,  7.12s/it]

0.6504793498251173


 40%|████      | 24/60 [02:51<04:16,  7.11s/it]

0.6473434368769327


 42%|████▏     | 25/60 [02:58<04:09,  7.12s/it]

0.6489959458510081


 43%|████▎     | 26/60 [03:05<04:02,  7.12s/it]

0.6483749945958456


 45%|████▌     | 27/60 [03:12<03:55,  7.12s/it]

0.6453098654747009


 47%|████▋     | 28/60 [03:19<03:47,  7.12s/it]

0.6450166867838966


 48%|████▊     | 29/60 [03:26<03:40,  7.12s/it]

0.6446490188439687


 50%|█████     | 30/60 [03:33<03:33,  7.12s/it]

0.6441992587513394


 52%|█████▏    | 31/60 [03:40<03:26,  7.13s/it]

0.6419478356838226


 53%|█████▎    | 32/60 [03:48<03:19,  7.12s/it]

0.6414603392283121


 55%|█████▌    | 33/60 [03:55<03:12,  7.13s/it]

0.6417674620946249


 57%|█████▋    | 34/60 [04:02<03:05,  7.14s/it]

0.6399615605672201


 58%|█████▊    | 35/60 [04:09<02:58,  7.13s/it]

0.6397140887048509


 60%|██████    | 36/60 [04:16<02:51,  7.13s/it]

0.6388987965053983


 62%|██████▏   | 37/60 [04:23<02:44,  7.13s/it]

0.6377967794736227


 63%|██████▎   | 38/60 [04:30<02:36,  7.13s/it]

0.6377261413468255


 65%|██████▌   | 39/60 [04:37<02:29,  7.14s/it]

0.6371789044804044


 67%|██████▋   | 40/60 [04:45<02:22,  7.13s/it]

0.6347541941536797


 68%|██████▊   | 41/60 [04:52<02:15,  7.13s/it]

0.6344867547353109


 70%|███████   | 42/60 [04:59<02:08,  7.12s/it]

0.6347079873085022


 72%|███████▏  | 43/60 [05:06<02:01,  7.12s/it]

0.6335808700985379


 73%|███████▎  | 44/60 [05:13<01:54,  7.13s/it]

0.6319718526469337


 75%|███████▌  | 45/60 [05:20<01:46,  7.13s/it]

0.6317791210280524


 77%|███████▋  | 46/60 [05:27<01:39,  7.13s/it]

0.6297262443436517


 78%|███████▊  | 47/60 [05:34<01:32,  7.12s/it]

0.630444758468204


 80%|████████  | 48/60 [05:42<01:25,  7.12s/it]

0.6288344264030457


 82%|████████▏ | 49/60 [05:49<01:18,  7.12s/it]

0.628329359822803


 83%|████████▎ | 50/60 [05:56<01:11,  7.12s/it]

0.6281141440073649


 85%|████████▌ | 51/60 [06:03<01:04,  7.13s/it]

0.6277099781566196


 87%|████████▋ | 52/60 [06:10<00:57,  7.13s/it]

0.6278564466370476


 88%|████████▊ | 53/60 [06:17<00:49,  7.13s/it]

0.6261135008600023


 90%|█████████ | 54/60 [06:24<00:42,  7.12s/it]

0.6251435412300957


 92%|█████████▏| 55/60 [06:31<00:35,  7.13s/it]

0.6283265186680688


 93%|█████████▎| 56/60 [06:39<00:28,  7.13s/it]

0.6231034133169386


 95%|█████████▌| 57/60 [06:46<00:21,  7.12s/it]

0.6256010399924384


 97%|█████████▋| 58/60 [06:53<00:14,  7.12s/it]

0.6240205069382986


 98%|█████████▊| 59/60 [07:00<00:07,  7.12s/it]

0.6247304214371575


100%|██████████| 60/60 [07:07<00:00,  7.13s/it]

0.6237954066859351
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [34]:
model_score3

{'Bleu_1': 0.5630024468217638,
 'Bleu_2': 0.4286591057147059,
 'Bleu_3': 0.34603280111232027,
 'Bleu_4': 0.28954717484367753,
 'METEOR': 0.24443921129255886,
 'ROUGE_L': 0.46283298458735955,
 'CIDEr': 1.5244663658940862,
 'SPICE': 0.3099243035870373,
 'USC_similarity': 0.5423565199348738}

In [35]:
caption_model4, model_score4 = cross_validation(cv[3][0], cv[3][1], 4)    

Split 4:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 4/8333 [00:00<04:05, 33.99it/s]

preprocessed words 2680 ==> 894
The vocabulary size is 895.
806 out of 895 words are found in the pre-trained matrix.
The size of embedding_matrix is (895, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:36<00:00, 38.47it/s]
  0%|          | 4/2083 [00:00<00:54, 37.90it/s]


Generating set took: 0:03:36.62


100%|██████████| 2083/2083 [00:54<00:00, 38.26it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:54.45
Training...


  2%|▏         | 1/60 [00:07<07:14,  7.36s/it]

10.558717727661133


  3%|▎         | 2/60 [00:14<07:02,  7.29s/it]

4.734496593475342


  5%|▌         | 3/60 [00:21<06:52,  7.23s/it]

4.00533209906684


  7%|▋         | 4/60 [00:28<06:42,  7.19s/it]

3.3375267452663846


  8%|▊         | 5/60 [00:35<06:34,  7.17s/it]

2.8710520532396107


 10%|█         | 6/60 [00:42<06:26,  7.16s/it]

2.5733842055002847


 12%|█▏        | 7/60 [00:50<06:18,  7.15s/it]

2.3699067963494196


 13%|█▎        | 8/60 [00:57<06:11,  7.14s/it]

2.2061988512674966


 15%|█▌        | 9/60 [01:04<06:03,  7.13s/it]

2.0480176210403442


 17%|█▋        | 10/60 [01:11<05:56,  7.13s/it]

1.9400874376296997


 18%|█▊        | 11/60 [01:18<05:49,  7.13s/it]

1.8312826421525743


 20%|██        | 12/60 [01:25<05:41,  7.12s/it]

1.7576369576983981


 22%|██▏       | 13/60 [01:32<05:35,  7.14s/it]

1.699012491438124


 23%|██▎       | 14/60 [01:39<05:28,  7.14s/it]

1.5855909983317058


 25%|██▌       | 15/60 [01:47<05:21,  7.14s/it]

1.5195782979329426


 27%|██▋       | 16/60 [01:54<05:14,  7.14s/it]

1.4710496796502008


 28%|██▊       | 17/60 [02:01<05:07,  7.14s/it]

1.4318938652674358


 30%|███       | 18/60 [02:08<04:59,  7.14s/it]

1.3781080312199063


 32%|███▏      | 19/60 [02:15<04:52,  7.13s/it]

1.3410282002554998


 33%|███▎      | 20/60 [02:22<04:45,  7.15s/it]

1.2910880910025702


 35%|███▌      | 21/60 [02:29<04:38,  7.14s/it]

1.2791367305649652


 37%|███▋      | 22/60 [02:37<04:31,  7.15s/it]

1.2375072307056851


 38%|███▊      | 23/60 [02:44<04:24,  7.14s/it]

1.2137591905064053


 40%|████      | 24/60 [02:51<04:17,  7.14s/it]

1.1873091061909993


 42%|████▏     | 25/60 [02:58<04:09,  7.13s/it]

1.1650785472657945


 43%|████▎     | 26/60 [03:05<04:02,  7.13s/it]

1.1576528681649103


 45%|████▌     | 27/60 [03:12<03:54,  7.12s/it]

1.134478489557902


 47%|████▋     | 28/60 [03:19<03:47,  7.12s/it]

1.0830030375056796


 48%|████▊     | 29/60 [03:26<03:41,  7.13s/it]

1.025813102722168


 50%|█████     | 30/60 [03:34<03:33,  7.13s/it]

0.9863102502293057


 52%|█████▏    | 31/60 [03:41<03:26,  7.13s/it]

0.9546332226859199


 53%|█████▎    | 32/60 [03:48<03:19,  7.13s/it]

0.9421185387505425


 55%|█████▌    | 33/60 [03:55<03:12,  7.13s/it]

0.9221247434616089


 57%|█████▋    | 34/60 [04:02<03:05,  7.14s/it]

0.9046211176448398


 58%|█████▊    | 35/60 [04:09<02:59,  7.18s/it]

0.8899806870354546


 60%|██████    | 36/60 [04:17<02:52,  7.17s/it]

0.878796398639679


 62%|██████▏   | 37/60 [04:24<02:44,  7.16s/it]

0.868840754032135


 63%|██████▎   | 38/60 [04:31<02:37,  7.15s/it]

0.8409995635350546


 65%|██████▌   | 39/60 [04:38<02:29,  7.13s/it]

0.8252517117394341


 67%|██████▋   | 40/60 [04:45<02:22,  7.13s/it]

0.8163842227723863


 68%|██████▊   | 41/60 [04:52<02:15,  7.13s/it]

0.8127738899654813


 70%|███████   | 42/60 [04:59<02:08,  7.13s/it]

0.8140745494100783


 72%|███████▏  | 43/60 [05:06<02:01,  7.13s/it]

0.8061912655830383


 73%|███████▎  | 44/60 [05:14<01:54,  7.13s/it]

0.7787592278586494


 75%|███████▌  | 45/60 [05:21<01:47,  7.15s/it]

0.7688031196594238


 77%|███████▋  | 46/60 [05:28<01:40,  7.15s/it]

0.741871522532569


 78%|███████▊  | 47/60 [05:35<01:32,  7.14s/it]

0.7455833322472043


 80%|████████  | 48/60 [05:42<01:25,  7.13s/it]

0.7574591802226173


 82%|████████▏ | 49/60 [05:49<01:18,  7.14s/it]

0.7620042430029975


 83%|████████▎ | 50/60 [05:56<01:11,  7.14s/it]

0.7560190955797831


 85%|████████▌ | 51/60 [06:04<01:04,  7.15s/it]

0.7623262537850274


 87%|████████▋ | 52/60 [06:11<00:57,  7.14s/it]

0.7352078424559699


 88%|████████▊ | 53/60 [06:18<00:50,  7.15s/it]

0.7344183292653825


 90%|█████████ | 54/60 [06:25<00:42,  7.15s/it]

0.7351264225112067


 92%|█████████▏| 55/60 [06:32<00:35,  7.15s/it]

0.7379176947805617


 93%|█████████▎| 56/60 [06:39<00:28,  7.14s/it]

0.7533587283558316


 95%|█████████▌| 57/60 [06:46<00:21,  7.13s/it]

0.7550855941242642


 97%|█████████▋| 58/60 [06:54<00:14,  7.15s/it]

0.7614167398876615


 98%|█████████▊| 59/60 [07:01<00:07,  7.14s/it]

0.7399618162049187


100%|██████████| 60/60 [07:08<00:00,  7.14s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.7234849731127421


  2%|▏         | 1/60 [00:07<07:00,  7.13s/it]

0.7218359775013394


  3%|▎         | 2/60 [00:14<06:53,  7.13s/it]

0.6787008742491404


  5%|▌         | 3/60 [00:21<06:46,  7.12s/it]

0.6422889298862882


  7%|▋         | 4/60 [00:28<06:39,  7.14s/it]

0.6200775570339627


  8%|▊         | 5/60 [00:35<06:32,  7.13s/it]

0.6058431930012174


 10%|█         | 6/60 [00:42<06:25,  7.13s/it]

0.5955073104964362


 12%|█▏        | 7/60 [00:49<06:18,  7.14s/it]

0.5887560513284471


 13%|█▎        | 8/60 [00:57<06:11,  7.15s/it]

0.5811407499843173


 15%|█▌        | 9/60 [01:04<06:04,  7.14s/it]

0.5800105697578855


 17%|█▋        | 10/60 [01:11<05:57,  7.14s/it]

0.5760508113437228


 18%|█▊        | 11/60 [01:18<05:49,  7.14s/it]

0.5726202858818902


 20%|██        | 12/60 [01:25<05:42,  7.13s/it]

0.5706797904438443


 22%|██▏       | 13/60 [01:32<05:35,  7.13s/it]

0.5691327452659607


 23%|██▎       | 14/60 [01:39<05:28,  7.15s/it]

0.5659917725457085


 25%|██▌       | 15/60 [01:47<05:21,  7.14s/it]

0.5637553168667687


 27%|██▋       | 16/60 [01:54<05:14,  7.14s/it]

0.5618370307816399


 28%|██▊       | 17/60 [02:01<05:06,  7.13s/it]

0.5615773300329844


 30%|███       | 18/60 [02:08<04:59,  7.13s/it]

0.5604540705680847


 32%|███▏      | 19/60 [02:15<04:52,  7.13s/it]

0.5584575004047818


 33%|███▎      | 20/60 [02:22<04:44,  7.12s/it]

0.5576412810219659


 35%|███▌      | 21/60 [02:29<04:37,  7.12s/it]

0.5553350249926249


 37%|███▋      | 22/60 [02:36<04:31,  7.15s/it]

0.5541395280096266


 38%|███▊      | 23/60 [02:44<04:24,  7.14s/it]

0.5533766282929314


 40%|████      | 24/60 [02:51<04:17,  7.14s/it]

0.5508276422818502


 42%|████▏     | 25/60 [02:58<04:10,  7.14s/it]

0.550103935930464


 43%|████▎     | 26/60 [03:05<04:02,  7.14s/it]

0.5484791729185317


 45%|████▌     | 27/60 [03:13<04:00,  7.28s/it]

0.5494528743955824


 47%|████▋     | 28/60 [03:20<03:51,  7.25s/it]

0.54745086034139


 48%|████▊     | 29/60 [03:27<03:43,  7.20s/it]

0.5480189124743143


 50%|█████     | 30/60 [03:34<03:35,  7.18s/it]

0.5459070404370626


 52%|█████▏    | 31/60 [03:41<03:27,  7.16s/it]

0.5454565684000651


 53%|█████▎    | 32/60 [03:48<03:20,  7.15s/it]

0.5447299016846551


 55%|█████▌    | 33/60 [03:55<03:12,  7.14s/it]

0.5404144525527954


 57%|█████▋    | 34/60 [04:03<03:05,  7.14s/it]

0.542678071392907


 58%|█████▊    | 35/60 [04:10<02:58,  7.14s/it]

0.5400670634375678


 60%|██████    | 36/60 [04:17<02:51,  7.14s/it]

0.541306283738878


 62%|██████▏   | 37/60 [04:24<02:44,  7.14s/it]

0.5400120218594869


 63%|██████▎   | 38/60 [04:31<02:36,  7.13s/it]

0.5392592814233568


 65%|██████▌   | 39/60 [04:38<02:29,  7.13s/it]

0.5388489034440782


 67%|██████▋   | 40/60 [04:45<02:22,  7.13s/it]

0.5375481645266215


 68%|██████▊   | 41/60 [04:52<02:15,  7.14s/it]

0.5363888111379411


 70%|███████   | 42/60 [05:00<02:08,  7.14s/it]

0.5355384349822998


 72%|███████▏  | 43/60 [05:07<02:01,  7.14s/it]

0.5351604521274567


 73%|███████▎  | 44/60 [05:14<01:54,  7.13s/it]

0.5356493956512876


 75%|███████▌  | 45/60 [05:21<01:47,  7.13s/it]

0.5328725543287065


 77%|███████▋  | 46/60 [05:28<01:39,  7.13s/it]

0.5320179528660245


 78%|███████▊  | 47/60 [05:35<01:32,  7.13s/it]

0.5314905511008369


 80%|████████  | 48/60 [05:42<01:25,  7.14s/it]

0.5322798093159994


 82%|████████▏ | 49/60 [05:50<01:18,  7.15s/it]

0.5312678615252177


 83%|████████▎ | 50/60 [05:57<01:11,  7.15s/it]

0.530590769317415


 85%|████████▌ | 51/60 [06:04<01:04,  7.15s/it]

0.5299335651927524


 87%|████████▋ | 52/60 [06:11<00:57,  7.15s/it]

0.5295970241228739


 88%|████████▊ | 53/60 [06:18<00:49,  7.14s/it]

0.5293578637970818


 90%|█████████ | 54/60 [06:25<00:42,  7.14s/it]

0.5272243387169309


 92%|█████████▏| 55/60 [06:32<00:35,  7.14s/it]

0.5260422627131144


 93%|█████████▎| 56/60 [06:40<00:28,  7.14s/it]

0.5262830787234836


 95%|█████████▌| 57/60 [06:47<00:21,  7.14s/it]

0.527308315038681


 97%|█████████▋| 58/60 [06:54<00:14,  7.14s/it]

0.5246015787124634


 98%|█████████▊| 59/60 [07:01<00:07,  7.15s/it]

0.5261783831649356


100%|██████████| 60/60 [07:08<00:00,  7.15s/it]

0.5248396297295889
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [36]:
model_score4

{'Bleu_1': 0.5388989990900573,
 'Bleu_2': 0.4044291059247442,
 'Bleu_3': 0.3259029084781842,
 'Bleu_4': 0.2738528763642468,
 'METEOR': 0.23856259654355447,
 'ROUGE_L': 0.4506431454397961,
 'CIDEr': 1.4425988826222083,
 'SPICE': 0.2982330863775588,
 'USC_similarity': 0.5311438714013227}

In [29]:
caption_model5, model_score5 = cross_validation(cv[4][0], cv[4][1], 5)    

Split 5:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions


  0%|          | 0/8333 [00:00<?, ?it/s]

preprocessed words 2657 ==> 905
The vocabulary size is 906.
815 out of 906 words are found in the pre-trained matrix.
The size of embedding_matrix is (906, 500)
Preparing dataloader...


100%|██████████| 8333/8333 [03:32<00:00, 39.15it/s]
  0%|          | 4/2083 [00:00<00:52, 39.59it/s]


Generating set took: 0:03:32.86


100%|██████████| 2083/2083 [00:53<00:00, 38.91it/s]
  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.54
Training...


  2%|▏         | 1/60 [00:07<07:00,  7.13s/it]

9.144343694051107


  3%|▎         | 2/60 [00:14<06:52,  7.11s/it]

5.109705183241102


  5%|▌         | 3/60 [00:21<06:44,  7.09s/it]

4.809932337866889


  7%|▋         | 4/60 [00:28<06:36,  7.08s/it]

4.566242535909017


  8%|▊         | 5/60 [00:35<06:28,  7.07s/it]

4.247168805864122


 10%|█         | 6/60 [00:42<06:21,  7.07s/it]

3.8562396897210016


 12%|█▏        | 7/60 [00:49<06:14,  7.06s/it]

3.581740140914917


 13%|█▎        | 8/60 [00:56<06:07,  7.06s/it]

3.440867450502184


 15%|█▌        | 9/60 [01:03<06:00,  7.07s/it]

3.3417934046851263


 17%|█▋        | 10/60 [01:10<05:53,  7.07s/it]

3.257337146335178


 18%|█▊        | 11/60 [01:17<05:46,  7.07s/it]

3.1640970706939697


 20%|██        | 12/60 [01:24<05:39,  7.07s/it]

3.0562586784362793


 22%|██▏       | 13/60 [01:31<05:32,  7.07s/it]

2.937298854192098


 23%|██▎       | 14/60 [01:38<05:25,  7.07s/it]

2.8644200695885553


 25%|██▌       | 15/60 [01:45<05:17,  7.06s/it]

2.768002165688409


 27%|██▋       | 16/60 [01:53<05:10,  7.06s/it]

2.7074325349595814


 28%|██▊       | 17/60 [02:00<05:04,  7.08s/it]

2.6062147352430554


 30%|███       | 18/60 [02:07<04:57,  7.08s/it]

2.525383154551188


 32%|███▏      | 19/60 [02:14<04:51,  7.11s/it]

2.4517191780938044


 33%|███▎      | 20/60 [02:21<04:43,  7.10s/it]

2.416715145111084


 35%|███▌      | 21/60 [02:28<04:37,  7.11s/it]

2.3538781801859536


 37%|███▋      | 22/60 [02:35<04:29,  7.10s/it]

2.3012255562676325


 38%|███▊      | 23/60 [02:42<04:22,  7.09s/it]

2.237253612942166


 40%|████      | 24/60 [02:49<04:14,  7.08s/it]

2.204681820339627


 42%|████▏     | 25/60 [02:56<04:07,  7.08s/it]

2.168993737962511


 43%|████▎     | 26/60 [03:03<04:00,  7.08s/it]

2.117914358774821


 45%|████▌     | 27/60 [03:11<03:53,  7.08s/it]

2.0690981811947293


 47%|████▋     | 28/60 [03:18<03:46,  7.09s/it]

2.0350513458251953


 48%|████▊     | 29/60 [03:25<03:39,  7.08s/it]

1.978740202056037


 50%|█████     | 30/60 [03:32<03:32,  7.08s/it]

1.9345980087916057


 52%|█████▏    | 31/60 [03:39<03:25,  7.08s/it]

1.923423343234592


 53%|█████▎    | 32/60 [03:46<03:18,  7.09s/it]

1.91128761238522


 55%|█████▌    | 33/60 [03:53<03:11,  7.09s/it]

1.8573675950368245


 57%|█████▋    | 34/60 [04:00<03:04,  7.09s/it]

1.8057643704944186


 58%|█████▊    | 35/60 [04:07<02:57,  7.09s/it]

1.769099513689677


 60%|██████    | 36/60 [04:14<02:50,  7.09s/it]

1.7401983473036025


 62%|██████▏   | 37/60 [04:21<02:42,  7.09s/it]

1.7130545642640855


 63%|██████▎   | 38/60 [04:29<02:35,  7.09s/it]

1.6855633523729112


 65%|██████▌   | 39/60 [04:36<02:28,  7.09s/it]

1.6512061887317233


 67%|██████▋   | 40/60 [04:43<02:21,  7.09s/it]

1.6476503610610962


 68%|██████▊   | 41/60 [04:50<02:14,  7.09s/it]

1.6180489195717707


 70%|███████   | 42/60 [04:57<02:07,  7.09s/it]

1.578249414761861


 72%|███████▏  | 43/60 [05:04<02:01,  7.13s/it]

1.538805537753635


 73%|███████▎  | 44/60 [05:11<01:53,  7.12s/it]

1.5060696999231975


 75%|███████▌  | 45/60 [05:18<01:46,  7.10s/it]

1.4866662952635024


 77%|███████▋  | 46/60 [05:25<01:39,  7.10s/it]

1.4723752472135756


 78%|███████▊  | 47/60 [05:32<01:32,  7.09s/it]

1.4703008201387193


 80%|████████  | 48/60 [05:40<01:25,  7.09s/it]

1.4549934996498957


 82%|████████▏ | 49/60 [05:47<01:17,  7.09s/it]

1.4175766838921442


 83%|████████▎ | 50/60 [05:54<01:10,  7.10s/it]

1.396935264269511


 85%|████████▌ | 51/60 [06:01<01:03,  7.10s/it]

1.3624948130713568


 87%|████████▋ | 52/60 [06:08<00:56,  7.10s/it]

1.3456732167137995


 88%|████████▊ | 53/60 [06:15<00:49,  7.09s/it]

1.346183353000217


 90%|█████████ | 54/60 [06:22<00:42,  7.09s/it]

1.3416554530461628


 92%|█████████▏| 55/60 [06:29<00:35,  7.09s/it]

1.3213997019661798


 93%|█████████▎| 56/60 [06:36<00:28,  7.08s/it]

1.3082656727896795


 95%|█████████▌| 57/60 [06:43<00:21,  7.08s/it]

1.305000126361847


 97%|█████████▋| 58/60 [06:50<00:14,  7.08s/it]

1.28587587012185


 98%|█████████▊| 59/60 [06:57<00:07,  7.08s/it]

1.2478602131207783


100%|██████████| 60/60 [07:05<00:00,  7.08s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

1.2224047713809543


  2%|▏         | 1/60 [00:07<06:57,  7.07s/it]

1.2379004028108385


  3%|▎         | 2/60 [00:14<06:50,  7.07s/it]

1.1882722510231867


  5%|▌         | 3/60 [00:21<06:44,  7.09s/it]

1.1564662125375536


  7%|▋         | 4/60 [00:28<06:37,  7.09s/it]

1.1388231847021315


  8%|▊         | 5/60 [00:35<06:29,  7.09s/it]

1.1279587480756972


 10%|█         | 6/60 [00:42<06:23,  7.10s/it]

1.122513731320699


 12%|█▏        | 7/60 [00:49<06:16,  7.10s/it]

1.1152786082691617


 13%|█▎        | 8/60 [00:56<06:08,  7.09s/it]

1.1118472218513489


 15%|█▌        | 9/60 [01:03<06:02,  7.10s/it]

1.108205305205451


 17%|█▋        | 10/60 [01:10<05:55,  7.10s/it]

1.105185992187924


 18%|█▊        | 11/60 [01:18<05:48,  7.10s/it]

1.1054913931422763


 20%|██        | 12/60 [01:25<05:40,  7.10s/it]

1.101044283972846


 22%|██▏       | 13/60 [01:32<05:33,  7.10s/it]

1.0983809298939176


 23%|██▎       | 14/60 [01:39<05:26,  7.10s/it]

1.0965974728266399


 25%|██▌       | 15/60 [01:46<05:19,  7.11s/it]

1.096397426393297


 27%|██▋       | 16/60 [01:53<05:12,  7.10s/it]

1.0917123688591852


 28%|██▊       | 17/60 [02:00<05:05,  7.10s/it]

1.0916646652751498


 30%|███       | 18/60 [02:07<04:58,  7.10s/it]

1.0888843999968634


 32%|███▏      | 19/60 [02:14<04:50,  7.09s/it]

1.0892223185963101


 33%|███▎      | 20/60 [02:21<04:43,  7.10s/it]

1.0867644614643521


 35%|███▌      | 21/60 [02:29<04:37,  7.10s/it]

1.088236755794949


 37%|███▋      | 22/60 [02:36<04:29,  7.10s/it]

1.08707508775923


 38%|███▊      | 23/60 [02:43<04:22,  7.10s/it]

1.0849636064635382


 40%|████      | 24/60 [02:50<04:15,  7.09s/it]

1.079836818906996


 42%|████▏     | 25/60 [02:57<04:08,  7.09s/it]

1.082279900709788


 43%|████▎     | 26/60 [03:04<04:00,  7.09s/it]

1.0796051820119221


 45%|████▌     | 27/60 [03:11<03:53,  7.08s/it]

1.078540199332767


 47%|████▋     | 28/60 [03:18<03:46,  7.08s/it]

1.07850091987186


 48%|████▊     | 29/60 [03:25<03:39,  7.08s/it]

1.0787378019756741


 50%|█████     | 30/60 [03:32<03:32,  7.09s/it]

1.0776830977863736


 52%|█████▏    | 31/60 [03:39<03:25,  7.09s/it]

1.075628187921312


 53%|█████▎    | 32/60 [03:47<03:18,  7.10s/it]

1.0745004216829936


 55%|█████▌    | 33/60 [03:54<03:11,  7.11s/it]

1.073340568277571


 57%|█████▋    | 34/60 [04:01<03:04,  7.11s/it]

1.072467075453864


 58%|█████▊    | 35/60 [04:08<02:57,  7.10s/it]

1.072117547194163


 60%|██████    | 36/60 [04:15<02:50,  7.10s/it]

1.069969978597429


 62%|██████▏   | 37/60 [04:22<02:43,  7.09s/it]

1.0728927387131586


 63%|██████▎   | 38/60 [04:29<02:36,  7.10s/it]

1.0702733331256442


 65%|██████▌   | 39/60 [04:36<02:29,  7.13s/it]

1.0693546798494127


 67%|██████▋   | 40/60 [04:43<02:22,  7.12s/it]

1.0694426894187927


 68%|██████▊   | 41/60 [04:51<02:15,  7.12s/it]

1.0680675241682265


 70%|███████   | 42/60 [04:58<02:08,  7.11s/it]

1.066617853111691


 72%|███████▏  | 43/60 [05:05<02:01,  7.12s/it]

1.0670632322629292


 73%|███████▎  | 44/60 [05:12<01:53,  7.11s/it]

1.0658184157477484


 75%|███████▌  | 45/60 [05:19<01:46,  7.10s/it]

1.0651385121875339


 77%|███████▋  | 46/60 [05:26<01:39,  7.10s/it]

1.0655506054560344


 78%|███████▊  | 47/60 [05:33<01:32,  7.10s/it]

1.0660792986551921


 80%|████████  | 48/60 [05:40<01:25,  7.13s/it]

1.0617514120207892


 82%|████████▏ | 49/60 [05:48<01:18,  7.13s/it]

1.0590561893251207


 83%|████████▎ | 50/60 [05:55<01:11,  7.12s/it]

1.0619060860739813


 85%|████████▌ | 51/60 [06:02<01:03,  7.11s/it]

1.0615310470263164


 87%|████████▋ | 52/60 [06:09<00:56,  7.12s/it]

1.0616169836786058


 88%|████████▊ | 53/60 [06:16<00:49,  7.12s/it]

1.0570208297835455


 90%|█████████ | 54/60 [06:23<00:42,  7.12s/it]

1.0577477481630113


 92%|█████████▏| 55/60 [06:30<00:35,  7.11s/it]

1.0580286449856229


 93%|█████████▎| 56/60 [06:37<00:28,  7.10s/it]

1.056614292992486


 95%|█████████▌| 57/60 [06:44<00:21,  7.09s/it]

1.055959039264255


 97%|█████████▋| 58/60 [06:51<00:14,  7.09s/it]

1.0564703808890448


 98%|█████████▊| 59/60 [06:58<00:07,  7.09s/it]

1.0556742350260417


100%|██████████| 60/60 [07:06<00:00,  7.10s/it]

1.0554755793677435
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [30]:
model_score5

{'Bleu_1': 0.5251268081908451,
 'Bleu_2': 0.3894817198962167,
 'Bleu_3': 0.3059120071700011,
 'Bleu_4': 0.24930890329749703,
 'METEOR': 0.22045867964786875,
 'ROUGE_L': 0.42268399019753317,
 'CIDEr': 1.259938772058676,
 'SPICE': 0.27510864518769956,
 'USC_similarity': 0.518327823850544}

In [32]:
tag = '11.1.2-2'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'r') as fp:
    model_scores = json.load(fp)

In [33]:
for key, value in model_score5.items():
    model_scores[key].append(value)

In [34]:
model_scores

{'Bleu_1': [0.5430511873831741,
  0.5223053064684876,
  0.5630024468217638,
  0.5388989990900573,
  0.5251268081908451],
 'Bleu_2': [0.4040510387044148,
  0.39087007653399974,
  0.4286591057147059,
  0.4044291059247442,
  0.3894817198962167],
 'Bleu_3': [0.3202125719013248,
  0.3113853373714817,
  0.34603280111232027,
  0.3259029084781842,
  0.3059120071700011],
 'Bleu_4': [0.26499074763786395,
  0.2582108151703191,
  0.28954717484367753,
  0.2738528763642468,
  0.24930890329749703],
 'METEOR': [0.23442467563859665,
  0.22031991810802834,
  0.24443921129255886,
  0.23856259654355447,
  0.22045867964786875],
 'ROUGE_L': [0.44063030412906334,
  0.4218176742259456,
  0.46283298458735955,
  0.4506431454397961,
  0.42268399019753317],
 'CIDEr': [1.284062242860915,
  1.313764382035548,
  1.5244663658940862,
  1.4425988826222083,
  1.259938772058676],
 'SPICE': [0.2835413805275309,
  0.28175420506525184,
  0.3099243035870373,
  0.2982330863775588,
  0.27510864518769956],
 'USC_similarity': [0

In [35]:
tag = '11.1.1'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'w') as fp:
    json.dump(model_scores, fp)