## Image Captioning with Pytorch

The following contents are modified from MDS DSCI 575 lecture 8 demo

In [1]:
import os, sys, json
from collections import defaultdict
from tqdm import tqdm
import pickle
from time import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torchsummary import summary
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

from nltk.translate import bleu_score
from sklearn.model_selection import KFold

sys.path.append('../../scr/evaluation')
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.usc_sim.usc_sim import usc_sim
import subprocess


START = "startseq"
STOP = "endseq"
EPOCHS = 10
AWS = True


In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
# torch.cuda.empty_cache()
# import gc 
# gc.collect()

In [4]:
# Nicely formatted time string
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m:>02}:{s:>05.2f}"
        
if AWS:
    root_captioning = "../../data"
else:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        root_captioning = "/content/drive/My Drive/data"
        COLAB = True
        print("Note: using Google CoLab")
    except:
        print("Note: not using Google CoLab")
        COLAB = False

### Clean/Build Dataset

- Read captions
- Preprocess captions


In [5]:
def get_img_info(name, num=np.inf):
    """
    Returns img paths and captions

    Parameters:
    -----------
    name: str
        the json file name
    num: int (default: np.inf)
        the number of observations to get

    Return:
    --------
    list, dict, int
        img paths, corresponding captions, max length of captions
    """
    img_path = []
    caption = [] 
    max_length = 0
    if AWS:
        with open(f'{root_captioning}/json/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for filename in data.keys():
                if num is not None and len(caption) == num:
                    break
                img_path.append(
                    f'{root_captioning}/{name}/{filename}'
                )
                sen_list = []
                for sentence in data[filename]['sentences']:
                    max_length = max(max_length, len(sentence['tokens']))
                    sen_list.append(sentence['raw'])

                caption.append(sen_list)    
    else:            
        with open(f'{root_captioning}/interim/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for set_name in ['rsicd', 'ucm']:
                for filename in data[set_name].keys():
                    if num is not None and len(caption) == num:
                        break

                    img_path.append(
                        f'{root_captioning}/raw/imgs/{set_name}/{filename}'
                    )
                    sen_list = []
                    for sentence in data[set_name][filename]['sentences']:
                        max_length = max(max_length, len(sentence['tokens']))
                        sen_list.append(sentence['raw'])

                    caption.append(sen_list)
    
    return img_path, caption, max_length            


In [6]:
# get img path and caption list
# # only test 800 train samples and 200 valid samples
# train_paths, train_descriptions, max_length_train = get_img_info('train', 800)
# test_paths, test_descriptions, max_length_test = get_img_info('valid', 200)

train_paths, train_descriptions, max_length_train = get_img_info('train')
test_paths, test_descriptions, max_length_test = get_img_info('valid')
max_length = max(max_length_train, max_length_test)



In [7]:
all_paths = train_paths.copy()
all_paths.extend(test_paths.copy())
all_paths = np.array(all_paths)

all_descriptions = train_descriptions.copy()
all_descriptions.extend(test_descriptions.copy())
all_descriptions = np.array(all_descriptions)

captions = all_descriptions.copy()
max_length_all = max(max_length_train, max_length_test)
max_length = max_length_all + 2
      
lex = set()
for sen in all_descriptions:
    [lex.update(d.split()) for d in sen]
    
# add a start and stop token at the beginning/end
for v in all_descriptions:
    for d in range(len(v)):
        v[d] = f'{START} {v[d]} {STOP}'
        
print(f'There are {len(all_paths)} images') 
print(f'There are {len(lex)} unique words (vocab)')
print(f'The maximum length of captions with start and stop token is {max_length}.')


There are 10416 images
There are 2912 unique words (vocab)
The maximum length of captions with start and stop token is 36.


In [8]:
all_paths[-1]

'../../s3/valid/rsicd_park_33.jpg'

In [9]:
all_descriptions[-1]

array(['startseq a vast artificial lake was built in the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq'],
      dtype='<U184')

### Loading Glove Embeddings

In [10]:
embeddings_index = {} 
path = os.path.join(root_captioning, 'glove.6B.200d.txt') if AWS\
else os.path.join(root_captioning, 'raw', 'glove.6B.200d.txt')

f = open(
    path, 
    encoding="utf-8"
)

for line in tqdm(f):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs

f.close()
print(f'Found {len(embeddings_index)} word vectors.')

400000it [00:23, 17385.23it/s]

Found 400000 word vectors.





In [11]:
def get_vocab(descriptions, word_count_threshold=10):

    captions = []
    for val in descriptions:
        for cap in val:
            captions.append(cap)
    print(f'There are {len(captions)} captions')
    
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('preprocessed words %d ==> %d' % (len(word_counts), len(vocab)))
    return vocab

def get_word_dict(vocab):
    
    idxtoword = {}
    wordtoidx = {}

    ix = 1
    for w in vocab:
        wordtoidx[w] = ix
        idxtoword[ix] = w
        ix += 1

    return idxtoword, wordtoidx

def get_vocab_size(idxtoword):
    
    print(f'The vocabulary size is {len(idxtoword) + 1}.')
    return len(idxtoword) + 1


def get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx):

    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    count =0

    for word, i in wordtoidx.items():

        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            count += 1
            # Words not found in the embedding index will be all zeros
            embedding_matrix[i] = embedding_vector
            
    print(f'{count} out of {vocab_size} words are found in the pre-trained matrix.')            
    print(f'The size of embedding_matrix is {embedding_matrix.shape}')
    return embedding_matrix

### Building the Neural Network

An embedding matrix is built from Glove.  This will be directly copied to the weight matrix of the neural network.

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
class CNNModel(nn.Module):

    def __init__(self, cnn_type, pretrained=True):
        """
        Initializes a CNNModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(CNNModel, self).__init__()

        if cnn_type == 'vgg16':
            self.model = models.vgg16(pretrained=pretrained)

            # remove the last two layers in classifier
            self.model.classifier = nn.Sequential(
              *list(self.model.classifier.children())[:-2]
            )
            self.input_size = 224     

        # inception v3 expects (299, 299) sized images
        elif cnn_type == 'inception_v3':
            self.model = models.inception_v3(pretrained=pretrained)
            # remove the classification layer
            self.model.fc = nn.Identity()

            # turn off auxiliary output
            self.model.aux_logits = False
            self.input_size = 299

        else:
            raise Exception("Please choose between 'vgg16' and 'inception_v3'.")

    def forward(self, img_input, train=False):
        """
        forward of the CNNModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        if not train:
            # set the model to evaluation model
            self.model.eval()

        return self.model(img_input)

In [14]:
class RNNModel(nn.Module):

    def __init__(
        self, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):
      
        """
        Initializes a RNNModel

        Parameters:
        -----------
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """

        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        if embedding_matrix is not None:

            self.embedding.load_state_dict({
              'weight': torch.FloatTensor(embedding_matrix)
            })
            self.embedding.weight.requires_grad = embedding_train

        self.dropout = nn.Dropout(p=0.5)

        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
 

    def forward(self, captions):
        """
        forward of the RNNModel

        Parameters:
        -----------
        captions: torch.Tensor
            the padded caption matrix

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        # embed the captions
        embedding = self.dropout(self.embedding(captions))

        outputs, (h, c) = self.lstm(embedding)

        return outputs, (h, c)



In [15]:
class CaptionModel(nn.Module):

    def __init__(
        self, 
        cnn_type, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):

        """
        Initializes a CaptionModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        feature_size: int
            the number of features in the image matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """    
        super(CaptionModel, self).__init__() 

        # set feature_size based on cnn_type
        if cnn_type == 'vgg16':
            self.feature_size = 4096
        elif cnn_type == 'inception_v3':
            self.feature_size = 2048
        else:
            raise Exception("Please choose between 'vgg16' and 'inception_v3'.")  

        self.decoder = RNNModel(
            vocab_size, 
            embedding_dim,
            hidden_size,
            embedding_matrix,
            embedding_train
        )
        
        self.dropout = nn.Dropout(p=0.5)
        self.dense1 = nn.Linear(self.feature_size, hidden_size) 
        self.relu1 = nn.ReLU()
          
        self.dense2 = nn.Linear(hidden_size, hidden_size) 
        self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(hidden_size, vocab_size) 

    def forward(self, img_features, captions):
        """
        forward of the CaptionModel

        Parameters:
        -----------
        img_features: torch.Tensor
            the image feature matrix
        captions: torch.Tensor
            the padded caption matrix

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        img_features =\
        self.relu1(
            self.dense1(
                self.dropout(
                    img_features
                )
            )
        )

        decoder_out, _ = self.decoder(captions)

        # add up decoder outputs and image features
        outputs =\
        self.dense3(
            self.relu2(
                self.dense2(
                    decoder_out.add(
                        (img_features.view(img_features.size(0), 1, -1))\
                        .repeat(1, decoder_out.size(1), 1)
                    )
                )
            )
        )

        return outputs

### Train the Neural Network

In [16]:
def train(model, iterator, optimizer, criterion, clip, vocab_size):
    """
    train the CaptionModel

    Parameters:
    -----------
    model: CaptionModel
        a CaptionModel instance
    iterator: torch.utils.data.dataloader
        a PyTorch dataloader
    optimizer: torch.optim
        a PyTorch optimizer 
    criterion: nn.CrossEntropyLoss
        a PyTorch criterion 

    Return:
    --------
    float
        average loss
    """
    model.train()    
    epoch_loss = 0
    
    for img_features, captions in iterator:
        
        optimizer.zero_grad()

        # for each caption, the end word is not passed for training
        outputs = model(
            img_features.to(device),
            captions[:, :-1].to(device)
        )

        loss = criterion(
            outputs.view(-1, vocab_size), 
            captions[:, 1:].flatten().to(device)
        )
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        
    return epoch_loss / len(iterator)

In [17]:
class SampleDataset(Dataset):
    def __init__(
        self,
        descriptions,
        imgs,
        wordtoidx,
        max_length
        ):
        """
        Initializes a SampleDataset

        Parameters:
        -----------
        descriptions: list
            a list of captions
        imgs: numpy.ndarray
            the image features
        wordtoidx: dict
            the dict to get word index
        max_length: int
            all captions will be padded to this size
        """        
        self.imgs = imgs
        self.descriptions = descriptions
        self.wordtoidx = wordtoidx
        self.max_length = max_length

    def __len__(self):
        """
        Returns the batch size

        Return:
        --------
        int
            the batch size
        """
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Prepare data for each image

        Parameters:
        -----------
        idx: int
          the index of the image to process

        Return:
        --------
        list, list, list
            [5 x image feature matrix],
            [five padded captions for this image]
            [the length of each caption]
        """

        img = self.imgs[idx]
        img_features, captions = [], []
        for desc in self.descriptions[idx]:
            # convert each word into a list of sequences.
            seq = [self.wordtoidx[word] for word in desc.split(' ')
                  if word in self.wordtoidx]
            # pad the sequence with 0 on the right side
            in_seq = np.pad(
                seq, 
                (0, max_length - len(seq)),
                mode='constant',
                constant_values=(0, 0)
                )

            img_features.append(img)
            captions.append(in_seq)
    
        return (img_features, captions)


In [18]:
def my_collate(batch):
    """
    Processes the batch to return from the dataloader

    Parameters:
    -----------
    batch: tuple
      a batch from the Dataset

    Return:
    --------
    list
        [image feature matrix, captions, the length of each caption]
    """  

    img_features = [item[0] for item in batch]
    captions = [item[1] for item in batch]

    img_features = torch.FloatTensor(list(chain(*img_features)))
    captions = torch.LongTensor(list(chain(*captions)))

    return [img_features, captions]

In [19]:
def init_weights(model, embedding_pretrained=True):
    """
    Initialize weights and bias in the model

    Parameters:
    -----------
    model: CaptionModel
      a CaptionModel instance
    embedding_pretrained: bool (default: True)
        not initialize the embedding matrix if True
    """  
  
    for name, param in model.named_parameters():
        if embedding_pretrained and 'embedding' in name:
            continue
        elif 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            


In [20]:
def encode_image(model, img_path):
    """
    Process the images to extract features

    Parameters:
    -----------
    model: CNNModel
      a CNNModel instance
    img_path: str
        the path of the image
 
    Return:
    --------
    torch.Tensor
        the extracted feature matrix from CNNModel
    """  

    img = Image.open(img_path)

    # Perform preprocessing needed by pre-trained models
    preprocessor = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    img = preprocessor(img)
    # Expand to 2D array
    img = img.view(1, *img.shape)
    # Call model to extract the smaller feature set for the image.
    x = model(img.to(device), False) 
    # Shape to correct form to be accepted by LSTM captioning network.
    x = np.squeeze(x)
    return x

In [21]:
def extract_img_features(img_paths, model):
    """
    Extracts, stores and returns image features

    Parameters:
    -----------
    img_paths: list
        the paths of images
    model: CNNModel (default: None)
      a CNNModel instance

    Return:
    --------
    numpy.ndarray
        the extracted image feature matrix from CNNModel
    """ 

    start = time()
    img_features = []

    for image_path in img_paths:
        img_features.append(
            encode_image(model, image_path).cpu().data.numpy()
    )

    print(f"\nGenerating set took: {hms_string(time()-start)}")

    return img_features

In [22]:
def get_train_test(
    encoder,
    train_paths,
    test_paths
):

    train_img_features = extract_img_features(
        train_paths,
        encoder
    )

    test_img_features = extract_img_features(
        test_paths,
        encoder
    )
    return train_img_features, test_img_features

def get_train_dataloader(
    train_descriptions, 
    train_img_features,
    wordtoidx,
    max_length,
    batch_size=200
):
    train_dataset = SampleDataset(
        train_descriptions,
        train_img_features,
        wordtoidx,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size,
        collate_fn=my_collate
    )
    
    return train_loader

def train_model(
    train_loader,
    vocab_size,
    embedding_dim, 
    embedding_matrix,
    cnn_type='inception_v3',
    hidden_size=256,
):

    caption_model = CaptionModel(
        cnn_type, 
        vocab_size, 
        embedding_dim, 
        hidden_size=hidden_size,
        embedding_matrix=embedding_matrix, 
        embedding_train=True
    )

    init_weights(
        caption_model,
        embedding_pretrained=True
    )

    caption_model.to(device)

    # we will ignore the pad token in true target set
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(
        caption_model.parameters(), 
        lr=0.01
    )

    clip = 1
    start = time()

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)

    # reduce the learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-4

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)
    return caption_model

In [23]:
def generateCaption(
    model, 
    img_features,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    in_text = START

    for i in range(max_length):

        sequence = [wordtoidx[w] for w in in_text.split() if w in wordtoidx]
        sequence = np.pad(sequence, (0, max_length - len(sequence)),
                          mode='constant', constant_values=(0, 0))
        model.eval()
        yhat = model(
            torch.FloatTensor(img_features)\
            .view(-1, model.feature_size).to(device),
            torch.LongTensor(sequence).view(-1, max_length).to(device)
        )

        yhat = yhat.view(-1, vocab_size).argmax(1)
        word = idxtoword[yhat.cpu().data.numpy()[i]]
        in_text += ' ' + word
        if word == STOP:
            break
    final = in_text.split()
    final = final[1 : -1]
    final = ' '.join(final)
    return final

### Evaluation

In [24]:

def eval_model(ref_data, results):
    """
    Computes evaluation metrics of the model results against the human annotated captions
    
    Parameters:
    ------------
    ref_data: dict
        a dictionary containing human annotated captions, with image name as key and a list of human annotated captions as values
    
    results: dict
        a dictionary containing model generated caption, with image name as key and a generated caption as value
        
    Returns:
    ------------
    score_dict: a dictionary containing the overall average score for the model
    """
    # download stanford nlp library
    subprocess.call(['../../scr/evaluation/get_stanford_models.sh'])
    
    # format the inputs
    gts = {}
    res = {}

    for imgId in range(len(ref_data)):
        caption_list_sel = []
        for i in range(5):
            lst = {}
            lst['caption'] = ref_data[imgId][i]
            lst['image_id'] = imgId
            lst['id'] = i
            caption_list_sel.append(lst)
        gts[imgId] = caption_list_sel

        res[imgId] = [{'caption': results[imgId]}]
        
    # tokenize
    print('tokenization...')
    tokenizer = PTBTokenizer()
    gts  = tokenizer.tokenize(gts)
    res = tokenizer.tokenize(res)
    
    # compute scores
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE"),
        (usc_sim(), "USC_similarity"),  
        ]
    score_dict = {}
    for scorer, method in scorers:
        print('computing %s score...'%(scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
        else:
            score_dict[method] = score
            
    return score_dict


In [25]:
def evaluate_results(
    test_img_features, 
    model,
    ref,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    # generate results
    print('Generating captions...')
    results = {}
    for n in range(len(test_img_features)):
        img_features = test_img_features[n]
        generated = generateCaption(
            model, 
            img_features,
            max_length,
            vocab_size,
            wordtoidx,
            idxtoword
        )
        results[n] = generated
        
    model_score = eval_model(ref, results)

    return model_score

### Cross validation

In [26]:
cnn_type = 'vgg16'
encoder = CNNModel(cnn_type, pretrained=True)
encoder.to(device)

CNNModel(
  (model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16):

In [27]:
def cross_validation(train_index, test_index, count, cnn_type):
    print('=' * 60)
    print(f'Split {count}:')
    print(f'Splitting data...')
    
    train_paths, test_paths = all_paths[train_index], all_paths[test_index]
    train_descriptions, test_descriptions = all_descriptions[train_index], all_descriptions[test_index]
    print(f'{len(train_paths)} images for training and {len(test_paths)} images for testing.')
    
    vocab = get_vocab(train_descriptions, word_count_threshold=10)
    idxtoword, wordtoidx = get_word_dict(vocab)
    vocab_size = get_vocab_size(idxtoword)
    embedding_dim = 200
    embedding_matrix = get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx) 

    print(f'Preparing dataloader...')
    train_img_features, test_img_features = get_train_test(encoder, train_paths, test_paths)

    train_loader = get_train_dataloader(
        train_descriptions, 
        train_img_features,
        wordtoidx,
        max_length,
        batch_size=200
    )

    print(f'Training...')
    caption_model = train_model(
        train_loader,
        vocab_size,
        embedding_dim, 
        embedding_matrix,
        cnn_type=cnn_type
    )

    
    ref = captions[test_index]
    model_score = evaluate_results(
        test_img_features, 
        caption_model,
        ref,
        max_length,
        vocab_size,
        wordtoidx,
        idxtoword
    )
    
    return caption_model, model_score

In [28]:
cv = KFold(n_splits=5, random_state=123, shuffle=True)
cv = [(train_index, test_index) for train_index, test_index in cv.split(all_paths)]  

In [29]:
caption_model1, model_score1 = cross_validation(cv[0][0], cv[0][1], 1, cnn_type)    

Split 1:
Splitting data...
8332 images for training and 2084 images for testing.
There are 41660 captions
preprocessed words 2659 ==> 884
The vocabulary size is 885.
796 out of 885 words are found in the pre-trained matrix.
The size of embedding_matrix is (885, 200)
Preparing dataloader...

Generating set took: 0:00:57.03


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:14.21
Training...


  2%|▏         | 1/60 [00:22<22:03, 22.43s/it]

4.4197990440187


  3%|▎         | 2/60 [00:44<21:40, 22.42s/it]

2.3923566171101163


  5%|▌         | 3/60 [01:07<21:17, 22.42s/it]

1.7507163626807076


  7%|▋         | 4/60 [01:29<20:54, 22.41s/it]

1.5221842243557884


  8%|▊         | 5/60 [01:52<20:32, 22.41s/it]

1.4062400119645255


 10%|█         | 6/60 [02:14<20:10, 22.41s/it]

1.3372377469426109


 12%|█▏        | 7/60 [02:36<19:47, 22.40s/it]

1.2827462213379996


 13%|█▎        | 8/60 [02:59<19:24, 22.40s/it]

1.242016008922032


 15%|█▌        | 9/60 [03:21<19:02, 22.39s/it]

1.2089008263179235


 17%|█▋        | 10/60 [03:43<18:39, 22.39s/it]

1.1849924269176664


 18%|█▊        | 11/60 [04:06<18:16, 22.38s/it]

1.1640289142018272


 20%|██        | 12/60 [04:28<17:53, 22.37s/it]

1.1428360059147789


 22%|██▏       | 13/60 [04:51<17:31, 22.38s/it]

1.1261998556909107


 23%|██▎       | 14/60 [05:13<17:09, 22.37s/it]

1.1084483620666323


 25%|██▌       | 15/60 [05:35<16:47, 22.38s/it]

1.091893276997975


 27%|██▋       | 16/60 [05:58<16:24, 22.38s/it]

1.0753181676069896


 28%|██▊       | 17/60 [06:20<16:02, 22.39s/it]

1.0635668365728288


 30%|███       | 18/60 [06:43<15:40, 22.39s/it]

1.0527594628788175


 32%|███▏      | 19/60 [07:05<15:18, 22.40s/it]

1.0441187038308097


 33%|███▎      | 20/60 [07:27<14:55, 22.40s/it]

1.037542917898723


 35%|███▌      | 21/60 [07:50<14:33, 22.40s/it]

1.0299441587357294


 37%|███▋      | 22/60 [08:12<14:11, 22.40s/it]

1.0215936516012465


 38%|███▊      | 23/60 [08:35<13:48, 22.40s/it]

1.01761913725308


 40%|████      | 24/60 [08:57<13:26, 22.39s/it]

1.0119928050608862


 42%|████▏     | 25/60 [09:19<13:03, 22.40s/it]

1.0069334663095928


 43%|████▎     | 26/60 [09:42<12:41, 22.40s/it]

1.0008427188509987


 45%|████▌     | 27/60 [10:04<12:19, 22.40s/it]

0.9958772190979549


 47%|████▋     | 28/60 [10:27<11:56, 22.39s/it]

0.9926019438675472


 48%|████▊     | 29/60 [10:49<11:34, 22.40s/it]

0.9880017737547556


 50%|█████     | 30/60 [11:11<11:12, 22.41s/it]

0.9847841972396487


 52%|█████▏    | 31/60 [11:34<10:49, 22.41s/it]

0.9788200855255127


 53%|█████▎    | 32/60 [11:56<10:28, 22.44s/it]

0.9770951909678323


 55%|█████▌    | 33/60 [12:19<10:06, 22.47s/it]

0.9735751989341918


 57%|█████▋    | 34/60 [12:41<09:44, 22.49s/it]

0.967977150565102


 58%|█████▊    | 35/60 [13:04<09:22, 22.49s/it]

0.9646732892308917


 60%|██████    | 36/60 [13:26<08:59, 22.47s/it]

0.9612152391955966


 62%|██████▏   | 37/60 [13:49<08:36, 22.45s/it]

0.9561838706334432


 63%|██████▎   | 38/60 [14:11<08:13, 22.44s/it]

0.9536050089768001


 65%|██████▌   | 39/60 [14:34<07:51, 22.43s/it]

0.9516612745466686


 67%|██████▋   | 40/60 [14:56<07:28, 22.43s/it]

0.949853682801837


 68%|██████▊   | 41/60 [15:18<07:05, 22.42s/it]

0.9449337181590852


 70%|███████   | 42/60 [15:41<06:43, 22.40s/it]

0.9431110875947135


 72%|███████▏  | 43/60 [16:03<06:20, 22.40s/it]

0.940031478802363


 73%|███████▎  | 44/60 [16:26<05:58, 22.40s/it]

0.9405927161375681


 75%|███████▌  | 45/60 [16:48<05:36, 22.43s/it]

0.9378781730220431


 77%|███████▋  | 46/60 [17:10<05:14, 22.44s/it]

0.9364220798015594


 78%|███████▊  | 47/60 [17:33<04:51, 22.45s/it]

0.9344595287527356


 80%|████████  | 48/60 [17:55<04:29, 22.46s/it]

0.9339190111273811


 82%|████████▏ | 49/60 [18:18<04:07, 22.47s/it]

0.9328034775597709


 83%|████████▎ | 50/60 [18:40<03:44, 22.47s/it]

0.9304328362147013


 85%|████████▌ | 51/60 [19:03<03:21, 22.44s/it]

0.9293933539163499


 87%|████████▋ | 52/60 [19:25<02:59, 22.42s/it]

0.9328095416227976


 88%|████████▊ | 53/60 [19:47<02:36, 22.40s/it]

0.9236910144488016


 90%|█████████ | 54/60 [20:10<02:14, 22.39s/it]

0.9215610240186963


 92%|█████████▏| 55/60 [20:32<01:51, 22.38s/it]

0.9239656981967744


 93%|█████████▎| 56/60 [20:55<01:29, 22.37s/it]

0.9234454248632703


 95%|█████████▌| 57/60 [21:17<01:07, 22.37s/it]

0.9205109079678854


 97%|█████████▋| 58/60 [21:39<00:44, 22.36s/it]

0.9209251034827459


 98%|█████████▊| 59/60 [22:02<00:22, 22.37s/it]

0.9232332209746043


100%|██████████| 60/60 [22:24<00:00, 22.41s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9175575872262319


  2%|▏         | 1/60 [00:22<22:01, 22.40s/it]

0.8930607296171642


  3%|▎         | 2/60 [00:44<21:38, 22.40s/it]

0.8643099651450202


  5%|▌         | 3/60 [01:07<21:16, 22.39s/it]

0.8522990629786537


  7%|▋         | 4/60 [01:29<20:53, 22.38s/it]

0.8448544456845238


  8%|▊         | 5/60 [01:51<20:31, 22.39s/it]

0.8378566361608959


 10%|█         | 6/60 [02:14<20:09, 22.40s/it]

0.8348675015426817


 12%|█▏        | 7/60 [02:36<19:46, 22.40s/it]

0.8301908813771748


 13%|█▎        | 8/60 [02:59<19:24, 22.40s/it]

0.8287086103643689


 15%|█▌        | 9/60 [03:21<19:02, 22.39s/it]

0.8264825500193096


 17%|█▋        | 10/60 [03:44<18:41, 22.43s/it]

0.8231716014090038


 18%|█▊        | 11/60 [04:06<18:20, 22.46s/it]

0.8210374798093524


 20%|██        | 12/60 [04:29<17:58, 22.47s/it]

0.8195567485832033


 22%|██▏       | 13/60 [04:51<17:36, 22.47s/it]

0.8173747630346389


 23%|██▎       | 14/60 [05:14<17:14, 22.48s/it]

0.816011640287581


 25%|██▌       | 15/60 [05:36<16:52, 22.49s/it]

0.8144501589593434


 27%|██▋       | 16/60 [05:59<16:29, 22.49s/it]

0.8129216389996665


 28%|██▊       | 17/60 [06:21<16:06, 22.49s/it]

0.8117425157910302


 30%|███       | 18/60 [06:44<15:44, 22.49s/it]

0.8103842607566288


 32%|███▏      | 19/60 [07:06<15:22, 22.50s/it]

0.8071131904919943


 33%|███▎      | 20/60 [07:29<14:59, 22.50s/it]

0.8080165485541025


 35%|███▌      | 21/60 [07:51<14:37, 22.50s/it]

0.8060930087452843


 37%|███▋      | 22/60 [08:14<14:14, 22.50s/it]

0.8057982197829655


 38%|███▊      | 23/60 [08:36<13:52, 22.50s/it]

0.8034258215200334


 40%|████      | 24/60 [08:59<13:30, 22.50s/it]

0.801928909051986


 42%|████▏     | 25/60 [09:21<13:07, 22.50s/it]

0.8014299940495264


 43%|████▎     | 26/60 [09:44<12:45, 22.51s/it]

0.8023254942326319


 45%|████▌     | 27/60 [10:06<12:23, 22.52s/it]

0.8012053767840067


 47%|████▋     | 28/60 [10:29<12:00, 22.53s/it]

0.8006139511153811


 48%|████▊     | 29/60 [10:51<11:37, 22.51s/it]

0.797954553649539


 50%|█████     | 30/60 [11:14<11:14, 22.48s/it]

0.7973329410666511


 52%|█████▏    | 31/60 [11:36<10:51, 22.47s/it]

0.7985682430721465


 53%|█████▎    | 32/60 [11:58<10:28, 22.46s/it]

0.7968130948997679


 55%|█████▌    | 33/60 [12:21<10:07, 22.49s/it]

0.7944737871487936


 57%|█████▋    | 34/60 [12:44<09:44, 22.50s/it]

0.7950472945258731


 58%|█████▊    | 35/60 [13:06<09:21, 22.48s/it]

0.7927986525353932


 60%|██████    | 36/60 [13:28<08:58, 22.45s/it]

0.7932708504654112


 62%|██████▏   | 37/60 [13:51<08:35, 22.43s/it]

0.7925069814636594


 63%|██████▎   | 38/60 [14:13<08:13, 22.43s/it]

0.7918588547479539


 65%|██████▌   | 39/60 [14:36<07:50, 22.41s/it]

0.7915583068416232


 67%|██████▋   | 40/60 [14:58<07:28, 22.40s/it]

0.7908003755978176


 68%|██████▊   | 41/60 [15:20<07:05, 22.40s/it]

0.7901124400751931


 70%|███████   | 42/60 [15:43<06:43, 22.40s/it]

0.7894110012622106


 72%|███████▏  | 43/60 [16:05<06:20, 22.40s/it]

0.7883891931601933


 73%|███████▎  | 44/60 [16:28<05:58, 22.40s/it]

0.7871943541935512


 75%|███████▌  | 45/60 [16:50<05:36, 22.40s/it]

0.7879661875111716


 77%|███████▋  | 46/60 [17:12<05:13, 22.40s/it]

0.7875518146015349


 78%|███████▊  | 47/60 [17:35<04:51, 22.39s/it]

0.7854914537497929


 80%|████████  | 48/60 [17:57<04:28, 22.40s/it]

0.787009135598228


 82%|████████▏ | 49/60 [18:20<04:06, 22.40s/it]

0.7857562104860941


 83%|████████▎ | 50/60 [18:42<03:43, 22.39s/it]

0.7839220591953823


 85%|████████▌ | 51/60 [19:04<03:21, 22.39s/it]

0.7830143925689516


 87%|████████▋ | 52/60 [19:27<02:59, 22.39s/it]

0.7829261592456273


 88%|████████▊ | 53/60 [19:49<02:36, 22.42s/it]

0.7827845343521663


 90%|█████████ | 54/60 [20:12<02:14, 22.43s/it]

0.7825924555460612


 92%|█████████▏| 55/60 [20:34<01:52, 22.42s/it]

0.7823396864391509


 93%|█████████▎| 56/60 [20:56<01:29, 22.42s/it]

0.7828218440214793


 95%|█████████▌| 57/60 [21:19<01:07, 22.42s/it]

0.7804044655391148


 97%|█████████▋| 58/60 [21:41<00:44, 22.43s/it]

0.7799256798766908


 98%|█████████▊| 59/60 [22:04<00:22, 22.41s/it]

0.7792992492516836


100%|██████████| 60/60 [22:26<00:00, 22.44s/it]

0.7798938155174255
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [30]:
model_score1

{'Bleu_1': 0.6571629575623202,
 'Bleu_2': 0.5362381576774322,
 'Bleu_3': 0.45560464775875736,
 'Bleu_4': 0.3964644970967075,
 'METEOR': 0.3079667441267819,
 'ROUGE_L': 0.56041962993465,
 'CIDEr': 2.1887565073910538,
 'SPICE': 0.4062008205782018,
 'USC_similarity': 0.6191158124401221}

In [31]:
caption_model2, model_score2 = cross_validation(cv[1][0], cv[1][1], 2, cnn_type)    

Split 2:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2688 ==> 916
The vocabulary size is 917.
822 out of 917 words are found in the pre-trained matrix.
The size of embedding_matrix is (917, 200)
Preparing dataloader...

Generating set took: 0:00:57.64


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:14.35
Training...


  2%|▏         | 1/60 [00:21<21:34, 21.95s/it]

4.903898000717163


  3%|▎         | 2/60 [00:43<21:12, 21.94s/it]

2.4080131479672024


  5%|▌         | 3/60 [01:05<20:50, 21.94s/it]

1.7491285744167508


  7%|▋         | 4/60 [01:27<20:28, 21.94s/it]

1.5273662067594982


  8%|▊         | 5/60 [01:49<20:06, 21.94s/it]

1.4109641398702348


 10%|█         | 6/60 [02:11<19:43, 21.93s/it]

1.333940083072299


 12%|█▏        | 7/60 [02:33<19:22, 21.93s/it]

1.279664138952891


 13%|█▎        | 8/60 [02:55<18:59, 21.92s/it]

1.2439169855344863


 15%|█▌        | 9/60 [03:17<18:38, 21.92s/it]

1.2120892036528814


 17%|█▋        | 10/60 [03:39<18:16, 21.94s/it]

1.1862955604280745


 18%|█▊        | 11/60 [04:01<17:54, 21.94s/it]

1.1620740095774333


 20%|██        | 12/60 [04:23<17:32, 21.92s/it]

1.1447600438481285


 22%|██▏       | 13/60 [04:45<17:09, 21.91s/it]

1.128550532318297


 23%|██▎       | 14/60 [05:06<16:47, 21.91s/it]

1.1110524152006422


 25%|██▌       | 15/60 [05:28<16:25, 21.91s/it]

1.096833477417628


 27%|██▋       | 16/60 [05:50<16:04, 21.91s/it]

1.0871606540112269


 28%|██▊       | 17/60 [06:12<15:42, 21.91s/it]

1.0709551558608101


 30%|███       | 18/60 [06:34<15:20, 21.91s/it]

1.0616304519630613


 32%|███▏      | 19/60 [06:56<14:58, 21.91s/it]

1.049428220306124


 33%|███▎      | 20/60 [07:18<14:36, 21.91s/it]

1.039554442678179


 35%|███▌      | 21/60 [07:40<14:14, 21.90s/it]

1.0333011476766496


 37%|███▋      | 22/60 [08:02<13:52, 21.90s/it]

1.0330157762482053


 38%|███▊      | 23/60 [08:24<13:30, 21.91s/it]

1.027718335390091


 40%|████      | 24/60 [08:46<13:08, 21.91s/it]

1.0227269473529996


 42%|████▏     | 25/60 [09:07<12:47, 21.92s/it]

1.013504147529602


 43%|████▎     | 26/60 [09:29<12:25, 21.92s/it]

1.0075056893484933


 45%|████▌     | 27/60 [09:51<12:04, 21.95s/it]

1.0012047063736689


 47%|████▋     | 28/60 [10:14<11:44, 22.00s/it]

0.9936284820238749


 48%|████▊     | 29/60 [10:36<11:23, 22.04s/it]

0.9856703465893155


 50%|█████     | 30/60 [10:58<11:01, 22.05s/it]

0.9773070939949581


 52%|█████▏    | 31/60 [11:20<10:39, 22.07s/it]

0.9738257513159797


 53%|█████▎    | 32/60 [11:42<10:17, 22.07s/it]

0.969068937358402


 55%|█████▌    | 33/60 [12:04<09:55, 22.07s/it]

0.9651489612602052


 57%|█████▋    | 34/60 [12:26<09:33, 22.07s/it]

0.965179925873166


 58%|█████▊    | 35/60 [12:48<09:11, 22.07s/it]

0.9627461376644316


 60%|██████    | 36/60 [13:10<08:49, 22.06s/it]

0.9583965426399594


 62%|██████▏   | 37/60 [13:32<08:27, 22.06s/it]

0.9562349546523321


 63%|██████▎   | 38/60 [13:54<08:05, 22.07s/it]

0.9577630006131672


 65%|██████▌   | 39/60 [14:16<07:43, 22.06s/it]

0.9529507614317394


 67%|██████▋   | 40/60 [14:38<07:21, 22.07s/it]

0.9509241737070537


 68%|██████▊   | 41/60 [15:01<06:59, 22.07s/it]

0.9500471083890825


 70%|███████   | 42/60 [15:23<06:37, 22.07s/it]

0.9518695899418422


 72%|███████▏  | 43/60 [15:45<06:15, 22.07s/it]

0.9525817363035112


 73%|███████▎  | 44/60 [16:07<05:53, 22.08s/it]

0.9461322880926586


 75%|███████▌  | 45/60 [16:29<05:31, 22.08s/it]

0.9390094010602861


 77%|███████▋  | 46/60 [16:51<05:08, 22.07s/it]

0.934578959430967


 78%|███████▊  | 47/60 [17:13<04:46, 22.06s/it]

0.9346925417582194


 80%|████████  | 48/60 [17:35<04:24, 22.07s/it]

0.9305059909820557


 82%|████████▏ | 49/60 [17:57<04:02, 22.07s/it]

0.9308363199234009


 83%|████████▎ | 50/60 [18:19<03:40, 22.07s/it]

0.9324791956515539


 85%|████████▌ | 51/60 [18:41<03:18, 22.04s/it]

0.9258883552891868


 87%|████████▋ | 52/60 [19:03<02:56, 22.01s/it]

0.9229043665386382


 88%|████████▊ | 53/60 [19:25<02:33, 21.97s/it]

0.9179828209536416


 90%|█████████ | 54/60 [19:47<02:11, 21.96s/it]

0.918117439463025


 92%|█████████▏| 55/60 [20:09<01:49, 21.95s/it]

0.9159676688058036


 93%|█████████▎| 56/60 [20:31<01:27, 21.95s/it]

0.9123384370690301


 95%|█████████▌| 57/60 [20:53<01:05, 21.94s/it]

0.9118188903445289


 97%|█████████▋| 58/60 [21:15<00:43, 21.93s/it]

0.9142261956419263


 98%|█████████▊| 59/60 [21:37<00:21, 21.93s/it]

0.9154911665689378


100%|██████████| 60/60 [21:58<00:00, 21.98s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.914925849153882


  2%|▏         | 1/60 [00:22<21:39, 22.02s/it]

0.8887057517256055


  3%|▎         | 2/60 [00:43<21:16, 22.01s/it]

0.8575540781021118


  5%|▌         | 3/60 [01:05<20:53, 22.00s/it]

0.8449020995980218


  7%|▋         | 4/60 [01:27<20:31, 21.99s/it]

0.8391740421454111


  8%|▊         | 5/60 [01:49<20:08, 21.97s/it]

0.8318495736235664


 10%|█         | 6/60 [02:11<19:46, 21.96s/it]

0.8282436969734374


 12%|█▏        | 7/60 [02:33<19:24, 21.96s/it]

0.8250225455988021


 13%|█▎        | 8/60 [02:55<19:02, 21.96s/it]

0.820684137798491


 15%|█▌        | 9/60 [03:17<18:40, 21.97s/it]

0.8186311068988982


 17%|█▋        | 10/60 [03:39<18:18, 21.98s/it]

0.8166941063744682


 18%|█▊        | 11/60 [04:01<17:56, 21.97s/it]

0.8146049465451922


 20%|██        | 12/60 [04:23<17:34, 21.96s/it]

0.8117113908131918


 22%|██▏       | 13/60 [04:45<17:12, 21.96s/it]

0.8107565740744272


 23%|██▎       | 14/60 [05:07<16:50, 21.97s/it]

0.8097700689520154


 25%|██▌       | 15/60 [05:29<16:28, 21.96s/it]

0.807491877249309


 27%|██▋       | 16/60 [05:51<16:06, 21.96s/it]

0.8062193705922082


 28%|██▊       | 17/60 [06:13<15:45, 21.98s/it]

0.8044400328681582


 30%|███       | 18/60 [06:35<15:23, 21.98s/it]

0.8025944658688137


 32%|███▏      | 19/60 [06:57<15:01, 21.98s/it]

0.8012943069140116


 33%|███▎      | 20/60 [07:19<14:38, 21.97s/it]

0.8001655737559


 35%|███▌      | 21/60 [07:41<14:16, 21.96s/it]

0.7981320477667309


 37%|███▋      | 22/60 [08:03<13:54, 21.96s/it]

0.7968504272756123


 38%|███▊      | 23/60 [08:25<13:32, 21.96s/it]

0.7979962541943505


 40%|████      | 24/60 [08:47<13:10, 21.96s/it]

0.7967261970043182


 42%|████▏     | 25/60 [09:09<12:48, 21.95s/it]

0.7935922131651924


 43%|████▎     | 26/60 [09:31<12:26, 21.96s/it]

0.7932028458231971


 45%|████▌     | 27/60 [09:53<12:04, 21.96s/it]

0.7934376739320301


 47%|████▋     | 28/60 [10:15<11:42, 21.96s/it]

0.7915727056208111


 48%|████▊     | 29/60 [10:36<11:20, 21.95s/it]

0.7922814956733158


 50%|█████     | 30/60 [10:58<10:58, 21.96s/it]

0.7900963581743694


 52%|█████▏    | 31/60 [11:20<10:36, 21.95s/it]

0.7904623108250755


 53%|█████▎    | 32/60 [11:42<10:14, 21.94s/it]

0.7893822973682767


 55%|█████▌    | 33/60 [12:04<09:53, 21.99s/it]

0.7887844273022243


 57%|█████▋    | 34/60 [12:26<09:32, 22.02s/it]

0.7869127208278293


 58%|█████▊    | 35/60 [12:49<09:10, 22.04s/it]

0.7871642935843695


 60%|██████    | 36/60 [13:11<08:49, 22.06s/it]

0.7865222096443176


 62%|██████▏   | 37/60 [13:33<08:26, 22.03s/it]

0.7862332661946615


 63%|██████▎   | 38/60 [13:55<08:04, 22.00s/it]

0.7842578887939453


 65%|██████▌   | 39/60 [14:17<07:41, 21.99s/it]

0.7833142167045957


 67%|██████▋   | 40/60 [14:38<07:19, 21.97s/it]

0.7833952832789648


 68%|██████▊   | 41/60 [15:00<06:57, 21.96s/it]

0.7825048026584444


 70%|███████   | 42/60 [15:22<06:35, 21.96s/it]

0.7819563008490062


 72%|███████▏  | 43/60 [15:44<06:13, 21.96s/it]

0.7808868374143328


 73%|███████▎  | 44/60 [16:06<05:51, 21.96s/it]

0.7802971175738743


 75%|███████▌  | 45/60 [16:28<05:29, 21.97s/it]

0.781708958603087


 77%|███████▋  | 46/60 [16:50<05:07, 21.97s/it]

0.7806191302481151


 78%|███████▊  | 47/60 [17:12<04:45, 21.96s/it]

0.7775852807930538


 80%|████████  | 48/60 [17:34<04:23, 21.96s/it]

0.7771955799488794


 82%|████████▏ | 49/60 [17:56<04:01, 21.97s/it]

0.7772037798450107


 83%|████████▎ | 50/60 [18:18<03:39, 21.96s/it]

0.7768264980543227


 85%|████████▌ | 51/60 [18:40<03:17, 21.96s/it]

0.776029151110422


 87%|████████▋ | 52/60 [19:02<02:55, 21.95s/it]

0.7759950317087627


 88%|████████▊ | 53/60 [19:24<02:33, 21.95s/it]

0.7758287886778513


 90%|█████████ | 54/60 [19:46<02:11, 21.94s/it]

0.7737265725930532


 92%|█████████▏| 55/60 [20:08<01:49, 21.96s/it]

0.7737659343651363


 93%|█████████▎| 56/60 [20:30<01:27, 21.95s/it]

0.7744525869687399


 95%|█████████▌| 57/60 [20:52<01:05, 21.94s/it]

0.77377336365836


 97%|█████████▋| 58/60 [21:14<00:43, 21.95s/it]

0.7730508702141898


 98%|█████████▊| 59/60 [21:36<00:21, 21.97s/it]

0.7729916600953965


100%|██████████| 60/60 [21:58<00:00, 21.97s/it]

0.772167831659317
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [32]:
model_score2

{'Bleu_1': 0.6352914180816597,
 'Bleu_2': 0.5113443206509108,
 'Bleu_3': 0.4287394882419803,
 'Bleu_4': 0.36894999565248604,
 'METEOR': 0.2968264250708283,
 'ROUGE_L': 0.5391833235204637,
 'CIDEr': 2.0850938434843065,
 'SPICE': 0.39573879922216904,
 'USC_similarity': 0.6125458522518695}

In [33]:
caption_model3, model_score3 = cross_validation(cv[2][0], cv[2][1], 3, cnn_type)    

Split 3:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2714 ==> 890
The vocabulary size is 891.
804 out of 891 words are found in the pre-trained matrix.
The size of embedding_matrix is (891, 200)
Preparing dataloader...

Generating set took: 0:00:57.58


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:14.31
Training...


  2%|▏         | 1/60 [00:21<21:32, 21.91s/it]

4.749859395481291


  3%|▎         | 2/60 [00:43<21:10, 21.90s/it]

2.473125923247564


  5%|▌         | 3/60 [01:05<20:47, 21.89s/it]

1.8171191641262598


  7%|▋         | 4/60 [01:27<20:25, 21.88s/it]

1.574152566137768


  8%|▊         | 5/60 [01:49<20:03, 21.87s/it]

1.4582518424306596


 10%|█         | 6/60 [02:11<19:41, 21.88s/it]

1.3848630019596644


 12%|█▏        | 7/60 [02:33<19:20, 21.89s/it]

1.3327372272809346


 13%|█▎        | 8/60 [02:55<18:58, 21.90s/it]

1.2909652931349618


 15%|█▌        | 9/60 [03:17<18:37, 21.90s/it]

1.2591571978160314


 17%|█▋        | 10/60 [03:38<18:15, 21.90s/it]

1.2276463366690136


 18%|█▊        | 11/60 [04:00<17:53, 21.91s/it]

1.2032254764011927


 20%|██        | 12/60 [04:22<17:31, 21.91s/it]

1.1834406483741033


 22%|██▏       | 13/60 [04:44<17:09, 21.91s/it]

1.1641527442705064


 23%|██▎       | 14/60 [05:06<16:48, 21.92s/it]

1.1465019583702087


 25%|██▌       | 15/60 [05:28<16:26, 21.93s/it]

1.1322658956050873


 27%|██▋       | 16/60 [05:50<16:04, 21.92s/it]

1.1245978673299153


 28%|██▊       | 17/60 [06:12<15:44, 21.96s/it]

1.1158661743005116


 30%|███       | 18/60 [06:34<15:23, 22.00s/it]

1.1065023995581127


 32%|███▏      | 19/60 [06:56<15:03, 22.03s/it]

1.090550986074266


 33%|███▎      | 20/60 [07:18<14:41, 22.04s/it]

1.0826620218299685


 35%|███▌      | 21/60 [07:40<14:19, 22.04s/it]

1.07398902234577


 37%|███▋      | 22/60 [08:02<13:57, 22.05s/it]

1.0643739359719413


 38%|███▊      | 23/60 [08:24<13:36, 22.06s/it]

1.0555398748034523


 40%|████      | 24/60 [08:46<13:13, 22.05s/it]

1.053871208713168


 42%|████▏     | 25/60 [09:09<12:51, 22.05s/it]

1.0437879179205214


 43%|████▎     | 26/60 [09:31<12:29, 22.05s/it]

1.0375916390191942


 45%|████▌     | 27/60 [09:53<12:07, 22.05s/it]

1.0264592411972226


 47%|████▋     | 28/60 [10:15<11:45, 22.05s/it]

1.022162532522565


 48%|████▊     | 29/60 [10:37<11:23, 22.05s/it]

1.0131623432749794


 50%|█████     | 30/60 [10:59<11:01, 22.05s/it]

1.0052278070222764


 52%|█████▏    | 31/60 [11:21<10:39, 22.05s/it]

1.00098374911717


 53%|█████▎    | 32/60 [11:43<10:17, 22.05s/it]

0.9956932692300706


 55%|█████▌    | 33/60 [12:05<09:55, 22.05s/it]

0.9949054278078533


 57%|█████▋    | 34/60 [12:27<09:33, 22.05s/it]

0.9926101508594695


 58%|█████▊    | 35/60 [12:49<09:11, 22.06s/it]

0.9851952720256079


 60%|██████    | 36/60 [13:11<08:49, 22.07s/it]

0.9816905685833522


 62%|██████▏   | 37/60 [13:33<08:27, 22.06s/it]

0.9801673931734902


 63%|██████▎   | 38/60 [13:55<08:04, 22.03s/it]

0.978140454916727


 65%|██████▌   | 39/60 [14:17<07:41, 22.00s/it]

0.9786452878089178


 67%|██████▋   | 40/60 [14:39<07:19, 21.98s/it]

0.973362576393854


 68%|██████▊   | 41/60 [15:01<06:56, 21.95s/it]

0.969047106447674


 70%|███████   | 42/60 [15:23<06:34, 21.93s/it]

0.9668077613626208


 72%|███████▏  | 43/60 [15:45<06:12, 21.93s/it]

0.9646973184176854


 73%|███████▎  | 44/60 [16:07<05:50, 21.92s/it]

0.9610017963818142


 75%|███████▌  | 45/60 [16:29<05:28, 21.93s/it]

0.9596473929427919


 77%|███████▋  | 46/60 [16:50<05:06, 21.93s/it]

0.9598740098022279


 78%|███████▊  | 47/60 [17:12<04:44, 21.91s/it]

0.9545187056064606


 80%|████████  | 48/60 [17:34<04:22, 21.92s/it]

0.9516917012986683


 82%|████████▏ | 49/60 [17:56<04:01, 21.92s/it]

0.949849511895861


 83%|████████▎ | 50/60 [18:18<03:39, 21.91s/it]

0.9472983777523041


 85%|████████▌ | 51/60 [18:40<03:17, 21.90s/it]

0.9464851447514125


 87%|████████▋ | 52/60 [19:02<02:55, 21.90s/it]

0.9416541528134119


 88%|████████▊ | 53/60 [19:24<02:33, 21.90s/it]

0.9383704052084968


 90%|█████████ | 54/60 [19:46<02:11, 21.91s/it]

0.9381255805492401


 92%|█████████▏| 55/60 [20:08<01:49, 21.91s/it]

0.9361832496665773


 93%|█████████▎| 56/60 [20:29<01:27, 21.91s/it]

0.9352504625206902


 95%|█████████▌| 57/60 [20:51<01:05, 21.90s/it]

0.9305114504836854


 97%|█████████▋| 58/60 [21:13<00:43, 21.93s/it]

0.9297141404378981


 98%|█████████▊| 59/60 [21:35<00:21, 21.92s/it]

0.9296242296695709


100%|██████████| 60/60 [21:57<00:00, 21.96s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9305388388179597


  2%|▏         | 1/60 [00:21<21:28, 21.84s/it]

0.9008874268758864


  3%|▎         | 2/60 [00:43<21:07, 21.85s/it]

0.8747224083968571


  5%|▌         | 3/60 [01:05<20:45, 21.85s/it]

0.8610311704022544


  7%|▋         | 4/60 [01:27<20:24, 21.86s/it]

0.8539035107408252


  8%|▊         | 5/60 [01:49<20:02, 21.87s/it]

0.8487513434319269


 10%|█         | 6/60 [02:11<19:40, 21.87s/it]

0.8434487240655082


 12%|█▏        | 7/60 [02:33<19:19, 21.87s/it]

0.8389440916833424


 13%|█▎        | 8/60 [02:54<18:57, 21.88s/it]

0.8357718657879603


 15%|█▌        | 9/60 [03:16<18:35, 21.87s/it]

0.8333643249103001


 17%|█▋        | 10/60 [03:38<18:13, 21.87s/it]

0.831424913236073


 18%|█▊        | 11/60 [04:00<17:52, 21.88s/it]

0.828879985071364


 20%|██        | 12/60 [04:22<17:30, 21.88s/it]

0.8272895770413535


 22%|██▏       | 13/60 [04:44<17:08, 21.87s/it]

0.8249103483699617


 23%|██▎       | 14/60 [05:06<16:47, 21.90s/it]

0.8219308952490488


 25%|██▌       | 15/60 [05:28<16:27, 21.93s/it]

0.821539336726779


 27%|██▋       | 16/60 [05:50<16:06, 21.96s/it]

0.8204955401874724


 28%|██▊       | 17/60 [06:12<15:44, 21.97s/it]

0.8185286933467502


 30%|███       | 18/60 [06:34<15:22, 21.97s/it]

0.8160626249653953


 32%|███▏      | 19/60 [06:56<15:01, 21.98s/it]

0.8153878762608483


 33%|███▎      | 20/60 [07:18<14:38, 21.97s/it]

0.8131035977885837


 35%|███▌      | 21/60 [07:40<14:16, 21.97s/it]

0.8116364393915448


 37%|███▋      | 22/60 [08:02<13:54, 21.96s/it]

0.8131488519055503


 38%|███▊      | 23/60 [08:24<13:32, 21.96s/it]

0.8089778366542998


 40%|████      | 24/60 [08:46<13:10, 21.97s/it]

0.8093410929044088


 42%|████▏     | 25/60 [09:08<12:48, 21.96s/it]

0.8081538066977546


 43%|████▎     | 26/60 [09:30<12:26, 21.97s/it]

0.8060482172738939


 45%|████▌     | 27/60 [09:51<12:03, 21.94s/it]

0.8060308269092015


 47%|████▋     | 28/60 [10:13<11:41, 21.92s/it]

0.8060106564135778


 48%|████▊     | 29/60 [10:35<11:18, 21.89s/it]

0.8047124998910087


 50%|█████     | 30/60 [10:57<10:56, 21.87s/it]

0.804511425041017


 52%|█████▏    | 31/60 [11:19<10:34, 21.87s/it]

0.8028279003642854


 53%|█████▎    | 32/60 [11:41<10:12, 21.89s/it]

0.8006198860350109


 55%|█████▌    | 33/60 [12:03<09:51, 21.89s/it]

0.8006663450172969


 57%|█████▋    | 34/60 [12:25<09:29, 21.90s/it]

0.7996082192375547


 58%|█████▊    | 35/60 [12:46<09:07, 21.91s/it]

0.7992022917384193


 60%|██████    | 36/60 [13:08<08:45, 21.92s/it]

0.797108109508242


 62%|██████▏   | 37/60 [13:30<08:24, 21.91s/it]

0.7968949519452595


 63%|██████▎   | 38/60 [13:52<08:02, 21.92s/it]

0.7972913795993442


 65%|██████▌   | 39/60 [14:14<07:40, 21.92s/it]

0.7956229405743735


 67%|██████▋   | 40/60 [14:36<07:18, 21.92s/it]

0.7956672594660804


 68%|██████▊   | 41/60 [14:58<06:56, 21.91s/it]

0.7936321880136218


 70%|███████   | 42/60 [15:20<06:34, 21.90s/it]

0.7936343011401948


 72%|███████▏  | 43/60 [15:42<06:12, 21.89s/it]

0.7918576158228374


 73%|███████▎  | 44/60 [16:04<05:50, 21.89s/it]

0.7930205946876889


 75%|███████▌  | 45/60 [16:26<05:28, 21.89s/it]

0.7919625512191227


 77%|███████▋  | 46/60 [16:47<05:06, 21.89s/it]

0.7911140478792644


 78%|███████▊  | 47/60 [17:09<04:44, 21.88s/it]

0.7908810291971479


 80%|████████  | 48/60 [17:31<04:22, 21.89s/it]

0.7892475199131739


 82%|████████▏ | 49/60 [17:53<04:00, 21.89s/it]

0.7897665287767138


 83%|████████▎ | 50/60 [18:15<03:38, 21.88s/it]

0.7885633664471763


 85%|████████▌ | 51/60 [18:37<03:16, 21.88s/it]

0.7881681762990498


 87%|████████▋ | 52/60 [18:59<02:55, 21.88s/it]

0.7865066485745567


 88%|████████▊ | 53/60 [19:21<02:33, 21.87s/it]

0.7846101848852067


 90%|█████████ | 54/60 [19:42<02:11, 21.88s/it]

0.7865396085239592


 92%|█████████▏| 55/60 [20:04<01:49, 21.89s/it]

0.7859937264805749


 93%|█████████▎| 56/60 [20:26<01:27, 21.89s/it]

0.7862327297528585


 95%|█████████▌| 57/60 [20:48<01:05, 21.89s/it]

0.7850077521233332


 97%|█████████▋| 58/60 [21:10<00:43, 21.89s/it]

0.7836205420039949


 98%|█████████▊| 59/60 [21:32<00:21, 21.88s/it]

0.7827000986962092


100%|██████████| 60/60 [21:54<00:00, 21.91s/it]

0.782227506240209
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [34]:
model_score3

{'Bleu_1': 0.6566475401825943,
 'Bleu_2': 0.5375875865786676,
 'Bleu_3': 0.45615388839730103,
 'Bleu_4': 0.3955415587750201,
 'METEOR': 0.3077360764321484,
 'ROUGE_L': 0.5643944449873991,
 'CIDEr': 2.2777023214350534,
 'SPICE': 0.40944634268215796,
 'USC_similarity': 0.6232238787242027}

In [35]:
caption_model4, model_score4 = cross_validation(cv[3][0], cv[3][1], 4, cnn_type)    

Split 4:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2680 ==> 894
The vocabulary size is 895.
809 out of 895 words are found in the pre-trained matrix.
The size of embedding_matrix is (895, 200)
Preparing dataloader...

Generating set took: 0:00:59.05


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:14.70
Training...


  2%|▏         | 1/60 [00:22<21:46, 22.14s/it]

4.524656778290158


  3%|▎         | 2/60 [00:44<21:23, 22.13s/it]

2.3696714355832054


  5%|▌         | 3/60 [01:06<20:59, 22.10s/it]

1.753268871988569


  7%|▋         | 4/60 [01:28<20:37, 22.10s/it]

1.538925031820933


  8%|▊         | 5/60 [01:50<20:15, 22.09s/it]

1.436170518398285


 10%|█         | 6/60 [02:12<19:52, 22.09s/it]

1.3641809366998219


 12%|█▏        | 7/60 [02:34<19:29, 22.07s/it]

1.3119671259607588


 13%|█▎        | 8/60 [02:56<19:07, 22.06s/it]

1.2748460343905859


 15%|█▌        | 9/60 [03:18<18:44, 22.05s/it]

1.2456788959957303


 17%|█▋        | 10/60 [03:40<18:22, 22.05s/it]

1.2150958606175013


 18%|█▊        | 11/60 [04:02<18:00, 22.05s/it]

1.1891446028436934


 20%|██        | 12/60 [04:24<17:38, 22.05s/it]

1.1689676358586265


 22%|██▏       | 13/60 [04:46<17:15, 22.04s/it]

1.1511814423969813


 23%|██▎       | 14/60 [05:08<16:53, 22.04s/it]

1.1371679788544065


 25%|██▌       | 15/60 [05:30<16:31, 22.04s/it]

1.126939838840848


 27%|██▋       | 16/60 [05:52<16:09, 22.03s/it]

1.1153036526271276


 28%|██▊       | 17/60 [06:14<15:47, 22.03s/it]

1.1050325291497367


 30%|███       | 18/60 [06:36<15:25, 22.03s/it]

1.0942930011522203


 32%|███▏      | 19/60 [06:58<15:03, 22.03s/it]

1.0821788154897236


 33%|███▎      | 20/60 [07:21<14:41, 22.04s/it]

1.070698312350682


 35%|███▌      | 21/60 [07:43<14:19, 22.04s/it]

1.0615025276229495


 37%|███▋      | 22/60 [08:05<13:57, 22.04s/it]

1.057881045909155


 38%|███▊      | 23/60 [08:27<13:35, 22.04s/it]

1.0519448121388753


 40%|████      | 24/60 [08:49<13:13, 22.04s/it]

1.0454213746956416


 42%|████▏     | 25/60 [09:11<12:51, 22.04s/it]

1.0400770604610443


 43%|████▎     | 26/60 [09:33<12:29, 22.04s/it]

1.0392840930393763


 45%|████▌     | 27/60 [09:55<12:07, 22.04s/it]

1.0335444921538943


 47%|████▋     | 28/60 [10:17<11:44, 22.03s/it]

1.026696210815793


 48%|████▊     | 29/60 [10:39<11:23, 22.04s/it]

1.027829871291206


 50%|█████     | 30/60 [11:01<11:01, 22.04s/it]

1.0236217691784812


 52%|█████▏    | 31/60 [11:23<10:39, 22.05s/it]

1.0228924793856484


 53%|█████▎    | 32/60 [11:45<10:17, 22.05s/it]

1.0155309438705444


 55%|█████▌    | 33/60 [12:07<09:55, 22.05s/it]

1.0113423551831926


 57%|█████▋    | 34/60 [12:29<09:33, 22.05s/it]

1.0118511730716342


 58%|█████▊    | 35/60 [12:51<09:11, 22.04s/it]

1.0063917892319816


 60%|██████    | 36/60 [13:13<08:49, 22.04s/it]

0.9998581679094405


 62%|██████▏   | 37/60 [13:35<08:27, 22.04s/it]

0.9935833769185203


 63%|██████▎   | 38/60 [13:57<08:04, 22.04s/it]

0.9932391785439991


 65%|██████▌   | 39/60 [14:19<07:42, 22.03s/it]

0.9913092411699749


 67%|██████▋   | 40/60 [14:41<07:20, 22.04s/it]

0.9859524653071449


 68%|██████▊   | 41/60 [15:03<06:58, 22.04s/it]

0.981565405925115


 70%|███████   | 42/60 [15:25<06:36, 22.04s/it]

0.9755377414680663


 72%|███████▏  | 43/60 [15:47<06:14, 22.04s/it]

0.973867879027412


 73%|███████▎  | 44/60 [16:10<05:52, 22.04s/it]

0.973363329966863


 75%|███████▌  | 45/60 [16:32<05:30, 22.05s/it]

0.9711949002175104


 77%|███████▋  | 46/60 [16:54<05:08, 22.05s/it]

0.9661914010842642


 78%|███████▊  | 47/60 [17:16<04:46, 22.06s/it]

0.9664839889322009


 80%|████████  | 48/60 [17:38<04:24, 22.07s/it]

0.966542260987418


 82%|████████▏ | 49/60 [18:00<04:02, 22.01s/it]

0.9659648935000101


 83%|████████▎ | 50/60 [18:22<03:39, 21.98s/it]

0.964216560125351


 85%|████████▌ | 51/60 [18:44<03:17, 21.96s/it]

0.9607230510030474


 87%|████████▋ | 52/60 [19:05<02:55, 21.94s/it]

0.9577195204439617


 88%|████████▊ | 53/60 [19:27<02:33, 21.93s/it]

0.9587147377786183


 90%|█████████ | 54/60 [19:49<02:11, 21.94s/it]

0.9583637118339539


 92%|█████████▏| 55/60 [20:11<01:49, 21.93s/it]

0.9560202047938392


 93%|█████████▎| 56/60 [20:33<01:27, 21.92s/it]

0.9541247742516654


 95%|█████████▌| 57/60 [20:55<01:05, 21.93s/it]

0.9539612063339779


 97%|█████████▋| 58/60 [21:17<00:43, 21.93s/it]

0.9540443675858634


 98%|█████████▊| 59/60 [21:39<00:21, 21.92s/it]

0.9524621906734648


100%|██████████| 60/60 [22:01<00:00, 22.02s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9499630757740566


  2%|▏         | 1/60 [00:21<21:34, 21.95s/it]

0.9228922994363875


  3%|▎         | 2/60 [00:44<21:14, 21.98s/it]

0.8978870965185619


  5%|▌         | 3/60 [01:06<20:54, 22.00s/it]

0.886759224392119


  7%|▋         | 4/60 [01:28<20:33, 22.03s/it]

0.8800969507013049


  8%|▊         | 5/60 [01:50<20:12, 22.04s/it]

0.8739892130806333


 10%|█         | 6/60 [02:12<19:50, 22.04s/it]

0.8720885685511998


 12%|█▏        | 7/60 [02:34<19:28, 22.05s/it]

0.868764457248506


 13%|█▎        | 8/60 [02:56<19:06, 22.05s/it]

0.864698079370317


 15%|█▌        | 9/60 [03:18<18:44, 22.05s/it]

0.8630407381625402


 17%|█▋        | 10/60 [03:40<18:20, 22.00s/it]

0.8616183570453099


 18%|█▊        | 11/60 [04:02<17:56, 21.97s/it]

0.8591486669722057


 20%|██        | 12/60 [04:24<17:33, 21.95s/it]

0.8572620891389393


 22%|██▏       | 13/60 [04:46<17:10, 21.93s/it]

0.8552719893909636


 23%|██▎       | 14/60 [05:07<16:48, 21.92s/it]

0.8531312020052046


 25%|██▌       | 15/60 [05:29<16:25, 21.91s/it]

0.8513490089348384


 27%|██▋       | 16/60 [05:51<16:04, 21.91s/it]

0.8504473396709987


 28%|██▊       | 17/60 [06:13<15:41, 21.90s/it]

0.8489376860005515


 30%|███       | 18/60 [06:35<15:19, 21.90s/it]

0.8473248439175742


 32%|███▏      | 19/60 [06:57<14:57, 21.89s/it]

0.8460568459261031


 33%|███▎      | 20/60 [07:19<14:35, 21.89s/it]

0.8477120385283515


 35%|███▌      | 21/60 [07:41<14:14, 21.92s/it]

0.8451005362329029


 37%|███▋      | 22/60 [08:03<13:53, 21.93s/it]

0.8423351517745427


 38%|███▊      | 23/60 [08:25<13:31, 21.92s/it]

0.8421039822555724


 40%|████      | 24/60 [08:46<13:08, 21.92s/it]

0.8411560555299123


 42%|████▏     | 25/60 [09:08<12:46, 21.90s/it]

0.8410232734112513


 43%|████▎     | 26/60 [09:30<12:24, 21.90s/it]

0.8400691350301107


 45%|████▌     | 27/60 [09:52<12:02, 21.89s/it]

0.8388725306306567


 47%|████▋     | 28/60 [10:14<11:40, 21.88s/it]

0.8375506897767385


 48%|████▊     | 29/60 [10:36<11:18, 21.89s/it]

0.8386031885941824


 50%|█████     | 30/60 [10:58<10:56, 21.89s/it]

0.8372993511813027


 52%|█████▏    | 31/60 [11:20<10:34, 21.90s/it]

0.8357438033535367


 53%|█████▎    | 32/60 [11:42<10:14, 21.96s/it]

0.8357048829396566


 55%|█████▌    | 33/60 [12:04<09:53, 22.00s/it]

0.8331249370461419


 57%|█████▋    | 34/60 [12:26<09:32, 22.03s/it]

0.8343332565966106


 58%|█████▊    | 35/60 [12:48<09:11, 22.04s/it]

0.8328560363678705


 60%|██████    | 36/60 [13:10<08:49, 22.06s/it]

0.8322514494260153


 62%|██████▏   | 37/60 [13:32<08:27, 22.06s/it]

0.832904577255249


 63%|██████▎   | 38/60 [13:54<08:05, 22.06s/it]

0.8315117274011884


 65%|██████▌   | 39/60 [14:16<07:43, 22.07s/it]

0.8308784450803485


 67%|██████▋   | 40/60 [14:38<07:21, 22.06s/it]

0.8302007544608343


 68%|██████▊   | 41/60 [15:00<06:58, 22.01s/it]

0.8301921912602016


 70%|███████   | 42/60 [15:22<06:35, 21.97s/it]

0.8298340113390059


 72%|███████▏  | 43/60 [15:44<06:13, 21.96s/it]

0.8284914663859776


 73%|███████▎  | 44/60 [16:06<05:51, 21.94s/it]

0.8280922302177974


 75%|███████▌  | 45/60 [16:28<05:28, 21.92s/it]

0.8274092901320684


 77%|███████▋  | 46/60 [16:50<05:06, 21.90s/it]

0.8263722615582603


 78%|███████▊  | 47/60 [17:12<04:44, 21.88s/it]

0.8261301602636065


 80%|████████  | 48/60 [17:33<04:22, 21.87s/it]

0.8260313513733092


 82%|████████▏ | 49/60 [17:55<04:00, 21.86s/it]

0.8247554358981904


 83%|████████▎ | 50/60 [18:17<03:38, 21.84s/it]

0.8251616940611884


 85%|████████▌ | 51/60 [18:39<03:16, 21.84s/it]

0.824188483612878


 87%|████████▋ | 52/60 [19:01<02:54, 21.85s/it]

0.8231259030955178


 88%|████████▊ | 53/60 [19:23<02:32, 21.86s/it]

0.8226834109851292


 90%|█████████ | 54/60 [19:45<02:11, 21.89s/it]

0.823384648277646


 92%|█████████▏| 55/60 [20:06<01:49, 21.87s/it]

0.8215176519893465


 93%|█████████▎| 56/60 [20:28<01:27, 21.85s/it]

0.8207750916481018


 95%|█████████▌| 57/60 [20:50<01:05, 21.84s/it]

0.8200745582580566


 97%|█████████▋| 58/60 [21:12<00:43, 21.83s/it]

0.8215714352471488


 98%|█████████▊| 59/60 [21:34<00:21, 21.83s/it]

0.8192393864904132


100%|██████████| 60/60 [21:56<00:00, 21.93s/it]

0.821579518772307
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [36]:
model_score4

{'Bleu_1': 0.6382132200017796,
 'Bleu_2': 0.5153082287312921,
 'Bleu_3': 0.43369952262237005,
 'Bleu_4': 0.3739999240790734,
 'METEOR': 0.3023109077695267,
 'ROUGE_L': 0.5484960378579745,
 'CIDEr': 2.0821031037948603,
 'SPICE': 0.39612244416504777,
 'USC_similarity': 0.6113672645024837}

In [37]:
caption_model5, model_score5 = cross_validation(cv[4][0], cv[4][1], 5, cnn_type)    

Split 5:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2657 ==> 905
The vocabulary size is 906.
818 out of 906 words are found in the pre-trained matrix.
The size of embedding_matrix is (906, 200)
Preparing dataloader...

Generating set took: 0:00:57.92


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:14.36
Training...


  2%|▏         | 1/60 [00:22<21:46, 22.14s/it]

4.838801452091762


  3%|▎         | 2/60 [00:44<21:23, 22.12s/it]

2.3379714517366317


  5%|▌         | 3/60 [01:06<21:00, 22.12s/it]

1.7333938451040358


  7%|▋         | 4/60 [01:28<20:38, 22.12s/it]

1.515297878356207


  8%|▊         | 5/60 [01:50<20:16, 22.12s/it]

1.4037101637749445


 10%|█         | 6/60 [02:12<19:54, 22.12s/it]

1.337876095658257


 12%|█▏        | 7/60 [02:34<19:32, 22.13s/it]

1.2925049747739519


 13%|█▎        | 8/60 [02:56<19:10, 22.13s/it]

1.2567259073257446


 15%|█▌        | 9/60 [03:19<18:48, 22.14s/it]

1.222508035954975


 17%|█▋        | 10/60 [03:41<18:26, 22.14s/it]

1.1904663074584234


 18%|█▊        | 11/60 [04:03<18:04, 22.13s/it]

1.1677946533475603


 20%|██        | 12/60 [04:25<17:42, 22.13s/it]

1.148976865268889


 22%|██▏       | 13/60 [04:47<17:19, 22.12s/it]

1.134375396228972


 23%|██▎       | 14/60 [05:09<16:57, 22.11s/it]

1.1149283406280337


 25%|██▌       | 15/60 [05:31<16:35, 22.12s/it]

1.1028622488180797


 27%|██▋       | 16/60 [05:53<16:13, 22.12s/it]

1.0888485369228182


 28%|██▊       | 17/60 [06:16<15:51, 22.13s/it]

1.080500093244371


 30%|███       | 18/60 [06:38<15:27, 22.10s/it]

1.0694863696893055


 32%|███▏      | 19/60 [07:00<15:04, 22.06s/it]

1.0618281123184024


 33%|███▎      | 20/60 [07:22<14:41, 22.04s/it]

1.0492907265822093


 35%|███▌      | 21/60 [07:44<14:18, 22.02s/it]

1.040775090456009


 37%|███▋      | 22/60 [08:06<13:56, 22.00s/it]

1.030882587035497


 38%|███▊      | 23/60 [08:28<13:33, 22.00s/it]

1.0233816986992246


 40%|████      | 24/60 [08:50<13:11, 22.00s/it]

1.0187850991884868


 42%|████▏     | 25/60 [09:12<12:49, 22.00s/it]

1.0148646221274422


 43%|████▎     | 26/60 [09:34<12:27, 21.99s/it]

1.014155136687415


 45%|████▌     | 27/60 [09:55<12:05, 21.99s/it]

1.0095178655215673


 47%|████▋     | 28/60 [10:17<11:43, 21.98s/it]

1.0079318072114671


 48%|████▊     | 29/60 [10:39<11:21, 21.98s/it]

1.0062846555596305


 50%|█████     | 30/60 [11:01<10:59, 21.98s/it]

0.9982536704767317


 52%|█████▏    | 31/60 [11:23<10:37, 21.97s/it]

0.9970422912211645


 53%|█████▎    | 32/60 [11:45<10:15, 21.97s/it]

0.9929664816175189


 55%|█████▌    | 33/60 [12:07<09:53, 21.97s/it]

0.981127115942183


 57%|█████▋    | 34/60 [12:29<09:31, 21.97s/it]

0.976909122296742


 58%|█████▊    | 35/60 [12:51<09:09, 21.98s/it]

0.9717638265518915


 60%|██████    | 36/60 [13:13<08:47, 21.98s/it]

0.9661996889682043


 62%|██████▏   | 37/60 [13:35<08:25, 21.99s/it]

0.9646347179299309


 63%|██████▎   | 38/60 [13:57<08:03, 21.99s/it]

0.9581449187937237


 65%|██████▌   | 39/60 [14:19<07:42, 22.01s/it]

0.9542715294020516


 67%|██████▋   | 40/60 [14:41<07:19, 21.99s/it]

0.9514662779512859


 68%|██████▊   | 41/60 [15:03<06:57, 21.98s/it]

0.9460805342310951


 70%|███████   | 42/60 [15:25<06:35, 21.96s/it]

0.9433580353146508


 72%|███████▏  | 43/60 [15:47<06:13, 21.96s/it]

0.9372789093426296


 73%|███████▎  | 44/60 [16:09<05:51, 21.96s/it]

0.9359422155788967


 75%|███████▌  | 45/60 [16:31<05:29, 21.95s/it]

0.9320831511701856


 77%|███████▋  | 46/60 [16:53<05:07, 21.94s/it]

0.9309984374613989


 78%|███████▊  | 47/60 [17:15<04:45, 21.94s/it]

0.9289289585181645


 80%|████████  | 48/60 [17:37<04:23, 21.94s/it]

0.92461116257168


 82%|████████▏ | 49/60 [17:59<04:01, 21.96s/it]

0.9260012989952451


 83%|████████▎ | 50/60 [18:21<03:40, 22.03s/it]

0.9239618650504521


 85%|████████▌ | 51/60 [18:43<03:18, 22.05s/it]

0.9183791208834875


 87%|████████▋ | 52/60 [19:05<02:56, 22.04s/it]

0.9192380408445994


 88%|████████▊ | 53/60 [19:27<02:33, 22.00s/it]

0.9160275757312775


 90%|█████████ | 54/60 [19:49<02:11, 21.97s/it]

0.9180990102745238


 92%|█████████▏| 55/60 [20:11<01:49, 21.96s/it]

0.9167508553890955


 93%|█████████▎| 56/60 [20:33<01:27, 21.95s/it]

0.9142692614169348


 95%|█████████▌| 57/60 [20:55<01:05, 21.94s/it]

0.911185010558083


 97%|█████████▋| 58/60 [21:17<00:43, 21.93s/it]

0.9083816025938306


 98%|█████████▊| 59/60 [21:38<00:21, 21.92s/it]

0.9084105392297109


100%|██████████| 60/60 [22:00<00:00, 22.01s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9092126034554981


  2%|▏         | 1/60 [00:21<21:34, 21.95s/it]

0.8821543057759603


  3%|▎         | 2/60 [00:43<21:12, 21.94s/it]

0.8528683057853154


  5%|▌         | 3/60 [01:05<20:50, 21.94s/it]

0.8419675968942189


  7%|▋         | 4/60 [01:27<20:28, 21.94s/it]

0.8349616570132119


  8%|▊         | 5/60 [01:49<20:07, 21.95s/it]

0.8288190606094542


 10%|█         | 6/60 [02:11<19:44, 21.94s/it]

0.8236469328403473


 12%|█▏        | 7/60 [02:33<19:23, 21.95s/it]

0.821737847157887


 13%|█▎        | 8/60 [02:55<19:01, 21.95s/it]

0.8161639627956209


 15%|█▌        | 9/60 [03:17<18:39, 21.95s/it]

0.8147064575127193


 17%|█▋        | 10/60 [03:39<18:17, 21.94s/it]

0.8131235142548879


 18%|█▊        | 11/60 [04:01<17:55, 21.95s/it]

0.8098390442984444


 20%|██        | 12/60 [04:23<17:33, 21.95s/it]

0.8057273101238978


 22%|██▏       | 13/60 [04:45<17:12, 21.98s/it]

0.8072127217338199


 23%|██▎       | 14/60 [05:07<16:50, 21.97s/it]

0.803169744355338


 25%|██▌       | 15/60 [05:29<16:28, 21.96s/it]

0.8020815366790408


 27%|██▋       | 16/60 [05:51<16:05, 21.95s/it]

0.8003229413713727


 28%|██▊       | 17/60 [06:13<15:46, 22.01s/it]

0.7991844302132016


 30%|███       | 18/60 [06:35<15:25, 22.05s/it]

0.7976007773762658


 32%|███▏      | 19/60 [06:57<15:05, 22.08s/it]

0.7960302758784521


 33%|███▎      | 20/60 [07:19<14:44, 22.10s/it]

0.7951581137520927


 35%|███▌      | 21/60 [07:41<14:21, 22.10s/it]

0.7941734265713465


 37%|███▋      | 22/60 [08:03<13:58, 22.06s/it]

0.7937051483563015


 38%|███▊      | 23/60 [08:25<13:35, 22.03s/it]

0.7909420913174039


 40%|████      | 24/60 [08:47<13:12, 22.01s/it]

0.7898180598304385


 42%|████▏     | 25/60 [09:09<12:50, 22.00s/it]

0.7895770072937012


 43%|████▎     | 26/60 [09:31<12:27, 21.98s/it]

0.7895492670081911


 45%|████▌     | 27/60 [09:53<12:04, 21.96s/it]

0.7876414571489606


 47%|████▋     | 28/60 [10:15<11:42, 21.96s/it]

0.7873603375185103


 48%|████▊     | 29/60 [10:37<11:20, 21.94s/it]

0.7861098973524003


 50%|█████     | 30/60 [10:59<10:58, 21.94s/it]

0.7855975869156065


 52%|█████▏    | 31/60 [11:21<10:36, 21.94s/it]

0.7841481807686034


 53%|█████▎    | 32/60 [11:43<10:14, 21.93s/it]

0.785013564995357


 55%|█████▌    | 33/60 [12:05<09:52, 21.93s/it]

0.7828057819888705


 57%|█████▋    | 34/60 [12:27<09:30, 21.94s/it]

0.7809745598407019


 58%|█████▊    | 35/60 [12:49<09:08, 21.94s/it]

0.7801248317673093


 60%|██████    | 36/60 [13:11<08:46, 21.93s/it]

0.7803583684421721


 62%|██████▏   | 37/60 [13:32<08:24, 21.93s/it]

0.7795035910038721


 63%|██████▎   | 38/60 [13:54<08:02, 21.93s/it]

0.7778993092832112


 65%|██████▌   | 39/60 [14:16<07:40, 21.93s/it]

0.7783557176589966


 67%|██████▋   | 40/60 [14:38<07:18, 21.93s/it]

0.7779331391765958


 68%|██████▊   | 41/60 [15:00<06:56, 21.95s/it]

0.7768807751791817


 70%|███████   | 42/60 [15:22<06:35, 21.96s/it]

0.7752052744229635


 72%|███████▏  | 43/60 [15:44<06:13, 21.97s/it]

0.7760604591596694


 73%|███████▎  | 44/60 [16:06<05:51, 21.97s/it]

0.773012447924841


 75%|███████▌  | 45/60 [16:28<05:29, 21.99s/it]

0.7746290521962302


 77%|███████▋  | 46/60 [16:50<05:07, 21.98s/it]

0.7738249713466281


 78%|███████▊  | 47/60 [17:12<04:45, 21.98s/it]

0.7719533372493017


 80%|████████  | 48/60 [17:34<04:23, 21.97s/it]

0.7727061779726119


 82%|████████▏ | 49/60 [17:56<04:01, 21.97s/it]

0.7736426904087975


 83%|████████▎ | 50/60 [18:18<03:39, 21.98s/it]

0.7707710791201818


 85%|████████▌ | 51/60 [18:40<03:17, 21.98s/it]

0.7708293469179244


 87%|████████▋ | 52/60 [19:02<02:55, 21.97s/it]

0.7703996215547834


 88%|████████▊ | 53/60 [19:24<02:33, 21.98s/it]

0.7680891227154505


 90%|█████████ | 54/60 [19:46<02:11, 21.98s/it]

0.7689832463150933


 92%|█████████▏| 55/60 [20:08<01:49, 21.98s/it]

0.7680144636403947


 93%|█████████▎| 56/60 [20:30<01:27, 21.97s/it]

0.7683932852177393


 95%|█████████▌| 57/60 [20:52<01:05, 21.96s/it]

0.7678832738172441


 97%|█████████▋| 58/60 [21:14<00:43, 21.97s/it]

0.7672300948983147


 98%|█████████▊| 59/60 [21:36<00:21, 21.97s/it]

0.766958067814509


100%|██████████| 60/60 [21:58<00:00, 21.97s/it]

0.7648609975973765
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [38]:
model_score5

{'Bleu_1': 0.647497239602473,
 'Bleu_2': 0.5274467806388617,
 'Bleu_3': 0.44626571637087664,
 'Bleu_4': 0.3860139510836923,
 'METEOR': 0.30408768113263684,
 'ROUGE_L': 0.5520025665847783,
 'CIDEr': 2.144824139337835,
 'SPICE': 0.39725981423602585,
 'USC_similarity': 0.6132966706494385}

In [39]:
model_scores = defaultdict(list)
for scores in [model_score1, model_score2, model_score3, model_score4, model_score5]:
    for key, value in scores.items():
        model_scores[key].append(value)

In [40]:
model_scores

defaultdict(list,
            {'Bleu_1': [0.6571629575623202,
              0.6352914180816597,
              0.6566475401825943,
              0.6382132200017796,
              0.647497239602473],
             'Bleu_2': [0.5362381576774322,
              0.5113443206509108,
              0.5375875865786676,
              0.5153082287312921,
              0.5274467806388617],
             'Bleu_3': [0.45560464775875736,
              0.4287394882419803,
              0.45615388839730103,
              0.43369952262237005,
              0.44626571637087664],
             'Bleu_4': [0.3964644970967075,
              0.36894999565248604,
              0.3955415587750201,
              0.3739999240790734,
              0.3860139510836923],
             'METEOR': [0.3079667441267819,
              0.2968264250708283,
              0.3077360764321484,
              0.3023109077695267,
              0.30408768113263684],
             'ROUGE_L': [0.56041962993465,
              0.5391833235204

In [41]:
tag = '9.1.3.1'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'w') as fp:
    json.dump(model_scores, fp)