In [14]:
import os
import torch

from get_loader import (VideoDataset_to_VideoCaptionsLoader, Vocabulary,
                        get_loader)
from trainer import Trainer
import pandas as pd

In [17]:
gpu = '0'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = gpu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoints_dir = os.path.join("checkpoints")
output_csv_dir = os.path.join("results")

batch_size = 128

In [5]:
# Get the dataset and dataloaders
dataset_folder = os.path.join("datasets", "MSVD")
vocab_pkl = os.path.join(dataset_folder, "metadata", "vocab.pkl")
vocab = Vocabulary.load(vocab_pkl)

val_loader, _ = get_loader(
    root_dir=dataset_folder,
    split="val",
    batch_size=batch_size,
    vocab_pkl=vocab_pkl,
)
test_loader, _ = get_loader(
    root_dir=dataset_folder,
    split="test",
    batch_size=batch_size,
    vocab_pkl=vocab_pkl,
)

val_vidCap_loader = VideoDataset_to_VideoCaptionsLoader(val_loader.dataset, batch_size)
test_vidCap_loader = VideoDataset_to_VideoCaptionsLoader(test_loader.dataset, batch_size)

Before integrity check: 3720
After integrity check: 2562
After removing unverified: 1055
Loading Vocab: datasets\MSVD\metadata\vocab.pkl 
Before integrity check: 24452
After integrity check: 19061
After removing unverified: 7671
Loading Vocab: datasets\MSVD\metadata\vocab.pkl 


In [21]:
# Get a trainer for model evaluation
tr = Trainer(checkpoint_name='test.ckpt', log_dir='trash')
tr.device = device

# Load the models and predict
for ckpt in os.listdir(checkpoints_dir):
    if ckpt.endswith('_best.pt'):
        print("\nLoading model from checkpoint:", ckpt)
        model = torch.load(os.path.join(checkpoints_dir, ckpt))

        for loader, phase in zip([val_vidCap_loader, test_vidCap_loader], ['val', 'test']):
            _, true_captions, generated_captions = tr.eval(model, val_vidCap_loader, training_phase=phase, epoch=0, get_scores=False)

            model_name = ckpt.split('_best')[0]
            model_results = {
                'generated_captions': generated_captions,
                'true_captions': true_captions,
            }

            if not os.path.isdir(os.path.join(output_csv_dir, phase)):
                os.makedirs(os.path.join(output_csv_dir, phase))

            df = pd.DataFrame(model_results)
            df.to_csv(
                os.path.join(output_csv_dir, phase, f"{model_name}.csv"),
                header=True,
                columns=["generated_captions", "true_captions"],
            )

> [a man is slicing a loaf of bread .] (a man is filling up a bag with meat .)
bSrpvMSuhPM_17_31 >> [a woman is putting a potato .] (a man is talking to a woman .)
bXsKw3TOQXs_30_55 >> [a man is playing with a ball .] (a small piece of paper is being folded .)
b_BuSVZwq6M_1_9 >> [a man is playing a goal of wood .] (a man makes a great play in a cricket game .)
bb6V0Grtub4_174_185 >> [a man is riding a horse .] (a man is playing on drums .)
bkazguPsusc_74_85 >> [a man is lifting a car .] (a cat is sliding under a couch .)
bmxIurBrW5s_51_70 >> [a man is riding a horse .] (a woman practicing a volleyball)
bruzcOyIGeg_4_12 >> [a boy is playing with a ball .] (a man drives a remote control car .)
btuxO-C2IzE_64_72 >> [a woman is riding a horse .] (a lion jumps up and is hugged and petted by two long - <UNK> men .)
buJ5HDCinrM_150_166 >> [a woman is chopping a potato .] (a woman is putting make up on her face .)
bxDlC7YV5is_0_12 >> [a man is playing a guitar .] (a boy is playing a key - boar

In [26]:
train_loader, train_dataset = get_loader(
    root_dir=dataset_folder,
    split="train",
    batch_size=batch_size,
    vocab_pkl=vocab_pkl,
    # normalize=exp['normalize_dataset'],
)
val_loader, val_dataset = get_loader(
    root_dir=dataset_folder,
    split="val",
    batch_size=batch_size,
    vocab_pkl=vocab_pkl,
    # normalize=exp['normalize_dataset'],
)
test_loader, test_dataset = get_loader(
    root_dir=dataset_folder,
    split="test",
    batch_size=batch_size,
    vocab_pkl=vocab_pkl,
    # normalize=exp['normalize_dataset'],
)

print("Train:", len(train_dataset.metadata['VideoID'].unique()))
print("Validation:", len(val_dataset.metadata['VideoID'].unique()))
print("Test:", len(test_dataset.metadata['VideoID'].unique()))

Before integrity check: 42990
After integrity check: 32509
After removing unverified: 13291
Loading Vocab: datasets\MSVD\metadata\vocab.pkl 
Before integrity check: 3720
After integrity check: 2562
After removing unverified: 1055
Loading Vocab: datasets\MSVD\metadata\vocab.pkl 
Before integrity check: 24452
After integrity check: 19061
After removing unverified: 7671
Loading Vocab: datasets\MSVD\metadata\vocab.pkl 
Train: 624
Validation: 60
Test: 392
