In [1]:
import os
import pickle
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import warnings
warnings.filterwarnings('ignore')

# Import Path,Vocabulary, utility, evaluator and datahandler module
from config import Path
from dictionary import Vocabulary
from utils import Utils
from evaluate import Evaluator
from data import DataHandler

#set seed for reproducibility
utils = Utils()
utils.set_seed(1)

# Mean Pooling

In [2]:
#Import configuration and model 
from config import ConfigMP
from models.mean_pooling.model import MeanPooling


from config import ConfigSALSTM
from models.SA_LSTM.model import SALSTM

#create Mean pooling object
cfg = ConfigMP()
# specifying the dataset in configuration object from {'msvd','msrvtt'}
cfg.dataset = 'msrvtt'
#creation of path object
path = Path(cfg,os.getcwd())

#Changing the hyperparameters in configuration object
cfg.batch_size = 100 #training batch size
cfg.n_layers = 2    # number of layers in decoder rnn
cfg.decoder_type = 'lstm'  # from {'lstm','gru'}

Vocabulary creation or load

In [3]:
#Vocabulary object
voc = Vocabulary(cfg)
#If vocabulary is already saved or downloaded the saved file
voc.load() #comment this if using vocabulary for the first time or with no saved file
print('Vocabulary Size : ',voc.num_words) 


# # Uncomment this block if using vocabulary for the first time or if there is no saved file
# text_dict = {}
# voc = Vocabulary(cfg)
# data_handler = DataHandler(cfg,path,voc)
# import json
# print(path.feature_file)
# json.load(open(path.feature_file))
# text_dict.update(data_handler.train_dict)
# text_dict.update(data_handler.val_dict)
# text_dict.update(data_handler.test_dict)
# for k,v in text_dict.items():
#     for anno in v:'
#         voc.addSentence(anno)
# voc.save()


##Uncomment this block for filtering Rare Words from Dictionary
# min_count = 2 #remove all words below count min_count
# voc.trim(min_count=2)
# print('Vocabulary Size : ',voc.num_words)

Vocabulary Size :  29327


Dataloaders model and evaluator

In [4]:
# Datasets and dataloaders
data_handler = DataHandler(cfg,path,voc)
train_dset,val_dset,test_dset = data_handler.getDatasets()
train_loader,val_loader,test_loader = data_handler.getDataloader(train_dset,val_dset,test_dset)

#Model object
model = MeanPooling(voc,cfg,path)
#Evaluator object on test data
test_evaluator = Evaluator(model,test_loader,path,cfg,data_handler.test_dict)

Training loop

In [5]:
#Training Loop
cfg.encoder_lr = 1e-4
cfg.decoder_lr = 1e-4
cfg.teacher_forcing_ratio = 1.0
model.update_hyperparameters(cfg)
for e in range(1,511):
    loss = model.train_epoch(train_loader,utils)
    if e%50 == 0 :
        print('Epoch -- >',e,'Loss -->',loss)
        print(test_evaluator.evaluate(utils,model,e,loss))

Epoch -- > 50 Loss --> 5.024275640105086
{'testlen': 17083, 'reflen': 19616, 'guess': [17083, 14093, 11103, 8113], 'correct': [11704, 5201, 2235, 480]}
ratio: 0.8708707177813585
{'Bleu_1': 0.5907105071688321, 'Bleu_2': 0.43354213369805633, 'Bleu_3': 0.31952249902913477, 'Bleu_4': 0.20197245965556046, 'METEOR': 0.18997056861640454, 'ROUGE_L': 0.4979413218606429, 'CIDEr': 0.09071018569955135}
Epoch -- > 100 Loss --> 4.670244797724767
{'testlen': 20704, 'reflen': 21822, 'guess': [20704, 17714, 14724, 11734], 'correct': [15758, 7718, 3443, 1205]}
ratio: 0.948767299055955
{'Bleu_1': 0.7210996267312337, 'Bleu_2': 0.5455893938392339, 'Bleu_3': 0.40401469516501004, 'Bleu_4': 0.283022016279368, 'METEOR': 0.23239376477749701, 'ROUGE_L': 0.5481253000758799, 'CIDEr': 0.21986267297381412}
Epoch -- > 150 Loss --> 4.375256969156208
{'testlen': 20305, 'reflen': 21482, 'guess': [20305, 17315, 14325, 11335], 'correct': [15972, 8102, 3653, 1322]}
ratio: 0.9452099432082234
{'Bleu_1': 0.7423043162637192, '

# SA-LSTM

In [None]:
#Import configuration and model 

from config import ConfigSALSTM
from models.SA_LSTM.model import SALSTM

#create Mean pooling object
cfg = ConfigSALSTM()
# specifying the dataset in configuration object from {'msvd','msrvtt'}
cfg.dataset = 'msvd'
#creation of path object
path = Path(cfg,os.getcwd())

#Changing the hyperparameters in configuration object
cfg.batch_size = 128 #training batch size
cfg.n_layers = 2    # number of layers in decoder rnn
cfg.decoder_type = 'lstm'  # from {'lstm','gru'}


#Vocabulary object, 
voc = Vocabulary(cfg)
#If vocabulary is already saved or downloaded the saved file
voc.load() #comment this if using vocabulary for the first time or with no saved file
print('Vocabulary Size : ',voc.num_words) 

In [None]:
# Datasets and dataloaders
data_handler = DataHandler(cfg,path,voc)
train_dset,val_dset,test_dset = data_handler.getDatasets()
train_loader,val_loader,test_loader = data_handler.getDataloader(train_dset,val_dset,test_dset)

#Model object
model = SALSTM(voc,cfg,path)
#Evaluator object on test data
test_evaluator = Evaluator(model,test_loader,path,cfg,data_handler.test_dict)

In [None]:
#Training Loop
cfg.decoder_lr = 1e-5
cfg.teacher_forcing_ratio = 1.0
model.update_hyperparameters(cfg)
for e in range(1,511):
    loss = model.train_epoch(train_loader,utils)
    if e%50 == 0 :
        print('Epoch -- >',e,'Loss -->',loss)
        print(test_evaluator.evaluate(utils,model,e))

In [None]:
#torch.save(model,'SA_LSTM_lstm_2000.pt')
#model = torch.load('msrvtt_lstm_mp.pt')

In [None]:
dataiter = iter(test_loader)
features, target, mask, max_length,ides= dataiter.next()

features.size(), target[:,5], mask[:,5],max_length

In [None]:
tsr,txt = model.GreedyDecoding(features.to(cfg.device))
txt

In [None]:
utils.target_tensor_to_caption(voc,target)