In [6]:
import torch
import pandas, os, pickle
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
import cv2
import numpy as np
import random

class DatasetManager(Dataset):
    def __init__(self, tokenizer, 
                 split: str,
                 json_file_path: str, # location of the json file
                 video_file_path: str, # location of the video files
                 vocabulary_size: int,
                 frames: int,
                 video_transform = None):
        
        assert os.path.exists(json_file_path), json_file_path
        assert os.path.exists(video_file_path), video_file_path
        assert vocabulary_size > 0
        
        self.video_file_path = video_file_path
        self.textTokenizer = tokenizer
        self.video_transform = video_transform
        self.frames = frames
        self.split = split
        
        with open(json_file_path, "r") as f:
            json_data = json.load(f)

        videoInfo = json_data['sentences']
        
        random.shuffle(videoInfo)

        if self.split == "train":
            splittedVideoInfo = videoInfo[:20000]

        elif self.split == "validation":  
            splittedVideoInfo = videoInfo[20000:25000]

        self.videos = [idx['video_id'] for idx in splittedVideoInfo]
        captions = [idx['caption'] for idx in splittedVideoInfo]
        
        self.texts = list()

        for caption in(captions):
            tokenized_caption = tokenizer.preprocess(caption)
            self.texts.append(tokenized_caption)
                
        if self.split == "train" and not hasattr(tokenizer, "vocab"):
            self.buildVocab(videoInfo, vocabulary_size)
    
        # Numericalize all texts.
        for i in tqdm(range(0, len(self.texts))):
                    self.texts[i] = tokenizer.process([self.texts[i]])

    def buildVocab(self, videoInfo, vocabulary_size):
        texts = list()
        captions = [idx['caption'] for idx in videoInfo]

        for caption in(captions):
            tokenized_caption = self.textTokenizer.preprocess(caption)
            texts.append(tokenized_caption)

        self.textTokenizer.build_vocab(texts, max_size = vocabulary_size)
        
        
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, i):

        vidFile = cv2.VideoCapture(self.video_file_path+self.videos[i]+'.mp4')
        
        if (vidFile.isOpened() == False):
            path = self.video_file_path+self.videos[i]
            print('Error while trying to read video ', path,'. Please check path again')

        clips = [] 
        
        while(vidFile.isOpened()):
            ret, frame = vidFile.read()
            if ret == True:
                image = frame.copy()
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = self.video_transform(image=frame)['image']
                clips.append(frame)
                if len(clips) == self.frames:
                    with torch.no_grad(): 
                        input_frames = np.array(clips)       
                        #input_frames = np.transpose(input_frames, (3, 0, 1, 2))
                        input_frames = np.transpose(input_frames, (0, 3, 1, 2))
                        input_frames = torch.tensor(input_frames, dtype=torch.float32)

                        return input_frames, self.texts[i].squeeze()
    
    
    # To be used in the Data Loader collate_fn parameter.
    def create_batch(self, batch):
        videos, texts = zip(*batch)

        # Compute text lengths for Pytorch's RNN library.
        text_lengths = [len(text) for text in texts]

        # Stack videos and pad text.
        stacked_videos = torch.stack(videos)
        padded_texts = pad_sequence(texts, batch_first = self.textTokenizer.batch_first, 
                                    padding_value = self.textTokenizer.vocab.stoi["<pad>"])

        return stacked_videos, padded_texts, text_lengths
    