In [1]:
import torch.nn.functional as F
import torch.nn as nn
import torch
import time
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pickle
import re
import json
import os
import sys
from scipy.special import expit
import random
import numpy as np
from torch.autograd import Variable

def data_preprocess():
    filepath = 'data/'
    with open(filepath + 'training_label.json', 'r') as f:
        file = json.load(f)

    wo_co = {}
    for d in file:
        for s in d['caption']:
            wor_sen = re.sub('[.!,;?]', ' ', s).split()
            for word in wor_sen:
                word = word.replace('.', '') if '.' in word else word
                if word in wo_co:
                    wo_co[word] += 1
                else:
                    wo_co[word] = 1

    word_dict = {}
    for word in wo_co:
        if wo_co[word] > 4:
            word_dict[word] = wo_co[word]

    use_tok = [('<PAD>', 0), ('<SOS>', 1), ('<EOS>', 2), ('<UNK>', 3)]
    i2w = {}
    w2i = {}
    for i, w in enumerate(word_dict):
        i2w[i + len(use_tok)] = w
        w2i[w] = i + len(use_tok)

    for token, index in use_tok:
        i2w[index] = token
        w2i[token] = index

    return i2w, w2i, word_dict

def s_split(sentence, word_dict, w2i):
    sentence = [w2i[word] if word in word_dict else 3 for word in re.sub(r'[.!,;?]', ' ', sentence).split()]
    sentence.insert(0, 1)  # Adding SOS token at the beginning
    sentence.append(2)  # Adding EOS token at the end
    return sentence

def annotate(label_file, word_dict, w2i):
    lab_json = 'data/' + label_file
    caption = []
    with open(lab_json, 'r') as f:
        label = json.load(f)
    for d in label:
        for s in d['caption']:
            s = s_split(s, word_dict, w2i)
            caption.append((d['id'], s))
    return caption

def avi(files_dir):
    avi_data = {}
    training_feats = 'data/' + files_dir
    files = os.listdir(training_feats)
    
    for i, file in enumerate(files):
        print(i)
        value = np.load(os.path.join(training_feats, file))
        avi_data[file.split('.npy')[0]] = value
    
    return avi_data

def minibatch(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    avi_data, captions = zip(*data) 
    avi_data = torch.stack(avi_data, 0)

    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return avi_data, targets, lengths

class training_data(Dataset):
    def __init__(self, label_file, files_dir, word_dict, w2i):
        self.label_file = label_file
        self.word_dict = word_dict
        self.files_dir = files_dir
        self.avi = avi(label_file)
        self.w2i = w2i
        self.data_pair = annotate(files_dir, word_dict, w2i)
        
    def __len__(self):
        return len(self.data_pair)
    
    def __getitem__(self, idx):
        assert (idx < self.__len__())
        avi_file_name, sentence = self.data_pair[idx]
        data = torch.Tensor(self.avi[avi_file_name])
        data += torch.Tensor(data.size()).random_(0, 2000)/10000.
        return torch.Tensor(data), torch.Tensor(sentence)

class lstm_attn_decode(nn.Module):
    def __init__(self, hidden_size, output_size, vocab_size, word_dim, dropout_percentage=0.35):
        super(lstm_attn_decode, self).__init__()

        self.hidden_size = 512
        self.output_size = output_size
        self.vocab_size = vocab_size
        self.word_dim = word_dim

        self.embedding = nn.Embedding(output_size, 1024)
        self.dropout = nn.Dropout(0.35)
        self.lstm = nn.LSTM(hidden_size+word_dim, hidden_size, batch_first=True)
        self.attention = attention(hidden_size)
        self.to_final_output = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_last_hidden_state, encoder_output, targets=None, mode='train', tr_steps=None):
        _, batch_size, _ = encoder_last_hidden_state.size()
        
        decoder_current_hidden_state = None if encoder_last_hidden_state is None else encoder_last_hidden_state
        decoder_cxt = torch.zeros(decoder_current_hidden_state.size())

        decoder_current_input_word = Variable(torch.ones(batch_size, 1)).long()
        seq_logProb = []
        seq_predictions = []

        targets = self.embedding(targets)
        _, seq_len, _ = targets.size()

        i = 0
        while i < seq_len - 1:
            threshold = self.teacher_forcing_ratio(training_steps=tr_steps)
            if random.uniform(0.05, 0.995) > threshold: 
                current_input_word = targets[:, i]  
            else: 
                current_input_word = self.embedding(decoder_current_input_word).squeeze(1)

            context = self.attention(decoder_current_hidden_state, encoder_output)
            lstm_input = torch.cat([current_input_word, context], dim=1).unsqueeze(1)
            lstm_output, t = self.lstm(lstm_input, (decoder_current_hidden_state, decoder_cxt))
            decoder_current_hidden_state = t[0]
            logprob = self.to_final_output(lstm_output.squeeze(1))
            seq_logProb.append(logprob.unsqueeze(1))
            decoder_current_input_word = logprob.unsqueeze(1).max(2)[1]
            
            i += 1

        seq_logProb = torch.cat(seq_logProb, dim=1)
        seq_predictions = seq_logProb.max(2)[1]
        return seq_logProb, seq_predictions

    def infer(self, encoder_last_hidden_state, encoder_output):
        _, batch_size, _ = encoder_last_hidden_state.size()
        decoder_current_hidden_state = None if encoder_last_hidden_state is None else encoder_last_hidden_state
        decoder_current_input_word = Variable(torch.ones(batch_size, 1)).long()
        decoder_c
