In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
from torch.utils.data import Dataset,DataLoader
import numpy as np
from tqdm import tqdm


import os
import json
from itertools import accumulate
from random import randint

In [4]:
class Encoder(nn.Module):
    def __init__(self, ninp, nhid, nlayers=1, bidirectional=False):
        super(Encoder, self).__init__()
        self.nlayers = nlayers
        self.gru = nn.GRU(
                     input_size = ninp,
                     hidden_size = nhid,
                     num_layers = nlayers,
                     batch_first = True,
                     bidirectional = bidirectional
                   )
        self.nhid = nhid
        self.ndirection = 1 + bidirectional
        
    def forward(self, data, hidden):
        for i in range(self.nlayers):
            output, hidden = self.gru(data, hidden)
        return output, hidden
    
    def init_hidden(self, bsz):
        result = Variable(torch.zeros(self.ndirection, bsz, self.nhid))
        return result.cuda()
    
class Decoder(nn.Module):
    def __init__(self, ninp, nhid, ntoken, nlayers=1, bidirectional=False):
        super(Decoder, self).__init__()
        self.nlayers = nlayers
        self.gru = nn.GRU(
                     input_size = ninp,
                     hidden_size = nhid,
                     num_layers = nlayers,
                     batch_first = True,
                     bidirectional = bidirectional
                   )
        self.nhid = nhid
        self.ntoken = ntoken
        self.ndirection = 1 + bidirectional

        self.hidden2out = nn.Linear(nhid * self.ndirection, ntoken)
        
    def forward(self, data, hidden):
        for i in range(self.nlayers):
            output, hidden = self.gru(data, hidden)
            output1 = self.hidden2out(output)
        return output, hidden, output1
    
    def init_hidden(self, bsz):
        result = Variable(torch.zeros(self.ndirection, bsz, self.nhid))
        return result.cuda()

In [2]:
VOCAB_SIZE = 6773
INPUT_DIM = 4096
DATA_DIR = './MLDS_hw2_data/training_data/feat'
SAVE_DIR = './save'
LABEL_PATH = './MLDS_hw2_data/training_label.json'
BATCH_SIZE = 32
BOS_TOKEN = 1
EOS_TOKEN = 2
hidden_size = 256
BIDIRECTIONAL = False
model = 'S2VT'
postfix = 'new'
class Vocab:
    def __init__(self,label_path):
        print("Building Vocab")
        with open(label_path,'r') as f:
            self.label = json.load(f)
        self.vocab2index = {'<PAD>':0, '<BOS>':1, '<EOS>':2}
        self.index2vocab = {0: '<PAD>', 1:'<BOS>', 2:'<EOS>'}
        self.num_words = len(self.vocab2index)
        self.build()
        
        print(self.num_words, 'words in the bank.')

    def build(self):
        for l in self.label:
            for line in l["caption"]:
                line = line.replace('.','')
                line = line.replace('!','')
                line = line.replace('(','')
                line = line.replace(')','')
                for w in line.strip().split():
                    if w not in self.vocab2index.keys():
                        self.vocab2index[w] = self.num_words
                        self.index2vocab[self.num_words] = w
                        self.num_words += 1
                        
    def sen2index(self, sen):
        return [BOS_TOKEN] + [self.vocab2index[word] for word in sen] + [EOS_TOKEN]

class TA_Dataset(Dataset):
    def __init__(self,data_dir,label_path,batch_sample = True):
        print("Preparing dataset")
        self.data_dir = data_dir
        with open(label_path,'r') as f:
            self.label = json.load(f)
        self.batch_sample = batch_sample
        
        if batch_sample:
            self.total_len = len(self.label)
        else:
            temp = [ len(video['caption']) for video in label]
            self.accu = accumulate(temp)
            self.total_len = sum(temp) 
            
    def __len__(self):
        return self.total_len
    
    def __getitem__(self, index):
        if self.batch_sample:
            avi_id = self.label[index]["id"]+'.npy'
            feat = np.load(os.path.join(self.data_dir, avi_id))
            captions = self.label[index]['caption']
            caption = captions[randint(0, len(captions)-1)]
            
            for ch in '.!()':
                if ch in caption:
                    caption = caption.replace(ch, '')
            
            feat = torch.FloatTensor(feat)
            caption = vocab.sen2index(caption.split())
            caption = torch.LongTensor(caption + [0]*(50 - len(caption)))
            
            return feat, caption
        else:
            for i, number in enumerate(self.accu):
                if number > index:
                    avi_id = self.label[i-1]['id'] + '.npy'
                    feat = np.load(os.path.join(self.data_dir, avi_id))
                    caption = self.label[i-1][index - self.accu[i-1]]
                    
                    for ch in '.!()':
                        if ch in caption:
                            caption = caption.replace(ch, '')

                    feat = torch.FloatTensor(feat)
                    caption = vocab.sen2index(caption.split())
                    caption = torch.LongTensor(caption + [0]*(50 - len(caption)))

                    return feat, caption

In [10]:
epoch = 100
bsz = 16
vocab = Vocab('../data/MLDS_hw2_data/training_label.json')
DS = TA_Dataset('../data/MLDS_hw2_data/training_data/feat', '../data/MLDS_hw2_data/training_label.json')
# model = S2VT(4096, 6772, 256, 1).cuda()
encoder = Encoder(INPUT_DIM, hidden_size, 1, BIDIRECTIONAL).cuda()
decoder = Decoder(hidden_size+hidden_size, hidden_size, VOCAB_SIZE, 1, BIDIRECTIONAL).cuda()
criterion = nn.CrossEntropyLoss().cuda()
e_optimizer = optim.Adam(encoder.parameters(), lr=1e-3)
d_optimizer = optim.Adam(decoder.parameters(), lr=1e-3)

Building Vocab
6773 words in the bank.
Preparing dataset


In [11]:
train_loader = DataLoader(
                   dataset = DS,
                   batch_size = 16,
                   shuffle = True,
                   num_workers = 4
                )

In [12]:
#############################################
#   Training
#############################################

for i in range(epoch):
    total_loss = 0
    total_len = 0
    for feat, caption in tqdm(train_loader):
        e_optimizer.zero_grad()
        d_optimizer.zero_grad()
        
        e_hid, d_hid = encoder.init_hidden(feat.size(0)), decoder.init_hidden(feat.size(0))
        epad = Variable(torch.zeros(feat.size(0), 80, hidden_size*(1+BIDIRECTIONAL))).cuda()
        dpad = Variable(torch.zeros(feat.size(0), 1, INPUT_DIM)).cuda()
        
        feat = Variable(feat).cuda()
        e_out, e_hid = encoder(feat, e_hid)
        d_out, d_hid, dn_out = decoder(torch.cat((epad, e_out), 2), d_hid)
        
        last_var = torch.zeros(feat.size(0), 1, hidden_size * (1+BIDIRECTIONAL))
        last_var = Variable(last_var).cuda()
        loss = 0
        
        for index in range(caption.size(-1)):
            out1, e_hid = encoder(dpad, e_hid) 
            out2, d_hid, dn_out = decoder(torch.cat((last_var, out1),2), d_hid)
            last_var = out2
            ans = Variable(caption[:,index]).cuda()
            loss += criterion(dn_out.view(-1, VOCAB_SIZE), ans)
        total_loss += loss.data[0]
        total_len += feat.size(0)
        loss.backward()
        e_optimizer.step()
        d_optimizer.step()
    print ("epoch", i+1,"loss : ", total_loss/total_len)
    

100%|██████████| 91/91 [00:14<00:00,  6.39it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 1 loss :  4.796872919016871


100%|██████████| 91/91 [00:13<00:00,  6.58it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 2 loss :  3.0229291284495385


100%|██████████| 91/91 [00:14<00:00,  6.40it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 3 loss :  2.9638055183147562


100%|██████████| 91/91 [00:14<00:00,  6.26it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 4 loss :  2.86524468520592


100%|██████████| 91/91 [00:14<00:00,  6.17it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 5 loss :  2.9276112812963024


100%|██████████| 91/91 [00:14<00:00,  6.27it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 6 loss :  2.8199515191439923


100%|██████████| 91/91 [00:15<00:00,  5.93it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 7 loss :  2.8711292332616347


100%|██████████| 91/91 [00:14<00:00,  6.11it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 8 loss :  2.7445101060538457


100%|██████████| 91/91 [00:14<00:00,  6.26it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 9 loss :  2.8143863520128973


100%|██████████| 91/91 [00:15<00:00,  5.74it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 10 loss :  2.737545581686086


100%|██████████| 91/91 [00:15<00:00,  5.74it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 11 loss :  2.6954489609290815


100%|██████████| 91/91 [00:15<00:00,  5.81it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 12 loss :  2.641676633111362


100%|██████████| 91/91 [00:15<00:00,  5.84it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 13 loss :  2.638446256703344


100%|██████████| 91/91 [00:15<00:00,  5.79it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 14 loss :  2.6292021244969863


100%|██████████| 91/91 [00:15<00:00,  5.97it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 15 loss :  2.496824739390406


100%|██████████| 91/91 [00:14<00:00,  6.30it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 16 loss :  2.6425691736155543


100%|██████████| 91/91 [00:14<00:00,  6.32it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 17 loss :  2.559693065511769


100%|██████████| 91/91 [00:14<00:00,  6.24it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 18 loss :  2.5255504239838698


100%|██████████| 91/91 [00:14<00:00,  6.26it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 19 loss :  2.5710223428134262


100%|██████████| 91/91 [00:15<00:00,  5.92it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 20 loss :  2.507055892944336


100%|██████████| 91/91 [00:14<00:00,  6.12it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 21 loss :  2.3954088395217368


100%|██████████| 91/91 [00:14<00:00,  6.29it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 22 loss :  2.5156780545464876


100%|██████████| 91/91 [00:15<00:00,  5.94it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 23 loss :  2.5132491144640694


100%|██████████| 91/91 [00:16<00:00,  5.63it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 24 loss :  2.4879088710916455


100%|██████████| 91/91 [00:16<00:00,  5.58it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 25 loss :  2.4368191791402882


100%|██████████| 91/91 [00:16<00:00,  5.56it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 26 loss :  2.355327340487776


100%|██████████| 91/91 [00:15<00:00,  5.74it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 27 loss :  2.387327152120656


100%|██████████| 91/91 [00:15<00:00,  5.90it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 28 loss :  2.359730252233045


100%|██████████| 91/91 [00:16<00:00,  5.59it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 29 loss :  2.3273862957132274


100%|██████████| 91/91 [00:16<00:00,  5.59it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 30 loss :  2.4242064653593918


100%|██████████| 91/91 [00:16<00:00,  5.56it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 31 loss :  2.3606816535160458


100%|██████████| 91/91 [00:16<00:00,  5.38it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 32 loss :  2.410198272178913


100%|██████████| 91/91 [00:15<00:00,  5.78it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 33 loss :  2.26796600867962


100%|██████████| 91/91 [00:15<00:00,  5.83it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 34 loss :  2.3224379164597084


100%|██████████| 91/91 [00:15<00:00,  5.89it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 35 loss :  2.432056875557735


100%|██████████| 91/91 [00:14<00:00,  6.47it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 36 loss :  2.281541820394582


100%|██████████| 91/91 [00:14<00:00,  6.25it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 37 loss :  2.283194648479593


100%|██████████| 91/91 [00:14<00:00,  6.24it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 38 loss :  2.2349780418132914


100%|██████████| 91/91 [00:13<00:00,  6.51it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 39 loss :  2.2751129965946593


100%|██████████| 91/91 [00:16<00:00,  5.68it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 40 loss :  2.2223189912993333


100%|██████████| 91/91 [00:16<00:00,  5.68it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 41 loss :  2.25972953927928


100%|██████████| 91/91 [00:15<00:00,  5.75it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 42 loss :  2.118805010565396


100%|██████████| 91/91 [00:16<00:00,  5.67it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 43 loss :  2.1901306349655676


100%|██████████| 91/91 [00:16<00:00,  5.63it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 44 loss :  2.2397022063156653


100%|██████████| 91/91 [00:14<00:00,  6.09it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 45 loss :  2.2441217330406453


100%|██████████| 91/91 [00:14<00:00,  6.37it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 46 loss :  2.2489501505884633


100%|██████████| 91/91 [00:15<00:00,  5.93it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 47 loss :  2.1296568692963698


100%|██████████| 91/91 [00:15<00:00,  5.96it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 48 loss :  2.0985507136377795


100%|██████████| 91/91 [00:15<00:00,  5.98it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 49 loss :  2.127141580252812


100%|██████████| 91/91 [00:14<00:00,  6.34it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 50 loss :  2.1042049868353483


100%|██████████| 91/91 [00:14<00:00,  6.10it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 51 loss :  2.10038984496018


100%|██████████| 91/91 [00:16<00:00,  5.59it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 52 loss :  2.118941684591359


100%|██████████| 91/91 [00:14<00:00,  6.15it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 53 loss :  2.1165977530643856


100%|██████████| 91/91 [00:14<00:00,  6.27it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 54 loss :  2.0372555673533475


100%|██████████| 91/91 [00:15<00:00,  6.04it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 55 loss :  2.0310452323124326


100%|██████████| 91/91 [00:14<00:00,  6.27it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 56 loss :  2.0473309931261787


100%|██████████| 91/91 [00:14<00:00,  6.13it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 57 loss :  2.0285599373126852


100%|██████████| 91/91 [00:15<00:00,  5.96it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 58 loss :  2.0779944728982858


100%|██████████| 91/91 [00:14<00:00,  6.22it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 59 loss :  2.054936608939335


100%|██████████| 91/91 [00:14<00:00,  6.38it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 60 loss :  2.024508409171269


100%|██████████| 91/91 [00:14<00:00,  6.39it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 61 loss :  1.9863517432377256


100%|██████████| 91/91 [00:15<00:00,  5.73it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 62 loss :  1.918145734852758


100%|██████████| 91/91 [00:15<00:00,  5.71it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 63 loss :  2.052486331545073


100%|██████████| 91/91 [00:14<00:00,  6.17it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 64 loss :  1.9816465483040646


100%|██████████| 91/91 [00:14<00:00,  6.44it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 65 loss :  2.0343687675739157


100%|██████████| 91/91 [00:14<00:00,  6.41it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 66 loss :  1.9489605042030071


100%|██████████| 91/91 [00:14<00:00,  6.10it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 67 loss :  2.0658445082039667


100%|██████████| 91/91 [00:14<00:00,  6.16it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 68 loss :  1.95071182908683


100%|██████████| 91/91 [00:14<00:00,  6.20it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 69 loss :  1.9547638215689824


100%|██████████| 91/91 [00:14<00:00,  6.37it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 70 loss :  1.9261830047081256


100%|██████████| 91/91 [00:13<00:00,  6.52it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 71 loss :  1.8620912078331258


100%|██████████| 91/91 [00:13<00:00,  6.61it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 72 loss :  1.9327720615781587


100%|██████████| 91/91 [00:13<00:00,  6.81it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 73 loss :  1.9302186097769902


100%|██████████| 91/91 [00:13<00:00,  6.85it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 74 loss :  1.8885094978069437


100%|██████████| 91/91 [00:13<00:00,  6.69it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 75 loss :  1.9701028534461713


100%|██████████| 91/91 [00:13<00:00,  6.72it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 76 loss :  1.8891925154061153


100%|██████████| 91/91 [00:13<00:00,  6.78it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 77 loss :  1.850044875309385


100%|██████████| 91/91 [00:13<00:00,  6.59it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 78 loss :  1.851461690705398


100%|██████████| 91/91 [00:13<00:00,  6.71it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 79 loss :  1.800219533196811


100%|██████████| 91/91 [00:13<00:00,  6.78it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 80 loss :  1.8879962697522394


100%|██████████| 91/91 [00:13<00:00,  6.52it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 81 loss :  1.828589869532092


100%|██████████| 91/91 [00:13<00:00,  6.79it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 82 loss :  1.839430449913288


100%|██████████| 91/91 [00:13<00:00,  6.73it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 83 loss :  1.8531535799749965


100%|██████████| 91/91 [00:13<00:00,  6.74it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 84 loss :  1.880437864106277


100%|██████████| 91/91 [00:13<00:00,  6.82it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 85 loss :  1.8707146874789533


100%|██████████| 91/91 [00:13<00:00,  6.80it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 86 loss :  1.8794891607350317


100%|██████████| 91/91 [00:13<00:00,  6.78it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 87 loss :  1.837823511320969


100%|██████████| 91/91 [00:13<00:00,  6.58it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 88 loss :  1.8543542704088936


100%|██████████| 91/91 [00:14<00:00,  6.49it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 89 loss :  1.8030151696040713


100%|██████████| 91/91 [00:13<00:00,  6.67it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 90 loss :  1.785056695609257


100%|██████████| 91/91 [00:13<00:00,  6.75it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 91 loss :  1.7835150120176118


100%|██████████| 91/91 [00:13<00:00,  6.69it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 92 loss :  1.7927738492242222


100%|██████████| 91/91 [00:13<00:00,  6.77it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 93 loss :  1.850563280829068


100%|██████████| 91/91 [00:13<00:00,  6.74it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 94 loss :  1.7239202907167632


100%|██████████| 91/91 [00:13<00:00,  6.74it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 95 loss :  1.7193121219503469


100%|██████████| 91/91 [00:13<00:00,  6.78it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 96 loss :  1.7807588577270508


100%|██████████| 91/91 [00:13<00:00,  6.72it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 97 loss :  1.7832908367288525


100%|██████████| 91/91 [00:13<00:00,  6.69it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 98 loss :  1.7681344176983012


100%|██████████| 91/91 [00:13<00:00,  6.76it/s]
  0%|          | 0/91 [00:00<?, ?it/s]

epoch 99 loss :  1.776755503950448


100%|██████████| 91/91 [00:13<00:00,  6.73it/s]

epoch 100 loss :  1.7744299290097993





In [15]:
#############################################
#   Testing
#############################################
class Testset(Dataset):
    def __init__(self, data, ids):
        self.data_dir = data
        self.label = []
        with open(ids, 'r') as f:
            for line in f:
                self.label.append(line.strip())
                
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, index):
        avi_id = self.label[index] + '.npy'
        data = np.load(os.path.join(self.data_dir, avi_id))
        return self.label[index], torch.from_numpy(data.astype(np.float32))

def sentence(l):
    output = ''
    for i in l:
        if i == EOS_TOKEN:
            output += '.'
            break
        elif i == 0 or i == BOS_TOKEN:
            continue
        else:
            output += vocab.index2vocab[i] + ' '
    return output
    

# with open('output_model.ffs', 'rb') as f:
#     model = torch.load(f)
    
DS = Testset('../data/MLDS_hw2_data/testing_data/feat', '../data/MLDS_hw2_data/testing_id.txt')

# model = model.cuda()
encoder.eval()
decoder.eval()

for idx, feat in DS:
    feat = Variable(feat).view(1, -1, INPUT_DIM).cuda()
    e_hid = encoder.init_hidden(1)
    d_hid = decoder.init_hidden(1)
    
    epad = Variable(torch.zeros(1, feat.size(1), hidden_size * (1+BIDIRECTIONAL))).cuda()
    dpad = Variable(torch.zeros(1, 1, INPUT_DIM)).cuda()
    
    e_out, e_hid = encoder(feat, e_hid)
    d_out, d_hid, dn_out = decoder(torch.cat((epad, e_out), 2), d_hid)

    last_var = torch.zeros(1, 1, hidden_size * (1+BIDIRECTIONAL))
    last_var[0][0][0] = 1 # BOS
    last_var = Variable(last_var).cuda()
    
    loss = 0
    ans = []
    for i in range(50):
        out1, e_hid = encoder(dpad, e_hid)
        out2, d_hid, dn_out = decoder(torch.cat((last_var, out1), 2), d_hid)
        last_var = out2
        ans.append(torch.max(dn_out, 2)[1].data[0][0])

    print(idx + ',' + sentence(ans))

ScdUht-pM6s_53_63.avi,A woman is a a a a 
wkgGxsuNVSg_34_41.avi,A man is running a a .
BtQtRGI0F2Q_15_20.avi,A man is doing .
k06Ge9ANKM8_5_16.avi,A little is is with a .
sZf3VDsdDPM_107_114.avi,A man is singing .
shPymuahrsc_5_12.avi,A slow is is a a .
XOAgUVVwKEA_8_20.avi,A girl is is a .
ufFT2BWh3BQ_0_8.avi,A baby panda laying on .
5YJaS2Eswg0_22_26.avi,A man is a a a .
lw7pTwpx0K0_38_48.avi,A man is a a .
UbmZAe5u5FI_132_141.avi,A person is cutting .
xCFCXzDUGjY_5_9.avi,A man is a a .
He7Ge7Sogrk_47_70.avi,A person is a a .
tJHUH9tpqPg_113_118.avi,A man is a a a .
n016q1w8Q30_2_11.avi,A person is folding an .
RjpbFlOHFps_8_25.avi,Two are are .
6JnGBs88sL0_4_10.avi,A girl girl on on a 
EpMuCrbxE8A_107_115.avi,A man is a a a .
HAjwXjwN9-A_16_24.avi,Two are are a a .
4xVGpDmA4lE_23_33.avi,A man is walking a a .
k5OKBX2e7xA_19_32.avi,A man is riding a a .
Jag7oTemldY_12_25.avi,A man is firing a a .
8MVo7fje_oE_125_130.avi,A man is a a plastic a a .
bqMmyY1ImkI_0_14.avi,A woman is is a 

In [16]:
with open('MODEL_ENCODER', 'wb') as f:
    torch.save(encoder, f)
with open('MODEL_DECODER', 'wb') as f:
    torch.save(decoder, f)

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [30]:

hid1, hid2 = model.init_hidden(1), model.init_hidden(1)
for i in range(epoch):
    total_loss = 0
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        feat = data[0]
        feat = Variable(torch.from_numpy(feat).view(1, feat.shape[0], -1).float()).cuda()
        caption = data[1][0]
        
        for ch in '.!()':
            if ch in caption:
                caption = caption.replace(ch, '')

        tar = Variable(torch.LongTensor([vocabs.vocab2index[word] for word in caption.split()]+[EOS_TOKEN])).view(1, -1).cuda()
        output = model(feat, tar.size(1))
        
        loss = criterion(output.view(-1, 6772), tar.view(-1))
        total_loss += loss.data
        loss.backward()
        
        torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)
        optimizer.step()
        # hid1, hid2 = repackage_hidden(hid1), repackage_hidden(hid2)
    print('Epoch {} Loss: {}'.format(i, total_loss/len(DS)))
    
with open('output_model.ffs', 'wb') as f:
    torch.save(model, f)


100%|██████████| 1450/1450 [01:45<00:00, 13.74it/s]
  0%|          | 2/1450 [00:00<01:42, 14.15it/s]

Epoch 0 Loss: 
 5.1666
[torch.cuda.FloatTensor of size 1 (GPU 0)]



100%|██████████| 1450/1450 [01:51<00:00, 13.03it/s]
  0%|          | 2/1450 [00:00<01:39, 14.62it/s]

Epoch 1 Loss: 
 4.4180
[torch.cuda.FloatTensor of size 1 (GPU 0)]



100%|██████████| 1450/1450 [01:44<00:00, 13.91it/s]
  0%|          | 2/1450 [00:00<01:37, 14.83it/s]

Epoch 2 Loss: 
 4.2193
[torch.cuda.FloatTensor of size 1 (GPU 0)]



100%|██████████| 1450/1450 [01:44<00:00, 13.91it/s]
  0%|          | 2/1450 [00:00<01:38, 14.63it/s]

Epoch 3 Loss: 
 4.0765
[torch.cuda.FloatTensor of size 1 (GPU 0)]



100%|██████████| 1450/1450 [01:43<00:00, 13.99it/s]

Epoch 4 Loss: 
 3.9658
[torch.cuda.FloatTensor of size 1 (GPU 0)]




  "type " + obj.__name__ + ". It won't be checked "


In [5]:
vocab = Vocab('../data/MLDS_hw2_data/training_label.json')

with open('MODEL_ENCODER', 'rb') as f:
    encoder = torch.load(f)

with open('MODEL_DECODER', 'rb') as f:
    decoder = torch.load(f)


#############################################
#   Testing
#############################################
class Testset(Dataset):
    def __init__(self, data, ids):
        self.data_dir = data
        self.label = []
        with open(ids, 'r') as f:
            for line in f:
                self.label.append(line.strip())
                
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, index):
        avi_id = self.label[index] + '.npy'
        data = np.load(os.path.join(self.data_dir, avi_id))
        return self.label[index], torch.from_numpy(data.astype(np.float32))

def sentence(l):
    output = ''
    for i in l:
        if i == EOS_TOKEN:
            output += '.'
            break
        elif i == 0 or i == BOS_TOKEN:
            continue
        else:
            output += vocab.index2vocab[i] + ' '
    return output
    

# with open('output_model.ffs', 'rb') as f:
#     model = torch.load(f)
    
DS = Testset('../data/MLDS_hw2_data/testing_data/feat', '../data/MLDS_hw2_data/testing_id.txt')

# model = model.cuda()
encoder.eval()
decoder.eval()

for idx, feat in DS:
    feat = Variable(feat).view(1, -1, INPUT_DIM).cuda()
    e_hid = encoder.init_hidden(1)
    d_hid = decoder.init_hidden(1)
    
    epad = Variable(torch.zeros(1, feat.size(1), hidden_size * (1+BIDIRECTIONAL))).cuda()
    dpad = Variable(torch.zeros(1, 1, INPUT_DIM)).cuda()
    
    e_out, e_hid = encoder(feat, e_hid)
    d_out, d_hid, dn_out = decoder(torch.cat((epad, e_out), 2), d_hid)

    last_var = torch.zeros(1, 1, hidden_size * (1+BIDIRECTIONAL))
    last_var[0][0][0] = 1 # BOS
    last_var = Variable(last_var).cuda()
    
    loss = 0
    ans = []
    for i in range(50):
        out1, e_hid = encoder(dpad, e_hid)
        out2, d_hid, dn_out = decoder(torch.cat((last_var, out1), 2), d_hid)
        last_var = out2
        ans.append(torch.max(dn_out, 2)[1].data[0][0])

    print(idx + ',' + sentence(ans))

Building Vocab
6773 words in the bank.




ScdUht-pM6s_53_63.avi,A woman is a a a a 
wkgGxsuNVSg_34_41.avi,A man is running a a .
BtQtRGI0F2Q_15_20.avi,A man is doing .
k06Ge9ANKM8_5_16.avi,A little is is with a .
sZf3VDsdDPM_107_114.avi,A man is singing .
shPymuahrsc_5_12.avi,A slow is is a a .
XOAgUVVwKEA_8_20.avi,A girl is is a .
ufFT2BWh3BQ_0_8.avi,A baby panda laying on .
5YJaS2Eswg0_22_26.avi,A man is a a a .
lw7pTwpx0K0_38_48.avi,A man is a a .
UbmZAe5u5FI_132_141.avi,A person is cutting .
xCFCXzDUGjY_5_9.avi,A man is a a .
He7Ge7Sogrk_47_70.avi,A person is a a .
tJHUH9tpqPg_113_118.avi,A man is a a a .
n016q1w8Q30_2_11.avi,A person is folding an .
RjpbFlOHFps_8_25.avi,Two are are .
6JnGBs88sL0_4_10.avi,A girl girl on on a 
EpMuCrbxE8A_107_115.avi,A man is a a a .
HAjwXjwN9-A_16_24.avi,Two are are a a .
4xVGpDmA4lE_23_33.avi,A man is walking a a .
k5OKBX2e7xA_19_32.avi,A man is riding a a .
Jag7oTemldY_12_25.avi,A man is firing a a .
8MVo7fje_oE_125_130.avi,A man is a a plastic a a .
bqMmyY1ImkI_0_14.avi,A woman is is a