## Image Captioning with Pytorch

The following contents are modified from MDS DSCI 575 lecture 8 demo

In [1]:
import os, sys, json
from collections import defaultdict
from tqdm import tqdm
import pickle
from time import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, datasets
from torchsummary import summary
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

from nltk.translate import bleu_score
from sklearn.model_selection import KFold

sys.path.append('../../scr/evaluation')
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.usc_sim.usc_sim import usc_sim
import subprocess


START = "startseq"
STOP = "endseq"
EPOCHS = 10
AWS = True


In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
# torch.cuda.empty_cache()
# import gc 
# gc.collect()

In [4]:
# Nicely formatted time string
def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m:>02}:{s:>05.2f}"
        
if AWS:
    root_captioning = "../../data"
else:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        root_captioning = "/content/drive/My Drive/data"
        COLAB = True
        print("Note: using Google CoLab")
    except:
        print("Note: not using Google CoLab")
        COLAB = False

### Clean/Build Dataset

- Read captions
- Preprocess captions


In [5]:
def get_img_info(name, num=np.inf):
    """
    Returns img paths and captions

    Parameters:
    -----------
    name: str
        the json file name
    num: int (default: np.inf)
        the number of observations to get

    Return:
    --------
    list, dict, int
        img paths, corresponding captions, max length of captions
    """
    img_path = []
    caption = [] 
    max_length = 0
    if AWS:
        with open(f'{root_captioning}/json/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for filename in data.keys():
                if num is not None and len(caption) == num:
                    break
                img_path.append(
                    f'{root_captioning}/{name}/{filename}'
                )
                sen_list = []
                for sentence in data[filename]['sentences']:
                    max_length = max(max_length, len(sentence['tokens']))
                    sen_list.append(sentence['raw'])

                caption.append(sen_list)    
    else:            
        with open(f'{root_captioning}/interim/{name}.json', 'r') as json_data:
            data = json.load(json_data)
            for set_name in ['rsicd', 'ucm']:
                for filename in data[set_name].keys():
                    if num is not None and len(caption) == num:
                        break

                    img_path.append(
                        f'{root_captioning}/raw/imgs/{set_name}/{filename}'
                    )
                    sen_list = []
                    for sentence in data[set_name][filename]['sentences']:
                        max_length = max(max_length, len(sentence['tokens']))
                        sen_list.append(sentence['raw'])

                    caption.append(sen_list)
    
    return img_path, caption, max_length            


In [6]:
# get img path and caption list
# # only test 800 train samples and 200 valid samples
# train_paths, train_descriptions, max_length_train = get_img_info('train', 800)
# test_paths, test_descriptions, max_length_test = get_img_info('valid', 200)

train_paths, train_descriptions, max_length_train = get_img_info('train')
test_paths, test_descriptions, max_length_test = get_img_info('valid')
max_length = max(max_length_train, max_length_test)



In [7]:
all_paths = train_paths.copy()
all_paths.extend(test_paths.copy())
all_paths = np.array(all_paths)

all_descriptions = train_descriptions.copy()
all_descriptions.extend(test_descriptions.copy())
all_descriptions = np.array(all_descriptions)

captions = all_descriptions.copy()
max_length_all = max(max_length_train, max_length_test)
max_length = max_length_all + 2
      
lex = set()
for sen in all_descriptions:
    [lex.update(d.split()) for d in sen]
    
# add a start and stop token at the beginning/end
for v in all_descriptions:
    for d in range(len(v)):
        v[d] = f'{START} {v[d]} {STOP}'
        
print(f'There are {len(all_paths)} images') 
print(f'There are {len(lex)} unique words (vocab)')
print(f'The maximum length of captions with start and stop token is {max_length}.')


There are 10416 images
There are 2912 unique words (vocab)
The maximum length of captions with start and stop token is 36.


In [8]:
all_paths[-1]

'../../s3/valid/rsicd_park_33.jpg'

In [9]:
all_descriptions[-1]

array(['startseq a vast artificial lake was built in the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq there are many residential areas near the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq',
       'startseq a vast artificial lake was built in the park . endseq'],
      dtype='<U184')

### Loading Glove Embeddings

In [10]:
embeddings_index = {} 
path = os.path.join(root_captioning, 'glove.6B.200d.txt') if AWS\
else os.path.join(root_captioning, 'raw', 'glove.6B.200d.txt')

f = open(
    path, 
    encoding="utf-8"
)

for line in tqdm(f):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs

f.close()
print(f'Found {len(embeddings_index)} word vectors.')

400000it [00:22, 17923.76it/s]

Found 400000 word vectors.





In [11]:
def get_vocab(descriptions, word_count_threshold=10):

    captions = []
    for val in descriptions:
        for cap in val:
            captions.append(cap)
    print(f'There are {len(captions)} captions')
    
    word_counts = {}
    nsents = 0
    for sent in captions:
        nsents += 1
        for w in sent.split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print('preprocessed words %d ==> %d' % (len(word_counts), len(vocab)))
    return vocab

def get_word_dict(vocab):
    
    idxtoword = {}
    wordtoidx = {}

    ix = 1
    for w in vocab:
        wordtoidx[w] = ix
        idxtoword[ix] = w
        ix += 1

    return idxtoword, wordtoidx

def get_vocab_size(idxtoword):
    
    print(f'The vocabulary size is {len(idxtoword) + 1}.')
    return len(idxtoword) + 1


def get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx):

    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    count =0

    for word, i in wordtoidx.items():

        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            count += 1
            # Words not found in the embedding index will be all zeros
            embedding_matrix[i] = embedding_vector
            
    print(f'{count} out of {vocab_size} words are found in the pre-trained matrix.')            
    print(f'The size of embedding_matrix is {embedding_matrix.shape}')
    return embedding_matrix

### Building the Neural Network

An embedding matrix is built from Glove.  This will be directly copied to the weight matrix of the neural network.

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
class CNNModel(nn.Module):

    def __init__(self, cnn_type, pretrained=True):
        """
        Initializes a CNNModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        pretrained: bool (default: True)
            use pretrained model if True

        """

        super(CNNModel, self).__init__()

        if cnn_type == 'vgg16':
            self.model = models.vgg16(pretrained=pretrained)

            # remove the last two layers in classifier
            self.model.classifier = nn.Sequential(
              *list(self.model.classifier.children())[:-2]
            )
            self.input_size = 224     

        # inception v3 expects (299, 299) sized images
        elif cnn_type == 'inception_v3':
            self.model = models.inception_v3(pretrained=pretrained)
            # remove the classification layer
            self.model.fc = nn.Identity()

            # turn off auxiliary output
            self.model.aux_logits = False
            self.input_size = 299

        else:
            raise Exception("Please choose between 'vgg16' and 'inception_v3'.")

    def forward(self, img_input, train=False):
        """
        forward of the CNNModel

        Parameters:
        -----------
        img_input: torch.Tensor
            the image matrix
        train: bool (default: False)
            use the model only for feature extraction if False

        Return:
        --------
        torch.Tensor
            image feature matrix
        """
        if not train:
            # set the model to evaluation model
            self.model.eval()

        return self.model(img_input)

In [14]:
class RNNModel(nn.Module):

    def __init__(
        self, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):
      
        """
        Initializes a RNNModel

        Parameters:
        -----------
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """

        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        if embedding_matrix is not None:

            self.embedding.load_state_dict({
              'weight': torch.FloatTensor(embedding_matrix)
            })
            self.embedding.weight.requires_grad = embedding_train

        self.dropout = nn.Dropout(p=0.5)

        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
 

    def forward(self, captions):
        """
        forward of the RNNModel

        Parameters:
        -----------
        captions: torch.Tensor
            the padded caption matrix

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        # embed the captions
        embedding = self.dropout(self.embedding(captions))

        outputs, (h, c) = self.lstm(embedding)

        return outputs, (h, c)



In [15]:
class CaptionModel(nn.Module):

    def __init__(
        self, 
        cnn_type, 
        vocab_size, 
        embedding_dim, 
        hidden_size=256,
        embedding_matrix=None, 
        embedding_train=False
    ):

        """
        Initializes a CaptionModel

        Parameters:
        -----------
        cnn_type: str
            the CNN type, either 'vgg16' or 'inception_v3'
        vocab_size: int
            the size of the vocabulary
        embedding_dim: int
            the number of features in the embedding matrix
        feature_size: int
            the number of features in the image matrix
        hidden_size: int (default: 256)
            the size of the hidden state in LSTM
        embedding_matrix: torch.Tensor (default: None)
            if not None, use this matrix as the embedding matrix
        embedding_train: bool (default: False)
            not train the embedding matrix if False
        """    
        super(CaptionModel, self).__init__() 

        # set feature_size based on cnn_type
        if cnn_type == 'vgg16':
            self.feature_size = 4096
        elif cnn_type == 'inception_v3':
            self.feature_size = 2048
        else:
            raise Exception("Please choose between 'vgg16' and 'inception_v3'.")  

        self.decoder = RNNModel(
            vocab_size, 
            embedding_dim,
            hidden_size,
            embedding_matrix,
            embedding_train
        )
        
        self.dropout = nn.Dropout(p=0.5)
        self.dense1 = nn.Linear(self.feature_size, hidden_size) 
        self.relu1 = nn.ReLU()
          
        self.dense2 = nn.Linear(hidden_size, hidden_size) 
        self.relu2 = nn.ReLU()
        self.dense3 = nn.Linear(hidden_size, vocab_size) 

    def forward(self, img_features, captions):
        """
        forward of the CaptionModel

        Parameters:
        -----------
        img_features: torch.Tensor
            the image feature matrix
        captions: torch.Tensor
            the padded caption matrix

        Return:
        --------
        torch.Tensor
            word probabilities for each position
        """

        img_features =\
        self.relu1(
            self.dense1(
                self.dropout(
                    img_features
                )
            )
        )

        decoder_out, _ = self.decoder(captions)

        # add up decoder outputs and image features
        outputs =\
        self.dense3(
            self.relu2(
                self.dense2(
                    decoder_out.add(
                        (img_features.view(img_features.size(0), 1, -1))\
                        .repeat(1, decoder_out.size(1), 1)
                    )
                )
            )
        )

        return outputs

### Train the Neural Network

In [16]:
def train(model, iterator, optimizer, criterion, clip, vocab_size):
    """
    train the CaptionModel

    Parameters:
    -----------
    model: CaptionModel
        a CaptionModel instance
    iterator: torch.utils.data.dataloader
        a PyTorch dataloader
    optimizer: torch.optim
        a PyTorch optimizer 
    criterion: nn.CrossEntropyLoss
        a PyTorch criterion 

    Return:
    --------
    float
        average loss
    """
    model.train()    
    epoch_loss = 0
    
    for img_features, captions in iterator:
        
        optimizer.zero_grad()

        # for each caption, the end word is not passed for training
        outputs = model(
            img_features.to(device),
            captions[:, :-1].to(device)
        )

        loss = criterion(
            outputs.view(-1, vocab_size), 
            captions[:, 1:].flatten().to(device)
        )
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        
    return epoch_loss / len(iterator)

In [17]:
class SampleDataset(Dataset):
    def __init__(
        self,
        descriptions,
        imgs,
        wordtoidx,
        max_length
        ):
        """
        Initializes a SampleDataset

        Parameters:
        -----------
        descriptions: list
            a list of captions
        imgs: numpy.ndarray
            the image features
        wordtoidx: dict
            the dict to get word index
        max_length: int
            all captions will be padded to this size
        """        
        self.imgs = imgs
        self.descriptions = descriptions
        self.wordtoidx = wordtoidx
        self.max_length = max_length

    def __len__(self):
        """
        Returns the batch size

        Return:
        --------
        int
            the batch size
        """
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Prepare data for each image

        Parameters:
        -----------
        idx: int
          the index of the image to process

        Return:
        --------
        list, list, list
            [5 x image feature matrix],
            [five padded captions for this image]
            [the length of each caption]
        """

        img = self.imgs[idx]
        img_features, captions = [], []
        for desc in self.descriptions[idx]:
            # convert each word into a list of sequences.
            seq = [self.wordtoidx[word] for word in desc.split(' ')
                  if word in self.wordtoidx]
            # pad the sequence with 0 on the right side
            in_seq = np.pad(
                seq, 
                (0, max_length - len(seq)),
                mode='constant',
                constant_values=(0, 0)
                )

            img_features.append(img)
            captions.append(in_seq)
    
        return (img_features, captions)


In [18]:
def my_collate(batch):
    """
    Processes the batch to return from the dataloader

    Parameters:
    -----------
    batch: tuple
      a batch from the Dataset

    Return:
    --------
    list
        [image feature matrix, captions, the length of each caption]
    """  

    img_features = [item[0] for item in batch]
    captions = [item[1] for item in batch]

    img_features = torch.FloatTensor(list(chain(*img_features)))
    captions = torch.LongTensor(list(chain(*captions)))

    return [img_features, captions]

In [19]:
def init_weights(model, embedding_pretrained=True):
    """
    Initialize weights and bias in the model

    Parameters:
    -----------
    model: CaptionModel
      a CaptionModel instance
    embedding_pretrained: bool (default: True)
        not initialize the embedding matrix if True
    """  
  
    for name, param in model.named_parameters():
        if embedding_pretrained and 'embedding' in name:
            continue
        elif 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            


In [20]:
def encode_image(model, img_path):
    """
    Process the images to extract features

    Parameters:
    -----------
    model: CNNModel
      a CNNModel instance
    img_path: str
        the path of the image
 
    Return:
    --------
    torch.Tensor
        the extracted feature matrix from CNNModel
    """  

    img = Image.open(img_path)

    # Perform preprocessing needed by pre-trained models
    preprocessor = transforms.Compose([
        transforms.Resize(model.input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    img = preprocessor(img)
    # Expand to 2D array
    img = img.view(1, *img.shape)
    # Call model to extract the smaller feature set for the image.
    x = model(img.to(device), False) 
    # Shape to correct form to be accepted by LSTM captioning network.
    x = np.squeeze(x)
    return x

In [21]:
def extract_img_features(img_paths, model):
    """
    Extracts, stores and returns image features

    Parameters:
    -----------
    img_paths: list
        the paths of images
    model: CNNModel (default: None)
      a CNNModel instance

    Return:
    --------
    numpy.ndarray
        the extracted image feature matrix from CNNModel
    """ 

    start = time()
    img_features = []

    for image_path in img_paths:
        img_features.append(
            encode_image(model, image_path).cpu().data.numpy()
    )

    print(f"\nGenerating set took: {hms_string(time()-start)}")

    return img_features

In [22]:
def get_train_test(
    encoder,
    train_paths,
    test_paths,
    cnn_type='inception_v3',
):

    train_img_features = extract_img_features(
        train_paths,
        encoder
    )

    test_img_features = extract_img_features(
        test_paths,
        encoder
    )
    return train_img_features, test_img_features

def get_train_dataloader(
    train_descriptions, 
    train_img_features,
    wordtoidx,
    max_length,
    batch_size=200
):
    train_dataset = SampleDataset(
        train_descriptions,
        train_img_features,
        wordtoidx,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size,
        collate_fn=my_collate
    )
    
    return train_loader

def train_model(
    train_loader,
    vocab_size,
    embedding_dim, 
    embedding_matrix,
    cnn_type='inception_v3',
    hidden_size=256,
):

    caption_model = CaptionModel(
        cnn_type, 
        vocab_size, 
        embedding_dim, 
        hidden_size=hidden_size,
        embedding_matrix=embedding_matrix, 
        embedding_train=True
    )

    init_weights(
        caption_model,
        embedding_pretrained=True
    )

    caption_model.to(device)

    # we will ignore the pad token in true target set
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    optimizer = torch.optim.Adam(
        caption_model.parameters(), 
        lr=0.01
    )

    clip = 1
    start = time()

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)

    # reduce the learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-4

    for i in tqdm(range(EPOCHS * 6)):

        loss = train(caption_model, train_loader, optimizer, criterion, clip, vocab_size)
        print(loss)
    return caption_model

In [23]:
def generateCaption(
    model, 
    img_features,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    in_text = START

    for i in range(max_length):

        sequence = [wordtoidx[w] for w in in_text.split() if w in wordtoidx]
        sequence = np.pad(sequence, (0, max_length - len(sequence)),
                          mode='constant', constant_values=(0, 0))
        model.eval()
        yhat = model(
            torch.FloatTensor(img_features)\
            .view(-1, model.feature_size).to(device),
            torch.LongTensor(sequence).view(-1, max_length).to(device)
        )

        yhat = yhat.view(-1, vocab_size).argmax(1)
        word = idxtoword[yhat.cpu().data.numpy()[i]]
        in_text += ' ' + word
        if word == STOP:
            break
    final = in_text.split()
    final = final[1 : -1]
    final = ' '.join(final)
    return final

### Evaluation

In [24]:

def eval_model(ref_data, results):
    """
    Computes evaluation metrics of the model results against the human annotated captions
    
    Parameters:
    ------------
    ref_data: dict
        a dictionary containing human annotated captions, with image name as key and a list of human annotated captions as values
    
    results: dict
        a dictionary containing model generated caption, with image name as key and a generated caption as value
        
    Returns:
    ------------
    score_dict: a dictionary containing the overall average score for the model
    """
    # download stanford nlp library
    subprocess.call(['../../scr/evaluation/get_stanford_models.sh'])
    
    # format the inputs
    gts = {}
    res = {}

    for imgId in range(len(ref_data)):
        caption_list_sel = []
        for i in range(5):
            lst = {}
            lst['caption'] = ref_data[imgId][i]
            lst['image_id'] = imgId
            lst['id'] = i
            caption_list_sel.append(lst)
        gts[imgId] = caption_list_sel

        res[imgId] = [{'caption': results[imgId]}]
        
    # tokenize
    print('tokenization...')
    tokenizer = PTBTokenizer()
    gts  = tokenizer.tokenize(gts)
    res = tokenizer.tokenize(res)
    
    # compute scores
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE"),
        (usc_sim(), "USC_similarity"),  
        ]
    score_dict = {}
    for scorer, method in scorers:
        print('computing %s score...'%(scorer.method()))
        score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                score_dict[m] = sc
        else:
            score_dict[method] = score
            
    return score_dict


In [25]:
def evaluate_results(
    test_img_features, 
    model,
    ref,
    max_length,
    vocab_size,
    wordtoidx,
    idxtoword
):
    # generate results
    print('Generating captions...')
    results = {}
    for n in range(len(test_img_features)):
        img_features = test_img_features[n]
        generated = generateCaption(
            model, 
            img_features,
            max_length,
            vocab_size,
            wordtoidx,
            idxtoword
        )
        results[n] = generated
        
    model_score = eval_model(ref, results)

    return model_score

### Cross validation

In [None]:
cnn_type = 'inception_v3'
encoder = CNNModel(cnn_type, pretrained=True)
encoder.to(device)

In [27]:
def cross_validation(train_index, test_index, count):
    print('=' * 60)
    print(f'Split {count}:')
    print(f'Splitting data...')
    
    train_paths, test_paths = all_paths[train_index], all_paths[test_index]
    train_descriptions, test_descriptions = all_descriptions[train_index], all_descriptions[test_index]
    print(f'{len(train_paths)} images for training and {len(test_paths)} images for testing.')
    
    vocab = get_vocab(train_descriptions, word_count_threshold=10)
    idxtoword, wordtoidx = get_word_dict(vocab)
    vocab_size = get_vocab_size(idxtoword)
    embedding_dim = 200
    embedding_matrix = get_embeddings(embeddings_index, vocab_size, embedding_dim, wordtoidx) 

    print(f'Preparing dataloader...')
    train_img_features, test_img_features = get_train_test(encoder, train_paths, test_paths)

    train_loader = get_train_dataloader(
        train_descriptions, 
        train_img_features,
        wordtoidx,
        max_length,
        batch_size=200
    )

    print(f'Training...')
    caption_model = train_model(
        train_loader,
        vocab_size,
        embedding_dim, 
        embedding_matrix
    )

    
    ref = captions[test_index]
    model_score = evaluate_results(
        test_img_features, 
        caption_model,
        ref,
        max_length,
        vocab_size,
        wordtoidx,
        idxtoword
    )
    
    return caption_model, model_score

In [28]:
cv = KFold(n_splits=5, random_state=123, shuffle=True)
cv = [(train_index, test_index) for train_index, test_index in cv.split(all_paths)]  

In [29]:
caption_model1, model_score1 = cross_validation(cv[0][0], cv[0][1], 1)    

Split 1:
Splitting data...
8332 images for training and 2084 images for testing.
There are 41660 captions
preprocessed words 2659 ==> 884
The vocabulary size is 885.
796 out of 885 words are found in the pre-trained matrix.
The size of embedding_matrix is (885, 200)
Preparing dataloader...

Generating set took: 0:03:27.32


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.02
Training...


  2%|▏         | 1/60 [00:12<12:11, 12.41s/it]

4.182570758320036


  3%|▎         | 2/60 [00:24<11:59, 12.40s/it]

2.3737904684884206


  5%|▌         | 3/60 [00:37<11:47, 12.41s/it]

1.788533179532914


  7%|▋         | 4/60 [00:49<11:34, 12.41s/it]

1.5672798327037267


  8%|▊         | 5/60 [01:02<11:22, 12.41s/it]

1.454969389098031


 10%|█         | 6/60 [01:14<11:09, 12.41s/it]

1.3811053860755194


 12%|█▏        | 7/60 [01:26<10:58, 12.42s/it]

1.3303789411272322


 13%|█▎        | 8/60 [01:39<10:45, 12.42s/it]

1.290237824122111


 15%|█▌        | 9/60 [01:51<10:32, 12.41s/it]

1.251978073801313


 17%|█▋        | 10/60 [02:04<10:20, 12.42s/it]

1.2223059336344402


 18%|█▊        | 11/60 [02:16<10:08, 12.42s/it]

1.1980939223652793


 20%|██        | 12/60 [02:28<09:56, 12.42s/it]

1.1809004488445463


 22%|██▏       | 13/60 [02:41<09:43, 12.42s/it]

1.1674254877226693


 23%|██▎       | 14/60 [02:53<09:31, 12.42s/it]

1.150037598042261


 25%|██▌       | 15/60 [03:06<09:18, 12.41s/it]

1.1362430424917311


 27%|██▋       | 16/60 [03:18<09:05, 12.41s/it]

1.1253308086168199


 28%|██▊       | 17/60 [03:30<08:53, 12.40s/it]

1.1132948994636536


 30%|███       | 18/60 [03:43<08:41, 12.40s/it]

1.100418896902175


 32%|███▏      | 19/60 [03:55<08:28, 12.40s/it]

1.0886233846346538


 33%|███▎      | 20/60 [04:08<08:16, 12.40s/it]

1.0782432967708224


 35%|███▌      | 21/60 [04:20<08:03, 12.41s/it]

1.0661887058189936


 37%|███▋      | 22/60 [04:33<07:51, 12.41s/it]

1.0544327000776927


 38%|███▊      | 23/60 [04:45<07:39, 12.41s/it]

1.0471297005812328


 40%|████      | 24/60 [04:57<07:26, 12.41s/it]

1.0423909042562758


 42%|████▏     | 25/60 [05:10<07:14, 12.41s/it]

1.0359007091749282


 43%|████▎     | 26/60 [05:22<07:02, 12.42s/it]

1.0364965966769628


 45%|████▌     | 27/60 [05:35<06:50, 12.43s/it]

1.0418959529626937


 47%|████▋     | 28/60 [05:47<06:37, 12.43s/it]

1.0347329860641843


 48%|████▊     | 29/60 [06:00<06:25, 12.44s/it]

1.0303222253209068


 50%|█████     | 30/60 [06:12<06:13, 12.44s/it]

1.0147268190270378


 52%|█████▏    | 31/60 [06:24<06:00, 12.44s/it]

1.0013734556379772


 53%|█████▎    | 32/60 [06:37<05:48, 12.44s/it]

0.9940886256240663


 55%|█████▌    | 33/60 [06:49<05:35, 12.44s/it]

0.9893362678232647


 57%|█████▋    | 34/60 [07:02<05:23, 12.44s/it]

0.9864191966397422


 58%|█████▊    | 35/60 [07:14<05:10, 12.44s/it]

0.9852754033747173


 60%|██████    | 36/60 [07:27<04:58, 12.43s/it]

0.9841237422965822


 62%|██████▏   | 37/60 [07:39<04:45, 12.43s/it]

0.980148834841592


 63%|██████▎   | 38/60 [07:51<04:33, 12.43s/it]

0.9854782294659388


 65%|██████▌   | 39/60 [08:04<04:20, 12.43s/it]

0.9848228693008423


 67%|██████▋   | 40/60 [08:16<04:08, 12.42s/it]

0.9733674001126063


 68%|██████▊   | 41/60 [08:29<03:56, 12.43s/it]

0.966656506061554


 70%|███████   | 42/60 [08:41<03:43, 12.43s/it]

0.9679365299996876


 72%|███████▏  | 43/60 [08:54<03:31, 12.43s/it]

0.9655981361865997


 73%|███████▎  | 44/60 [09:06<03:18, 12.42s/it]

0.9662466162726993


 75%|███████▌  | 45/60 [09:18<03:06, 12.43s/it]

0.9666620265869867


 77%|███████▋  | 46/60 [09:31<02:54, 12.45s/it]

0.9658543495904832


 78%|███████▊  | 47/60 [09:43<02:41, 12.44s/it]

0.9636464629854474


 80%|████████  | 48/60 [09:56<02:29, 12.44s/it]

0.9605093044894082


 82%|████████▏ | 49/60 [10:08<02:16, 12.44s/it]

0.9638115196000963


 83%|████████▎ | 50/60 [10:21<02:04, 12.44s/it]

0.9616620384511494


 85%|████████▌ | 51/60 [10:33<01:51, 12.43s/it]

0.9608509143193563


 87%|████████▋ | 52/60 [10:46<01:39, 12.43s/it]

0.9571442348616463


 88%|████████▊ | 53/60 [10:58<01:27, 12.43s/it]

0.9481313938186282


 90%|█████████ | 54/60 [11:10<01:14, 12.43s/it]

0.9453218806357611


 92%|█████████▏| 55/60 [11:23<01:02, 12.43s/it]

0.9418367842833201


 93%|█████████▎| 56/60 [11:35<00:49, 12.45s/it]

0.944669246673584


 95%|█████████▌| 57/60 [11:48<00:37, 12.44s/it]

0.9434159568377903


 97%|█████████▋| 58/60 [12:00<00:24, 12.44s/it]

0.9433150603657677


 98%|█████████▊| 59/60 [12:13<00:12, 12.43s/it]

0.9367710947990417


100%|██████████| 60/60 [12:25<00:00, 12.42s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9339388722465152


  2%|▏         | 1/60 [00:12<12:14, 12.46s/it]

0.8980559068066734


  3%|▎         | 2/60 [00:24<12:02, 12.45s/it]

0.8758343770390465


  5%|▌         | 3/60 [00:37<11:49, 12.45s/it]

0.8644798539933705


  7%|▋         | 4/60 [00:49<11:36, 12.44s/it]

0.8602245932533628


  8%|▊         | 5/60 [01:02<11:23, 12.43s/it]

0.8539146227496011


 10%|█         | 6/60 [01:14<11:11, 12.43s/it]

0.8521424829959869


 12%|█▏        | 7/60 [01:27<10:58, 12.42s/it]

0.8494699725082943


 13%|█▎        | 8/60 [01:39<10:46, 12.42s/it]

0.8461907832395463


 15%|█▌        | 9/60 [01:51<10:33, 12.42s/it]

0.8436059795674824


 17%|█▋        | 10/60 [02:04<10:21, 12.42s/it]

0.8413603547073546


 18%|█▊        | 11/60 [02:16<10:08, 12.42s/it]

0.8406376654193515


 20%|██        | 12/60 [02:29<09:56, 12.43s/it]

0.8386440518356505


 22%|██▏       | 13/60 [02:41<09:44, 12.43s/it]

0.8373598640873319


 23%|██▎       | 14/60 [02:53<09:31, 12.43s/it]

0.8350274037747156


 25%|██▌       | 15/60 [03:06<09:19, 12.43s/it]

0.8341100442977178


 27%|██▋       | 16/60 [03:18<09:06, 12.43s/it]

0.8323320916720799


 28%|██▊       | 17/60 [03:31<08:54, 12.43s/it]

0.8306259981223515


 30%|███       | 18/60 [03:43<08:41, 12.42s/it]

0.828258481763658


 32%|███▏      | 19/60 [03:56<08:29, 12.43s/it]

0.8281600063755399


 33%|███▎      | 20/60 [04:08<08:16, 12.42s/it]

0.8288232897009168


 35%|███▌      | 21/60 [04:20<08:04, 12.42s/it]

0.8270593753882817


 37%|███▋      | 22/60 [04:33<07:52, 12.42s/it]

0.8260565442698342


 38%|███▊      | 23/60 [04:45<07:39, 12.43s/it]

0.824696211587815


 40%|████      | 24/60 [04:58<07:27, 12.43s/it]

0.8263537500585828


 42%|████▏     | 25/60 [05:10<07:15, 12.43s/it]

0.8231681031840188


 43%|████▎     | 26/60 [05:23<07:02, 12.43s/it]

0.8216385756220136


 45%|████▌     | 27/60 [05:35<06:50, 12.43s/it]

0.8224901599543435


 47%|████▋     | 28/60 [05:47<06:38, 12.44s/it]

0.8209827442963918


 48%|████▊     | 29/60 [06:00<06:25, 12.44s/it]

0.8201577464739481


 50%|█████     | 30/60 [06:12<06:13, 12.44s/it]

0.8189452077661242


 52%|█████▏    | 31/60 [06:25<06:00, 12.44s/it]

0.8195005683671861


 53%|█████▎    | 32/60 [06:37<05:48, 12.44s/it]

0.8188972387995038


 55%|█████▌    | 33/60 [06:50<05:35, 12.43s/it]

0.8169329904374623


 57%|█████▋    | 34/60 [07:02<05:23, 12.44s/it]

0.8176535282816205


 58%|█████▊    | 35/60 [07:15<05:10, 12.44s/it]

0.8164116669268835


 60%|██████    | 36/60 [07:27<04:58, 12.44s/it]

0.8138195517517272


 62%|██████▏   | 37/60 [07:39<04:46, 12.44s/it]

0.8149354997135344


 63%|██████▎   | 38/60 [07:52<04:33, 12.44s/it]

0.815101492972601


 65%|██████▌   | 39/60 [08:04<04:21, 12.45s/it]

0.814050725528172


 67%|██████▋   | 40/60 [08:17<04:08, 12.45s/it]

0.8132522035212744


 68%|██████▊   | 41/60 [08:29<03:56, 12.45s/it]

0.8124964662960598


 70%|███████   | 42/60 [08:42<03:44, 12.45s/it]

0.8112550377845764


 72%|███████▏  | 43/60 [08:54<03:31, 12.44s/it]

0.8120198462690625


 73%|███████▎  | 44/60 [09:07<03:19, 12.45s/it]

0.8092736417338962


 75%|███████▌  | 45/60 [09:19<03:06, 12.46s/it]

0.8093039350850242


 77%|███████▋  | 46/60 [09:32<02:54, 12.46s/it]

0.8092445958228338


 78%|███████▊  | 47/60 [09:44<02:41, 12.45s/it]

0.8080305584839412


 80%|████████  | 48/60 [09:56<02:29, 12.45s/it]

0.8088401257991791


 82%|████████▏ | 49/60 [10:09<02:16, 12.44s/it]

0.8072196003936586


 83%|████████▎ | 50/60 [10:21<02:04, 12.45s/it]

0.8071638757274264


 85%|████████▌ | 51/60 [10:34<01:52, 12.45s/it]

0.8071197328113374


 87%|████████▋ | 52/60 [10:46<01:39, 12.44s/it]

0.8066972096761068


 88%|████████▊ | 53/60 [10:59<01:27, 12.44s/it]

0.8063559787614005


 90%|█████████ | 54/60 [11:11<01:14, 12.45s/it]

0.804940409603573


 92%|█████████▏| 55/60 [11:23<01:02, 12.44s/it]

0.805299828449885


 93%|█████████▎| 56/60 [11:36<00:49, 12.43s/it]

0.8044457520757403


 95%|█████████▌| 57/60 [11:48<00:37, 12.43s/it]

0.803552644593375


 97%|█████████▋| 58/60 [12:01<00:24, 12.42s/it]

0.8038398736999148


 98%|█████████▊| 59/60 [12:13<00:12, 12.42s/it]

0.8032894446736291


100%|██████████| 60/60 [12:26<00:00, 12.43s/it]

0.8012692488375164
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [30]:
model_score1

{'Bleu_1': 0.6374539978750372,
 'Bleu_2': 0.5097287325539306,
 'Bleu_3': 0.42603430037193946,
 'Bleu_4': 0.3651897571307535,
 'METEOR': 0.29132363875100775,
 'ROUGE_L': 0.5424600890989629,
 'CIDEr': 2.039759266574518,
 'SPICE': 0.3875159970259289,
 'USC_similarity': 0.600431677240578}

In [31]:
caption_model2, model_score2 = cross_validation(cv[1][0], cv[1][1], 2)    

Split 2:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2688 ==> 916
The vocabulary size is 917.
822 out of 917 words are found in the pre-trained matrix.
The size of embedding_matrix is (917, 200)
Preparing dataloader...

Generating set took: 0:03:29.32


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.13
Training...


  2%|▏         | 1/60 [00:12<12:12, 12.42s/it]

5.08402114822751


  3%|▎         | 2/60 [00:24<12:00, 12.42s/it]

2.8801963953744796


  5%|▌         | 3/60 [00:37<11:47, 12.42s/it]

1.994551195984795


  7%|▋         | 4/60 [00:49<11:35, 12.41s/it]

1.6586340296836126


  8%|▊         | 5/60 [01:02<11:22, 12.42s/it]

1.509123560928163


 10%|█         | 6/60 [01:14<11:10, 12.42s/it]

1.425741019703093


 12%|█▏        | 7/60 [01:26<10:58, 12.43s/it]

1.3715726222310747


 13%|█▎        | 8/60 [01:39<10:46, 12.43s/it]

1.3228611548741658


 15%|█▌        | 9/60 [01:51<10:33, 12.42s/it]

1.2845574021339417


 17%|█▋        | 10/60 [02:04<10:21, 12.43s/it]

1.2555282768749056


 18%|█▊        | 11/60 [02:16<10:09, 12.43s/it]

1.2314988772074382


 20%|██        | 12/60 [02:29<09:56, 12.43s/it]

1.215278060663314


 22%|██▏       | 13/60 [02:41<09:44, 12.43s/it]

1.1995048239117576


 23%|██▎       | 14/60 [02:53<09:31, 12.43s/it]

1.1809315993672325


 25%|██▌       | 15/60 [03:06<09:19, 12.44s/it]

1.1634812752405803


 27%|██▋       | 16/60 [03:18<09:07, 12.44s/it]

1.1494551925432115


 28%|██▊       | 17/60 [03:31<08:54, 12.44s/it]

1.1328152120113373


 30%|███       | 18/60 [03:43<08:42, 12.44s/it]

1.1248246899672918


 32%|███▏      | 19/60 [03:56<08:30, 12.44s/it]

1.1197352068764823


 33%|███▎      | 20/60 [04:08<08:18, 12.46s/it]

1.1097600772267295


 35%|███▌      | 21/60 [04:21<08:06, 12.46s/it]

1.0983218720981054


 37%|███▋      | 22/60 [04:33<07:53, 12.47s/it]

1.0932298245884122


 38%|███▊      | 23/60 [04:46<07:41, 12.46s/it]

1.0946208096685863


 40%|████      | 24/60 [04:58<07:28, 12.45s/it]

1.081929284901846


 42%|████▏     | 25/60 [05:10<07:15, 12.45s/it]

1.0717235619113559


 43%|████▎     | 26/60 [05:23<07:03, 12.45s/it]

1.0610216401872181


 45%|████▌     | 27/60 [05:35<06:50, 12.44s/it]

1.0548183478060222


 47%|████▋     | 28/60 [05:48<06:37, 12.44s/it]

1.0514032329831804


 48%|████▊     | 29/60 [06:00<06:25, 12.43s/it]

1.0453860135305495


 50%|█████     | 30/60 [06:13<06:13, 12.44s/it]

1.043930124668848


 52%|█████▏    | 31/60 [06:25<06:00, 12.43s/it]

1.0404107826096671


 53%|█████▎    | 32/60 [06:38<05:48, 12.44s/it]

1.0376664698123932


 55%|█████▌    | 33/60 [06:50<05:35, 12.44s/it]

1.028432616165706


 57%|█████▋    | 34/60 [07:02<05:23, 12.44s/it]

1.0269813040892284


 58%|█████▊    | 35/60 [07:15<05:10, 12.44s/it]

1.0255334632737296


 60%|██████    | 36/60 [07:27<04:58, 12.44s/it]

1.0200171442258925


 62%|██████▏   | 37/60 [07:40<04:45, 12.43s/it]

1.0163421233495076


 63%|██████▎   | 38/60 [07:52<04:33, 12.43s/it]

1.0135828384331294


 65%|██████▌   | 39/60 [08:05<04:21, 12.43s/it]

1.0108714997768402


 67%|██████▋   | 40/60 [08:17<04:08, 12.43s/it]

1.0056992641517095


 68%|██████▊   | 41/60 [08:29<03:56, 12.43s/it]

1.0007821449211665


 70%|███████   | 42/60 [08:42<03:43, 12.42s/it]

0.998654686269306


 72%|███████▏  | 43/60 [08:54<03:31, 12.43s/it]

0.9975147375038692


 73%|███████▎  | 44/60 [09:07<03:18, 12.43s/it]

0.9967692905948276


 75%|███████▌  | 45/60 [09:19<03:06, 12.43s/it]

0.9901414300714221


 77%|███████▋  | 46/60 [09:32<02:54, 12.44s/it]

0.9857585259846279


 78%|███████▊  | 47/60 [09:44<02:41, 12.44s/it]

0.9825409480503627


 80%|████████  | 48/60 [09:56<02:29, 12.44s/it]

0.9786754960105533


 82%|████████▏ | 49/60 [10:09<02:16, 12.45s/it]

0.9769554989678519


 83%|████████▎ | 50/60 [10:21<02:04, 12.44s/it]

0.9738933883962178


 85%|████████▌ | 51/60 [10:34<01:52, 12.45s/it]

0.9731203729198092


 87%|████████▋ | 52/60 [10:46<01:39, 12.44s/it]

0.9694723742348808


 88%|████████▊ | 53/60 [10:59<01:27, 12.46s/it]

0.9666808176608312


 90%|█████████ | 54/60 [11:11<01:14, 12.46s/it]

0.9693839564209893


 92%|█████████▏| 55/60 [11:24<01:02, 12.45s/it]

0.9694560794603257


 93%|█████████▎| 56/60 [11:36<00:49, 12.44s/it]

0.9644463700907571


 95%|█████████▌| 57/60 [11:49<00:37, 12.44s/it]

0.9603814241432008


 97%|█████████▋| 58/60 [12:01<00:24, 12.44s/it]

0.9582350438549405


 98%|█████████▊| 59/60 [12:13<00:12, 12.45s/it]

0.959266551903316


100%|██████████| 60/60 [12:26<00:00, 12.44s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.956176004239491


  2%|▏         | 1/60 [00:12<12:14, 12.46s/it]

0.9220919892901466


  3%|▎         | 2/60 [00:24<12:02, 12.46s/it]

0.9034506323791686


  5%|▌         | 3/60 [00:37<11:49, 12.45s/it]

0.8949057885578701


  7%|▋         | 4/60 [00:49<11:37, 12.45s/it]

0.8883367507230668


  8%|▊         | 5/60 [01:02<11:24, 12.45s/it]

0.8831552721205211


 10%|█         | 6/60 [01:14<11:12, 12.45s/it]

0.8790597518285116


 12%|█▏        | 7/60 [01:27<10:59, 12.45s/it]

0.877114053283419


 13%|█▎        | 8/60 [01:39<10:47, 12.45s/it]

0.8753604761191777


 15%|█▌        | 9/60 [01:52<10:35, 12.46s/it]

0.8713900588807606


 17%|█▋        | 10/60 [02:04<10:22, 12.46s/it]

0.8687995771567026


 18%|█▊        | 11/60 [02:17<10:10, 12.46s/it]

0.8672203833148593


 20%|██        | 12/60 [02:29<09:57, 12.45s/it]

0.8652308441343761


 22%|██▏       | 13/60 [02:41<09:45, 12.45s/it]

0.8650118779568445


 23%|██▎       | 14/60 [02:54<09:32, 12.45s/it]

0.8622052272160848


 25%|██▌       | 15/60 [03:06<09:20, 12.45s/it]

0.8614650780246371


 27%|██▋       | 16/60 [03:19<09:07, 12.45s/it]

0.8610604958874839


 28%|██▊       | 17/60 [03:31<08:55, 12.46s/it]

0.8596803801400321


 30%|███       | 18/60 [03:44<08:43, 12.46s/it]

0.8584636648495992


 32%|███▏      | 19/60 [03:56<08:30, 12.45s/it]

0.8585622381596338


 33%|███▎      | 20/60 [04:09<08:18, 12.45s/it]

0.8558258272352672


 35%|███▌      | 21/60 [04:21<08:05, 12.45s/it]

0.8546919439520154


 37%|███▋      | 22/60 [04:33<07:53, 12.46s/it]

0.8542437340532031


 38%|███▊      | 23/60 [04:46<07:41, 12.46s/it]

0.8539585414386931


 40%|████      | 24/60 [04:58<07:28, 12.46s/it]

0.8528694723333631


 42%|████▏     | 25/60 [05:11<07:15, 12.46s/it]

0.8514463560921806


 43%|████▎     | 26/60 [05:23<07:03, 12.46s/it]

0.8499794332754045


 45%|████▌     | 27/60 [05:36<06:51, 12.47s/it]

0.8513129779270717


 47%|████▋     | 28/60 [05:48<06:39, 12.49s/it]

0.8496763507525126


 48%|████▊     | 29/60 [06:01<06:26, 12.48s/it]

0.8471317376409259


 50%|█████     | 30/60 [06:13<06:13, 12.46s/it]

0.8473122985590071


 52%|█████▏    | 31/60 [06:26<06:01, 12.45s/it]

0.8481544809682029


 53%|█████▎    | 32/60 [06:38<05:48, 12.44s/it]

0.8468825008187976


 55%|█████▌    | 33/60 [06:50<05:35, 12.44s/it]

0.845466666278385


 57%|█████▋    | 34/60 [07:03<05:23, 12.44s/it]

0.8452956321693602


 58%|█████▊    | 35/60 [07:15<05:10, 12.44s/it]

0.8437117920035407


 60%|██████    | 36/60 [07:28<04:58, 12.44s/it]

0.8445922874269032


 62%|██████▏   | 37/60 [07:40<04:45, 12.43s/it]

0.8448778092861176


 63%|██████▎   | 38/60 [07:53<04:33, 12.43s/it]

0.843281070391337


 65%|██████▌   | 39/60 [08:05<04:21, 12.43s/it]

0.8421953036671593


 67%|██████▋   | 40/60 [08:18<04:08, 12.43s/it]

0.8410532034578777


 68%|██████▊   | 41/60 [08:30<03:56, 12.42s/it]

0.8412971794605255


 70%|███████   | 42/60 [08:42<03:43, 12.42s/it]

0.8412044587589446


 72%|███████▏  | 43/60 [08:55<03:31, 12.43s/it]

0.8401697050957453


 73%|███████▎  | 44/60 [09:07<03:19, 12.44s/it]

0.8391065895557404


 75%|███████▌  | 45/60 [09:20<03:06, 12.43s/it]

0.8407781209264483


 77%|███████▋  | 46/60 [09:32<02:53, 12.43s/it]

0.8404676843257177


 78%|███████▊  | 47/60 [09:45<02:41, 12.43s/it]

0.8376400286243075


 80%|████████  | 48/60 [09:57<02:29, 12.43s/it]

0.8381303222406478


 82%|████████▏ | 49/60 [10:09<02:16, 12.43s/it]

0.8367427786191305


 83%|████████▎ | 50/60 [10:22<02:04, 12.43s/it]

0.8377511359396435


 85%|████████▌ | 51/60 [10:34<01:51, 12.44s/it]

0.8366231222947439


 87%|████████▋ | 52/60 [10:47<01:39, 12.44s/it]

0.8351912853263673


 88%|████████▊ | 53/60 [10:59<01:27, 12.44s/it]

0.8357150270825341


 90%|█████████ | 54/60 [11:12<01:14, 12.44s/it]

0.8361598352591196


 92%|█████████▏| 55/60 [11:24<01:02, 12.44s/it]

0.8342445648851848


 93%|█████████▎| 56/60 [11:36<00:49, 12.44s/it]

0.8345076881703877


 95%|█████████▌| 57/60 [11:49<00:37, 12.44s/it]

0.8340944222041539


 97%|█████████▋| 58/60 [12:01<00:24, 12.45s/it]

0.8335570082778022


 98%|█████████▊| 59/60 [12:14<00:12, 12.45s/it]

0.8321406103315807


100%|██████████| 60/60 [12:26<00:00, 12.45s/it]

0.8324610221953619
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [32]:
model_score2

{'Bleu_1': 0.6232993589447667,
 'Bleu_2': 0.4915525450592119,
 'Bleu_3': 0.4058396978483188,
 'Bleu_4': 0.34568139609959103,
 'METEOR': 0.2809060472919057,
 'ROUGE_L': 0.5196458203981928,
 'CIDEr': 1.9593216133220972,
 'SPICE': 0.3776395545575529,
 'USC_similarity': 0.5946840881589985}

In [33]:
caption_model3, model_score3 = cross_validation(cv[2][0], cv[2][1], 3)    

Split 3:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2714 ==> 890
The vocabulary size is 891.
804 out of 891 words are found in the pre-trained matrix.
The size of embedding_matrix is (891, 200)
Preparing dataloader...

Generating set took: 0:03:30.51


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.21
Training...


  2%|▏         | 1/60 [00:12<12:15, 12.47s/it]

5.129976845922924


  3%|▎         | 2/60 [00:24<12:02, 12.46s/it]

2.6352923881439936


  5%|▌         | 3/60 [00:37<11:49, 12.45s/it]

1.8709557851155598


  7%|▋         | 4/60 [00:49<11:37, 12.46s/it]

1.6080446072987147


  8%|▊         | 5/60 [01:02<11:24, 12.45s/it]

1.485854943593343


 10%|█         | 6/60 [01:14<11:11, 12.44s/it]

1.4084124167760212


 12%|█▏        | 7/60 [01:27<10:59, 12.44s/it]

1.352670499256679


 13%|█▎        | 8/60 [01:39<10:47, 12.45s/it]

1.3113455289886111


 15%|█▌        | 9/60 [01:52<10:35, 12.46s/it]

1.281318037282853


 17%|█▋        | 10/60 [02:04<10:22, 12.45s/it]

1.2549243172009785


 18%|█▊        | 11/60 [02:16<10:10, 12.45s/it]

1.231392653215499


 20%|██        | 12/60 [02:29<09:57, 12.46s/it]

1.2093775045304072


 22%|██▏       | 13/60 [02:41<09:45, 12.45s/it]

1.1913392969540186


 23%|██▎       | 14/60 [02:54<09:32, 12.44s/it]

1.1756759541375297


 25%|██▌       | 15/60 [03:06<09:19, 12.43s/it]

1.1654229249273027


 27%|██▋       | 16/60 [03:19<09:07, 12.44s/it]

1.1643128565379552


 28%|██▊       | 17/60 [03:31<08:55, 12.46s/it]

1.1450757810047694


 30%|███       | 18/60 [03:44<08:43, 12.46s/it]

1.1239518849622636


 32%|███▏      | 19/60 [03:56<08:30, 12.46s/it]

1.1108570127260118


 33%|███▎      | 20/60 [04:09<08:18, 12.47s/it]

1.1008361138048626


 35%|███▌      | 21/60 [04:21<08:06, 12.46s/it]

1.0892404899710701


 37%|███▋      | 22/60 [04:33<07:53, 12.46s/it]

1.0789118622030531


 38%|███▊      | 23/60 [04:46<07:41, 12.46s/it]

1.0730501782326471


 40%|████      | 24/60 [04:58<07:28, 12.47s/it]

1.0692016283671062


 42%|████▏     | 25/60 [05:11<07:16, 12.46s/it]

1.0687812467416127


 43%|████▎     | 26/60 [05:23<07:03, 12.45s/it]

1.0706790004457747


 45%|████▌     | 27/60 [05:36<06:50, 12.43s/it]

1.0739876996903193


 47%|████▋     | 28/60 [05:48<06:38, 12.44s/it]

1.0623810915719896


 48%|████▊     | 29/60 [06:01<06:26, 12.45s/it]

1.048188832544145


 50%|█████     | 30/60 [06:13<06:13, 12.45s/it]

1.0418623302664076


 52%|█████▏    | 31/60 [06:25<06:00, 12.44s/it]

1.0375002991585505


 53%|█████▎    | 32/60 [06:38<05:48, 12.45s/it]

1.0339627549761818


 55%|█████▌    | 33/60 [06:50<05:37, 12.49s/it]

1.0307996273040771


 57%|█████▋    | 34/60 [07:03<05:24, 12.48s/it]

1.026764540445237


 58%|█████▊    | 35/60 [07:15<05:11, 12.48s/it]

1.02229810044879


 60%|██████    | 36/60 [07:28<04:59, 12.46s/it]

1.0153045867170607


 62%|██████▏   | 37/60 [07:40<04:46, 12.45s/it]

1.0138718372299558


 63%|██████▎   | 38/60 [07:53<04:33, 12.45s/it]

1.0076842634450822


 65%|██████▌   | 39/60 [08:05<04:21, 12.46s/it]

1.006029920918601


 67%|██████▋   | 40/60 [08:18<04:09, 12.46s/it]

1.0031062009788694


 68%|██████▊   | 41/60 [08:30<03:56, 12.45s/it]

1.0001801678112574


 70%|███████   | 42/60 [08:43<03:44, 12.44s/it]

0.9965316369420006


 72%|███████▏  | 43/60 [08:55<03:31, 12.44s/it]

0.9915829414413089


 73%|███████▎  | 44/60 [09:07<03:19, 12.45s/it]

0.9909186278070722


 75%|███████▌  | 45/60 [09:20<03:06, 12.44s/it]

0.9883001120317549


 77%|███████▋  | 46/60 [09:32<02:54, 12.45s/it]

0.9834977629638854


 78%|███████▊  | 47/60 [09:45<02:41, 12.46s/it]

0.98433613493329


 80%|████████  | 48/60 [09:57<02:29, 12.45s/it]

0.9929761404082889


 82%|████████▏ | 49/60 [10:10<02:16, 12.45s/it]

0.9919636944929758


 83%|████████▎ | 50/60 [10:22<02:04, 12.44s/it]

0.9888607930569422


 85%|████████▌ | 51/60 [10:35<01:51, 12.44s/it]

0.9924122478280749


 87%|████████▋ | 52/60 [10:47<01:39, 12.44s/it]

0.990456644977842


 88%|████████▊ | 53/60 [10:59<01:27, 12.45s/it]

0.9822497339475722


 90%|█████████ | 54/60 [11:12<01:14, 12.44s/it]

0.9757189779054551


 92%|█████████▏| 55/60 [11:24<01:02, 12.44s/it]

0.9723517923128038


 93%|█████████▎| 56/60 [11:37<00:49, 12.44s/it]

0.9703089083944049


 95%|█████████▌| 57/60 [11:49<00:37, 12.44s/it]

0.9699232961450305


 97%|█████████▋| 58/60 [12:02<00:24, 12.43s/it]

0.9704673403785342


 98%|█████████▊| 59/60 [12:14<00:12, 12.44s/it]

0.9691362196490878


100%|██████████| 60/60 [12:27<00:00, 12.45s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9697241726375762


  2%|▏         | 1/60 [00:12<12:13, 12.43s/it]

0.9349783786705562


  3%|▎         | 2/60 [00:24<12:01, 12.43s/it]

0.9082580577759516


  5%|▌         | 3/60 [00:37<11:49, 12.44s/it]

0.897694875796636


  7%|▋         | 4/60 [00:49<11:36, 12.44s/it]

0.8902328525270734


  8%|▊         | 5/60 [01:02<11:23, 12.43s/it]

0.8868971098036993


 10%|█         | 6/60 [01:14<11:11, 12.43s/it]

0.8835585117340088


 12%|█▏        | 7/60 [01:27<10:59, 12.44s/it]

0.8808554737340837


 13%|█▎        | 8/60 [01:39<10:47, 12.44s/it]

0.8776982625325521


 15%|█▌        | 9/60 [01:51<10:34, 12.44s/it]

0.8766997626849583


 17%|█▋        | 10/60 [02:04<10:21, 12.44s/it]

0.8731288086800348


 18%|█▊        | 11/60 [02:16<10:09, 12.43s/it]

0.8711788739476886


 20%|██        | 12/60 [02:29<09:56, 12.43s/it]

0.8694308130514055


 22%|██▏       | 13/60 [02:41<09:44, 12.43s/it]

0.8681636523632776


 23%|██▎       | 14/60 [02:54<09:31, 12.42s/it]

0.867511632896605


 25%|██▌       | 15/60 [03:06<09:19, 12.42s/it]

0.8650758323215303


 27%|██▋       | 16/60 [03:18<09:06, 12.42s/it]

0.8641471876984551


 28%|██▊       | 17/60 [03:31<08:54, 12.42s/it]

0.8637098826113201


 30%|███       | 18/60 [03:43<08:41, 12.42s/it]

0.8626090983549753


 32%|███▏      | 19/60 [03:56<08:29, 12.42s/it]

0.8606143579596565


 33%|███▎      | 20/60 [04:08<08:16, 12.42s/it]

0.8600857101735615


 35%|███▌      | 21/60 [04:21<08:05, 12.46s/it]

0.8592687717505864


 37%|███▋      | 22/60 [04:33<07:53, 12.45s/it]

0.8581525683403015


 38%|███▊      | 23/60 [04:46<07:40, 12.45s/it]

0.8561326534975142


 40%|████      | 24/60 [04:58<07:28, 12.47s/it]

0.8566004804202488


 42%|████▏     | 25/60 [05:10<07:16, 12.47s/it]

0.8549493778319586


 43%|████▎     | 26/60 [05:23<07:03, 12.46s/it]

0.8546434782800221


 45%|████▌     | 27/60 [05:35<06:51, 12.46s/it]

0.8535962785993304


 47%|████▋     | 28/60 [05:48<06:38, 12.45s/it]

0.8516514840580168


 48%|████▊     | 29/60 [06:00<06:25, 12.45s/it]

0.8509184604599362


 50%|█████     | 30/60 [06:13<06:13, 12.44s/it]

0.8510868975094387


 52%|█████▏    | 31/60 [06:25<06:00, 12.44s/it]

0.8494143244766054


 53%|█████▎    | 32/60 [06:38<05:48, 12.44s/it]

0.8496061818940299


 55%|█████▌    | 33/60 [06:50<05:35, 12.44s/it]

0.8491092835153852


 57%|█████▋    | 34/60 [07:02<05:23, 12.44s/it]

0.8485792108944484


 58%|█████▊    | 35/60 [07:15<05:11, 12.44s/it]

0.8462718725204468


 60%|██████    | 36/60 [07:27<04:58, 12.44s/it]

0.8456621170043945


 62%|██████▏   | 37/60 [07:40<04:45, 12.43s/it]

0.8461239905584426


 63%|██████▎   | 38/60 [07:52<04:33, 12.44s/it]

0.8440118119830177


 65%|██████▌   | 39/60 [08:05<04:21, 12.44s/it]

0.8460045896825337


 67%|██████▋   | 40/60 [08:17<04:08, 12.44s/it]

0.8448374243009658


 68%|██████▊   | 41/60 [08:30<03:56, 12.44s/it]

0.8432750276156834


 70%|███████   | 42/60 [08:42<03:43, 12.44s/it]

0.8425240403129941


 72%|███████▏  | 43/60 [08:54<03:31, 12.43s/it]

0.8435572910876501


 73%|███████▎  | 44/60 [09:07<03:18, 12.43s/it]

0.8407397681758517


 75%|███████▌  | 45/60 [09:19<03:06, 12.44s/it]

0.8414544534115564


 77%|███████▋  | 46/60 [09:32<02:54, 12.44s/it]

0.8421194226968856


 78%|███████▊  | 47/60 [09:44<02:41, 12.43s/it]

0.8404258489608765


 80%|████████  | 48/60 [09:57<02:29, 12.44s/it]

0.8404197210357303


 82%|████████▏ | 49/60 [10:09<02:16, 12.43s/it]

0.8390609508468991


 83%|████████▎ | 50/60 [10:21<02:04, 12.42s/it]

0.8398087663309914


 85%|████████▌ | 51/60 [10:34<01:51, 12.43s/it]

0.8380925428299677


 87%|████████▋ | 52/60 [10:46<01:39, 12.43s/it]

0.8375817977246784


 88%|████████▊ | 53/60 [10:59<01:26, 12.43s/it]

0.836236488251459


 90%|█████████ | 54/60 [11:11<01:14, 12.43s/it]

0.8371279665402004


 92%|█████████▏| 55/60 [11:24<01:02, 12.44s/it]

0.8375012079874674


 93%|█████████▎| 56/60 [11:36<00:49, 12.44s/it]

0.8364103535811106


 95%|█████████▌| 57/60 [11:48<00:37, 12.44s/it]

0.835013713155474


 97%|█████████▋| 58/60 [12:01<00:24, 12.44s/it]

0.8348964835916247


 98%|█████████▊| 59/60 [12:13<00:12, 12.44s/it]

0.8348967234293619


100%|██████████| 60/60 [12:26<00:00, 12.44s/it]

0.8345382539998918
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [34]:
model_score3

{'Bleu_1': 0.6346993011263479,
 'Bleu_2': 0.5067599152598399,
 'Bleu_3': 0.423500759390165,
 'Bleu_4': 0.3635006872983065,
 'METEOR': 0.28896062185124893,
 'ROUGE_L': 0.5368304034660417,
 'CIDEr': 2.06688548454554,
 'SPICE': 0.38560690616565724,
 'USC_similarity': 0.602588351841424}

In [35]:
caption_model4, model_score4 = cross_validation(cv[3][0], cv[3][1], 4)    

Split 4:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2680 ==> 894
The vocabulary size is 895.
809 out of 895 words are found in the pre-trained matrix.
The size of embedding_matrix is (895, 200)
Preparing dataloader...

Generating set took: 0:03:28.59


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:52.05
Training...


  2%|▏         | 1/60 [00:12<12:11, 12.40s/it]

4.710105055854434


  3%|▎         | 2/60 [00:24<11:59, 12.40s/it]

2.433500954083034


  5%|▌         | 3/60 [00:37<11:47, 12.40s/it]

1.7926395790917533


  7%|▋         | 4/60 [00:49<11:34, 12.41s/it]

1.5621429426329476


  8%|▊         | 5/60 [01:02<11:23, 12.42s/it]

1.4463493852388292


 10%|█         | 6/60 [01:14<11:11, 12.43s/it]

1.3723340006101699


 12%|█▏        | 7/60 [01:26<10:59, 12.43s/it]

1.3202649780682154


 13%|█▎        | 8/60 [01:39<10:46, 12.44s/it]

1.2838689968699502


 15%|█▌        | 9/60 [01:51<10:34, 12.44s/it]

1.2516410975229173


 17%|█▋        | 10/60 [02:04<10:21, 12.44s/it]

1.2176026815459842


 18%|█▊        | 11/60 [02:16<10:09, 12.43s/it]

1.1942871241342454


 20%|██        | 12/60 [02:29<09:56, 12.43s/it]

1.1728773656345548


 22%|██▏       | 13/60 [02:41<09:43, 12.42s/it]

1.1595100050880796


 23%|██▎       | 14/60 [02:53<09:31, 12.42s/it]

1.1417292384874254


 25%|██▌       | 15/60 [03:06<09:19, 12.42s/it]

1.128582877772195


 27%|██▋       | 16/60 [03:18<09:06, 12.42s/it]

1.1212619997206188


 28%|██▊       | 17/60 [03:31<08:54, 12.42s/it]

1.1056971535796212


 30%|███       | 18/60 [03:43<08:41, 12.42s/it]

1.0905448638257527


 32%|███▏      | 19/60 [03:56<08:29, 12.42s/it]

1.0724080148197355


 33%|███▎      | 20/60 [04:08<08:16, 12.42s/it]

1.0632259207112449


 35%|███▌      | 21/60 [04:20<08:04, 12.42s/it]

1.0563615305083138


 37%|███▋      | 22/60 [04:33<07:51, 12.42s/it]

1.0560574630896251


 38%|███▊      | 23/60 [04:45<07:39, 12.42s/it]

1.0445942140760875


 40%|████      | 24/60 [04:58<07:27, 12.42s/it]

1.0391503317015511


 42%|████▏     | 25/60 [05:10<07:14, 12.43s/it]

1.0335163403125036


 43%|████▎     | 26/60 [05:23<07:02, 12.42s/it]

1.0301575022084373


 45%|████▌     | 27/60 [05:35<06:49, 12.42s/it]

1.029051423072815


 47%|████▋     | 28/60 [05:47<06:37, 12.41s/it]

1.0239761769771576


 48%|████▊     | 29/60 [06:00<06:24, 12.41s/it]

1.0166866992201125


 50%|█████     | 30/60 [06:12<06:12, 12.41s/it]

1.010642300049464


 52%|█████▏    | 31/60 [06:25<05:59, 12.41s/it]

1.0100897153218586


 53%|█████▎    | 32/60 [06:37<05:47, 12.41s/it]

1.0059295225711096


 55%|█████▌    | 33/60 [06:49<05:35, 12.41s/it]

0.9993539466744378


 57%|█████▋    | 34/60 [07:02<05:22, 12.41s/it]

0.9989523788293203


 58%|█████▊    | 35/60 [07:14<05:10, 12.41s/it]

0.9907059839793614


 60%|██████    | 36/60 [07:27<04:57, 12.41s/it]

0.9829332260858445


 62%|██████▏   | 37/60 [07:39<04:45, 12.41s/it]

0.9812400241692861


 63%|██████▎   | 38/60 [07:51<04:33, 12.42s/it]

0.9761856936273121


 65%|██████▌   | 39/60 [08:04<04:20, 12.42s/it]

0.9716450557822273


 67%|██████▋   | 40/60 [08:16<04:08, 12.43s/it]

0.9661568289711362


 68%|██████▊   | 41/60 [08:29<03:56, 12.43s/it]

0.9606773881685167


 70%|███████   | 42/60 [08:41<03:43, 12.42s/it]

0.9587236259664808


 72%|███████▏  | 43/60 [08:54<03:31, 12.42s/it]

0.9541039608773731


 73%|███████▎  | 44/60 [09:06<03:18, 12.42s/it]

0.9554754112448011


 75%|███████▌  | 45/60 [09:18<03:06, 12.42s/it]

0.9655640352339971


 77%|███████▋  | 46/60 [09:31<02:53, 12.43s/it]

0.9782209311212812


 78%|███████▊  | 47/60 [09:43<02:41, 12.43s/it]

0.972030588558742


 80%|████████  | 48/60 [09:56<02:29, 12.44s/it]

0.9588965404601324


 82%|████████▏ | 49/60 [10:08<02:16, 12.44s/it]

0.9479495471432096


 83%|████████▎ | 50/60 [10:21<02:04, 12.45s/it]

0.9410500327746073


 85%|████████▌ | 51/60 [10:33<01:52, 12.46s/it]

0.9358720282713572


 87%|████████▋ | 52/60 [10:46<01:39, 12.46s/it]

0.9328097119217827


 88%|████████▊ | 53/60 [10:58<01:27, 12.45s/it]

0.9331837367443812


 90%|█████████ | 54/60 [11:11<01:14, 12.45s/it]

0.9305304189523061


 92%|█████████▏| 55/60 [11:23<01:02, 12.46s/it]

0.9305258975142524


 93%|█████████▎| 56/60 [11:36<00:49, 12.48s/it]

0.9290963751929147


 95%|█████████▌| 57/60 [11:48<00:37, 12.46s/it]

0.9297347182319278


 97%|█████████▋| 58/60 [12:00<00:24, 12.46s/it]

0.9281992685227167


 98%|█████████▊| 59/60 [12:13<00:12, 12.45s/it]

0.9245906968911489


100%|██████████| 60/60 [12:25<00:00, 12.43s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.9252161255904606


  2%|▏         | 1/60 [00:12<12:14, 12.46s/it]

0.8905430592241741


  3%|▎         | 2/60 [00:24<12:02, 12.45s/it]

0.8666346484706515


  5%|▌         | 3/60 [00:37<11:49, 12.45s/it]

0.8556046727157774


  7%|▋         | 4/60 [00:49<11:37, 12.45s/it]

0.8501715191773006


  8%|▊         | 5/60 [01:02<11:25, 12.46s/it]

0.8464192010107494


 10%|█         | 6/60 [01:14<11:12, 12.45s/it]

0.8422766299474806


 12%|█▏        | 7/60 [01:27<11:00, 12.45s/it]

0.8379917059625898


 13%|█▎        | 8/60 [01:39<10:47, 12.45s/it]

0.8357255515598115


 15%|█▌        | 9/60 [01:52<10:36, 12.49s/it]

0.8323054398809161


 17%|█▋        | 10/60 [02:04<10:23, 12.47s/it]

0.8310184705825079


 18%|█▊        | 11/60 [02:17<10:10, 12.46s/it]

0.8296576809315455


 20%|██        | 12/60 [02:29<09:58, 12.47s/it]

0.8281785732223874


 22%|██▏       | 13/60 [02:42<09:45, 12.47s/it]

0.8267851315793537


 23%|██▎       | 14/60 [02:54<09:33, 12.47s/it]

0.8243965747810545


 25%|██▌       | 15/60 [03:06<09:20, 12.46s/it]

0.8233230043025244


 27%|██▋       | 16/60 [03:19<09:09, 12.49s/it]

0.8220333329268864


 28%|██▊       | 17/60 [03:31<08:56, 12.48s/it]

0.8200168112913767


 30%|███       | 18/60 [03:44<08:43, 12.47s/it]

0.8205416685058957


 32%|███▏      | 19/60 [03:56<08:30, 12.46s/it]

0.8167814342748552


 33%|███▎      | 20/60 [04:09<08:18, 12.45s/it]

0.8183642753532955


 35%|███▌      | 21/60 [04:21<08:05, 12.45s/it]

0.8161772603080386


 37%|███▋      | 22/60 [04:34<07:53, 12.46s/it]

0.8156205387342543


 38%|███▊      | 23/60 [04:46<07:41, 12.47s/it]

0.8157559760979244


 40%|████      | 24/60 [04:59<07:29, 12.48s/it]

0.8130481058642978


 42%|████▏     | 25/60 [05:11<07:16, 12.47s/it]

0.8110865510645366


 43%|████▎     | 26/60 [05:24<07:04, 12.47s/it]

0.8099093096596854


 45%|████▌     | 27/60 [05:36<06:51, 12.47s/it]

0.8094735542933146


 47%|████▋     | 28/60 [05:49<06:38, 12.47s/it]

0.8100155521006811


 48%|████▊     | 29/60 [06:01<06:26, 12.46s/it]

0.8087238499096462


 50%|█████     | 30/60 [06:13<06:13, 12.46s/it]

0.8067663822855268


 52%|█████▏    | 31/60 [06:26<06:01, 12.46s/it]

0.806460759469441


 53%|█████▎    | 32/60 [06:38<05:48, 12.45s/it]

0.8071606357892355


 55%|█████▌    | 33/60 [06:51<05:36, 12.45s/it]

0.8060628076394399


 57%|█████▋    | 34/60 [07:03<05:23, 12.45s/it]

0.8039015503156752


 58%|█████▊    | 35/60 [07:16<05:12, 12.48s/it]

0.8045995732148489


 60%|██████    | 36/60 [07:28<04:59, 12.47s/it]

0.8038031160831451


 62%|██████▏   | 37/60 [07:41<04:46, 12.46s/it]

0.8026276344344729


 63%|██████▎   | 38/60 [07:53<04:34, 12.46s/it]

0.8022025128205618


 65%|██████▌   | 39/60 [08:06<04:21, 12.46s/it]

0.8006049153350648


 67%|██████▋   | 40/60 [08:18<04:09, 12.45s/it]

0.8031875845931825


 68%|██████▊   | 41/60 [08:30<03:56, 12.45s/it]

0.8016402778171358


 70%|███████   | 42/60 [08:43<03:44, 12.45s/it]

0.8010655343532562


 72%|███████▏  | 43/60 [08:55<03:31, 12.44s/it]

0.8002068854513622


 73%|███████▎  | 44/60 [09:08<03:19, 12.45s/it]

0.7990801944619134


 75%|███████▌  | 45/60 [09:20<03:06, 12.45s/it]

0.7982888164974394


 77%|███████▋  | 46/60 [09:33<02:54, 12.45s/it]

0.797758800642831


 78%|███████▊  | 47/60 [09:45<02:41, 12.45s/it]

0.7982397334916251


 80%|████████  | 48/60 [09:58<02:29, 12.44s/it]

0.7953313305264428


 82%|████████▏ | 49/60 [10:10<02:16, 12.44s/it]

0.7959190493538266


 83%|████████▎ | 50/60 [10:22<02:04, 12.44s/it]

0.7962074591999962


 85%|████████▌ | 51/60 [10:35<01:52, 12.45s/it]

0.7952355912753514


 87%|████████▋ | 52/60 [10:47<01:39, 12.45s/it]

0.794240436383656


 88%|████████▊ | 53/60 [11:00<01:27, 12.44s/it]

0.795652911776588


 90%|█████████ | 54/60 [11:12<01:14, 12.44s/it]

0.7933982766809917


 92%|█████████▏| 55/60 [11:25<01:02, 12.44s/it]

0.7928669878414699


 93%|█████████▎| 56/60 [11:37<00:49, 12.44s/it]

0.792652625413168


 95%|█████████▌| 57/60 [11:50<00:37, 12.44s/it]

0.7908003259272802


 97%|█████████▋| 58/60 [12:02<00:24, 12.44s/it]

0.7912927255744026


 98%|█████████▊| 59/60 [12:14<00:12, 12.44s/it]

0.7917830248673757


100%|██████████| 60/60 [12:27<00:00, 12.46s/it]

0.7913799158164433
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [36]:
model_score4

{'Bleu_1': 0.6297681376215113,
 'Bleu_2': 0.5045738328765402,
 'Bleu_3': 0.4225585731564296,
 'Bleu_4': 0.3630833438964103,
 'METEOR': 0.28928258538798995,
 'ROUGE_L': 0.5347993826502163,
 'CIDEr': 1.9861205791804195,
 'SPICE': 0.37538831143312207,
 'USC_similarity': 0.5941384401227842}

In [37]:
caption_model5, model_score5 = cross_validation(cv[4][0], cv[4][1], 5)    

Split 5:
Splitting data...
8333 images for training and 2083 images for testing.
There are 41665 captions
preprocessed words 2657 ==> 905
The vocabulary size is 906.
818 out of 906 words are found in the pre-trained matrix.
The size of embedding_matrix is (906, 200)
Preparing dataloader...

Generating set took: 0:03:33.08


  0%|          | 0/60 [00:00<?, ?it/s]


Generating set took: 0:00:53.23
Training...


  2%|▏         | 1/60 [00:12<12:16, 12.48s/it]

4.379349487168448


  3%|▎         | 2/60 [00:24<12:03, 12.48s/it]

2.38930272488367


  5%|▌         | 3/60 [00:37<11:50, 12.47s/it]

1.8136864418075198


  7%|▋         | 4/60 [00:49<11:37, 12.46s/it]

1.567824519815899


  8%|▊         | 5/60 [01:02<11:25, 12.46s/it]

1.4427709806533087


 10%|█         | 6/60 [01:14<11:12, 12.46s/it]

1.3636500040690105


 12%|█▏        | 7/60 [01:27<11:00, 12.46s/it]

1.3223009024347578


 13%|█▎        | 8/60 [01:39<10:47, 12.46s/it]

1.2866435902459281


 15%|█▌        | 9/60 [01:52<10:35, 12.46s/it]

1.2545127215839567


 17%|█▋        | 10/60 [02:04<10:24, 12.48s/it]

1.2281261341912406


 18%|█▊        | 11/60 [02:17<10:11, 12.47s/it]

1.2027420032592047


 20%|██        | 12/60 [02:29<09:58, 12.46s/it]

1.172298184462956


 22%|██▏       | 13/60 [02:42<09:45, 12.46s/it]

1.1544020147550673


 23%|██▎       | 14/60 [02:54<09:32, 12.45s/it]

1.1425253947575886


 25%|██▌       | 15/60 [03:06<09:20, 12.45s/it]

1.1280766810689653


 27%|██▋       | 16/60 [03:19<09:08, 12.46s/it]

1.1132818857828777


 28%|██▊       | 17/60 [03:31<08:55, 12.46s/it]

1.1036523779233296


 30%|███       | 18/60 [03:44<08:42, 12.45s/it]

1.0951522106216067


 32%|███▏      | 19/60 [03:56<08:30, 12.45s/it]

1.0885810795284452


 33%|███▎      | 20/60 [04:09<08:17, 12.45s/it]

1.077813900652386


 35%|███▌      | 21/60 [04:21<08:05, 12.45s/it]

1.0722151228359766


 37%|███▋      | 22/60 [04:34<07:52, 12.44s/it]

1.0675264157000042


 38%|███▊      | 23/60 [04:46<07:40, 12.44s/it]

1.0652504747822171


 40%|████      | 24/60 [04:58<07:27, 12.44s/it]

1.056295775231861


 42%|████▏     | 25/60 [05:11<07:15, 12.44s/it]

1.0488878573690141


 43%|████▎     | 26/60 [05:23<07:03, 12.45s/it]

1.0449349809260595


 45%|████▌     | 27/60 [05:36<06:51, 12.47s/it]

1.039203100261234


 47%|████▋     | 28/60 [05:48<06:38, 12.47s/it]

1.0382240017255147


 48%|████▊     | 29/60 [06:01<06:26, 12.47s/it]

1.0272207245940255


 50%|█████     | 30/60 [06:13<06:14, 12.48s/it]

1.0202869006565638


 52%|█████▏    | 31/60 [06:26<06:01, 12.48s/it]

1.0176583911691393


 53%|█████▎    | 32/60 [06:38<05:49, 12.47s/it]

1.005128951299758


 55%|█████▌    | 33/60 [06:51<05:36, 12.47s/it]

0.9986215148653302


 57%|█████▋    | 34/60 [07:03<05:24, 12.47s/it]

0.9974624883560907


 58%|█████▊    | 35/60 [07:16<05:11, 12.47s/it]

0.9964098802634648


 60%|██████    | 36/60 [07:28<04:59, 12.47s/it]

0.9991165513084048


 62%|██████▏   | 37/60 [07:41<04:46, 12.47s/it]

0.9952517066683088


 63%|██████▎   | 38/60 [07:53<04:34, 12.48s/it]

0.995566596587499


 65%|██████▌   | 39/60 [08:06<04:22, 12.49s/it]

0.9883129412219638


 67%|██████▋   | 40/60 [08:18<04:09, 12.47s/it]

0.9781376066662016


 68%|██████▊   | 41/60 [08:30<03:56, 12.46s/it]

0.9751831874960945


 70%|███████   | 42/60 [08:43<03:44, 12.46s/it]

0.9684001193160102


 72%|███████▏  | 43/60 [08:55<03:31, 12.46s/it]

0.9700285820733934


 73%|███████▎  | 44/60 [09:08<03:19, 12.47s/it]

0.969545209691638


 75%|███████▌  | 45/60 [09:20<03:07, 12.47s/it]

0.966312818583988


 77%|███████▋  | 46/60 [09:33<02:54, 12.47s/it]

0.971175784156436


 78%|███████▊  | 47/60 [09:45<02:42, 12.47s/it]

0.9734680979024797


 80%|████████  | 48/60 [09:58<02:29, 12.47s/it]

0.9632811191536131


 82%|████████▏ | 49/60 [10:10<02:17, 12.50s/it]

0.9623949953487941


 83%|████████▎ | 50/60 [10:23<02:04, 12.48s/it]

0.957572014558883


 85%|████████▌ | 51/60 [10:35<01:52, 12.47s/it]

0.9555972913901011


 87%|████████▋ | 52/60 [10:48<01:39, 12.47s/it]

0.9564438646747953


 88%|████████▊ | 53/60 [11:00<01:27, 12.46s/it]

0.9502592399006798


 90%|█████████ | 54/60 [11:13<01:14, 12.46s/it]

0.9483686203048343


 92%|█████████▏| 55/60 [11:25<01:02, 12.46s/it]

0.9488138059775034


 93%|█████████▎| 56/60 [11:37<00:49, 12.46s/it]

0.9504723761762891


 95%|█████████▌| 57/60 [11:50<00:37, 12.45s/it]

0.9495455409799304


 97%|█████████▋| 58/60 [12:02<00:24, 12.45s/it]

0.9497196504047939


 98%|█████████▊| 59/60 [12:15<00:12, 12.44s/it]

0.9479275998615083


100%|██████████| 60/60 [12:27<00:00, 12.46s/it]
  0%|          | 0/60 [00:00<?, ?it/s]

0.942500501871109


  2%|▏         | 1/60 [00:12<12:13, 12.44s/it]

0.9012825858025324


  3%|▎         | 2/60 [00:24<12:01, 12.43s/it]

0.8800567502067203


  5%|▌         | 3/60 [00:37<11:49, 12.44s/it]

0.8708955986159188


  7%|▋         | 4/60 [00:49<11:36, 12.44s/it]

0.8667961671238854


  8%|▊         | 5/60 [01:02<11:24, 12.44s/it]

0.8594556266353244


 10%|█         | 6/60 [01:14<11:11, 12.44s/it]

0.8555926212242672


 12%|█▏        | 7/60 [01:27<10:59, 12.45s/it]

0.8537148364952633


 13%|█▎        | 8/60 [01:39<10:47, 12.44s/it]

0.8519189897037688


 15%|█▌        | 9/60 [01:51<10:34, 12.44s/it]

0.8509318161578405


 17%|█▋        | 10/60 [02:04<10:22, 12.44s/it]

0.846211070106143


 18%|█▊        | 11/60 [02:16<10:09, 12.44s/it]

0.8463726072084337


 20%|██        | 12/60 [02:29<09:56, 12.43s/it]

0.8436004420121511


 22%|██▏       | 13/60 [02:41<09:44, 12.44s/it]

0.8413482989583697


 23%|██▎       | 14/60 [02:54<09:32, 12.44s/it]

0.8416756732123238


 25%|██▌       | 15/60 [03:06<09:19, 12.44s/it]

0.8391300198577699


 27%|██▋       | 16/60 [03:19<09:07, 12.44s/it]

0.8359051304204124


 28%|██▊       | 17/60 [03:31<08:54, 12.43s/it]

0.8358501536505563


 30%|███       | 18/60 [03:43<08:42, 12.43s/it]

0.8349973360697428


 32%|███▏      | 19/60 [03:56<08:29, 12.43s/it]

0.8338946316923413


 33%|███▎      | 20/60 [04:08<08:17, 12.43s/it]

0.8326146673588526


 35%|███▌      | 21/60 [04:21<08:04, 12.43s/it]

0.8317257052376157


 37%|███▋      | 22/60 [04:33<07:52, 12.43s/it]

0.8307972153027853


 38%|███▊      | 23/60 [04:46<07:40, 12.44s/it]

0.8304892849354517


 40%|████      | 24/60 [04:58<07:27, 12.44s/it]

0.828382609855561


 42%|████▏     | 25/60 [05:10<07:15, 12.43s/it]

0.830090290024167


 43%|████▎     | 26/60 [05:23<07:02, 12.43s/it]

0.8274444199743725


 45%|████▌     | 27/60 [05:35<06:50, 12.43s/it]

0.8267303194318499


 47%|████▋     | 28/60 [05:48<06:37, 12.43s/it]

0.8257691065470377


 48%|████▊     | 29/60 [06:00<06:25, 12.43s/it]

0.8249758709044683


 50%|█████     | 30/60 [06:13<06:12, 12.43s/it]

0.8232698227678027


 52%|█████▏    | 31/60 [06:25<06:00, 12.44s/it]

0.8242327173550924


 53%|█████▎    | 32/60 [06:37<05:48, 12.44s/it]

0.8240674847648257


 55%|█████▌    | 33/60 [06:50<05:35, 12.44s/it]

0.8214345687911624


 57%|█████▋    | 34/60 [07:02<05:23, 12.45s/it]

0.8208440045515696


 58%|█████▊    | 35/60 [07:15<05:11, 12.46s/it]

0.8206126477037158


 60%|██████    | 36/60 [07:27<04:59, 12.47s/it]

0.8179182254132771


 62%|██████▏   | 37/60 [07:40<04:46, 12.47s/it]

0.8205037060238066


 63%|██████▎   | 38/60 [07:52<04:35, 12.50s/it]

0.8194659919965834


 65%|██████▌   | 39/60 [08:05<04:22, 12.49s/it]

0.8158955857867286


 67%|██████▋   | 40/60 [08:17<04:09, 12.47s/it]

0.8181667327880859


 68%|██████▊   | 41/60 [08:30<03:56, 12.47s/it]

0.8170691615059262


 70%|███████   | 42/60 [08:42<03:44, 12.46s/it]

0.8164865800312587


 72%|███████▏  | 43/60 [08:55<03:31, 12.45s/it]

0.815773396264939


 73%|███████▎  | 44/60 [09:07<03:19, 12.47s/it]

0.8142596965744382


 75%|███████▌  | 45/60 [09:20<03:07, 12.48s/it]

0.8153577446937561


 77%|███████▋  | 46/60 [09:32<02:54, 12.48s/it]

0.8144903211366563


 78%|███████▊  | 47/60 [09:45<02:42, 12.47s/it]

0.8133819301923116


 80%|████████  | 48/60 [09:57<02:29, 12.46s/it]

0.8142089701834179


 82%|████████▏ | 49/60 [10:09<02:17, 12.46s/it]

0.8122166012014661


 83%|████████▎ | 50/60 [10:22<02:04, 12.46s/it]

0.8109706100963411


 85%|████████▌ | 51/60 [10:34<01:52, 12.46s/it]

0.8104643623034159


 87%|████████▋ | 52/60 [10:47<01:39, 12.45s/it]

0.8106990229515803


 88%|████████▊ | 53/60 [10:59<01:27, 12.46s/it]

0.8114625158764067


 90%|█████████ | 54/60 [11:12<01:14, 12.44s/it]

0.8098916014035543


 92%|█████████▏| 55/60 [11:24<01:02, 12.44s/it]

0.8116852399848756


 93%|█████████▎| 56/60 [11:37<00:49, 12.43s/it]

0.8088807719094413


 95%|█████████▌| 57/60 [11:49<00:37, 12.44s/it]

0.8094908140954518


 97%|█████████▋| 58/60 [12:01<00:24, 12.44s/it]

0.8091630566687811


 98%|█████████▊| 59/60 [12:14<00:12, 12.44s/it]

0.8075773602440244


100%|██████████| 60/60 [12:26<00:00, 12.45s/it]

0.8067337714490437
Generating captions...





tokenization...
computing Bleu score...
computing METEOR score...
computing Rouge score...
computing CIDEr score...
computing SPICE score...
computing Universal_Sentence_Encoder_Similarity score...


In [38]:
model_score5

{'Bleu_1': 0.6383266486334932,
 'Bleu_2': 0.5099921353509047,
 'Bleu_3': 0.42475045055064564,
 'Bleu_4': 0.36282966723852916,
 'METEOR': 0.2903262588930655,
 'ROUGE_L': 0.5347564638318366,
 'CIDEr': 1.9885872084732155,
 'SPICE': 0.3811500037292482,
 'USC_similarity': 0.5988246693868285}

In [39]:
model_scores = defaultdict(list)
for scores in [model_score1, model_score2, model_score3, model_score4, model_score5]:
    for key, value in scores.items():
        model_scores[key].append(value)

In [40]:
model_scores

defaultdict(list,
            {'Bleu_1': [0.6374539978750372,
              0.6232993589447667,
              0.6346993011263479,
              0.6297681376215113,
              0.6383266486334932],
             'Bleu_2': [0.5097287325539306,
              0.4915525450592119,
              0.5067599152598399,
              0.5045738328765402,
              0.5099921353509047],
             'Bleu_3': [0.42603430037193946,
              0.4058396978483188,
              0.423500759390165,
              0.4225585731564296,
              0.42475045055064564],
             'Bleu_4': [0.3651897571307535,
              0.34568139609959103,
              0.3635006872983065,
              0.3630833438964103,
              0.36282966723852916],
             'METEOR': [0.29132363875100775,
              0.2809060472919057,
              0.28896062185124893,
              0.28928258538798995,
              0.2903262588930655],
             'ROUGE_L': [0.5424600890989629,
              0.5196458203

In [41]:
tag = '9.1.1.1'
with open(f'{root_captioning}/fz_notebooks/cv_n{tag}.json', 'w') as fp:
    json.dump(model_scores, fp)