In [1]:
import os
import pickle

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

%matplotlib inline

from dictionary import Vocabulary,EOS_token,PAD_token,SOS_token,UNK_token
from evaluate import Evaluator

from models.mean_pooling.config import Config,Path
from models.mean_pooling.data import DataHandler
from models.mean_pooling.model import MeanPooling
from models.mean_pooling import utils

cfg = Config()
cfg.dataset = 'msrvtt'
path = Path(cfg,os.getcwd())
utils.set_seed(1)
#cfg.batch_size = 100
#cfg.n_layers = 2



    image feature dimension = 1920 (4096 in original paper)
    
    embedding dimension = 256
    
    word embedding = 256
    
    dropout = 0.4
    
    hidden memory = 256
    
    number of layers in LSTM = 2 (2 in original paper)
    
    video feature fed to = initial hidden memory of the decoder (not same as paper)
    
    Training method = mixture of teacher forcing \& sampling
    
    Initial teacher forcing ratio = 0.3
    
    optimizer = Adam
    


# Vocabulary Object Creation and loading

In [2]:
# #if first time or vocabulary is not saved
# text_dict = {}
# voc = Vocabulary(cfg.dataset)
# data_handler = DataHandler(cfg,path,voc)
# text_dict.update(data_handler.train_dict)
# text_dict.update(data_handler.val_dict)
# text_dict.update(data_handler.test_dict)
# for k,v in text_dict.items():
#     for anno in v:'
#         voc.addSentence(anno)

# voc.save()

#if vocabulary is already saved
voc = Vocabulary(cfg.dataset)
voc.load()
print('Vocabulary Size : ',voc.num_words)


#Filter Rare Words from Dictionary
# min_count = 2
# voc.trim(min_count=2)
# print('Vocabulary Size : ',voc.num_words)

Vocabulary Size :  29327


# Data Loader

In [3]:
data_handler = DataHandler(cfg,path,voc)
train_dset,val_dset,test_dset = data_handler.getDatasets()
train_loader,val_loader,test_loader = data_handler.getDataloader(train_dset,val_dset,test_dset)

# dataiter = iter(train_loader)
# features, targets, mask, max_length,ides= dataiter.next()
# features.size(), targets[:,5], mask[:,5],max_length,ides

test_evaluator = Evaluator('test','test',test_loader,cfg,data_handler.test_dict)

  self.feature_dict[key] = f1[key].value.mean(axis=0)


In [4]:
model = MeanPooling(voc,cfg,path)

In [None]:
cfg.encoder_lr = 1e-4
cfg.decoder_lr = 1e-4
cfg.teacher_forcing_ratio = 1.0
model.update_hyperparameters(cfg)
for e in range(311,511):
    loss = model.train_epoch(train_loader)
    if e%10 == 0 :
        print('Epoch -- >',e,'Loss -->',loss)
        print(test_evaluator.evaluate(model,e))

In [5]:
#torch.save(model,'msrvtt_lstm_mp.pt')
model = torch.load('msrvtt_lstm_mp.pt')

In [10]:

dataiter = iter(test_loader)

In [11]:
features, target, mask, max_length,ides= dataiter.next()

features.size(), target[:,5], mask[:,5],max_length

(torch.Size([10, 1536]),
 tensor([  85, 2248, 1863,    2,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0]),
 tensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 15)

In [12]:
tsr,txt = model.GreedyDecoding(features.to(cfg.device))
txt

['a group of people are dancing',
 'a man is singing',
 'a man is singing a song',
 'a person is folding a paper',
 'a group of people are playing a basketball',
 'a man is dancing',
 'a man is riding a motorcycle',
 'a person is folding a piece of paper',
 'a person is cooking a dish',
 'a person is folding a paper airplane']

In [13]:
utils.target_tensor_to_caption(voc,target)

['a man is singing a song and playing guitar and dancing with other s EOS',
 'a man dances at a wedding EOS PAD PAD PAD PAD PAD PAD PAD PAD',
 'a woman is singing EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD',
 'someone is folding paper EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD',
 'people fighting in a basketball game EOS PAD PAD PAD PAD PAD PAD PAD PAD',
 'people do karate EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD',
 'footage from a monster truck style event followed by a frat party EOS PAD PAD',
 'a person folding a piece of paper into a paper airplane EOS PAD PAD PAD',
 'person lighting a kettle EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD',
 'person folding paper EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD']

In [None]:
print(test_evaluator.evaluate(model,500))