In [19]:
import os
import csv
import sys
import numpy as np
from torch.utils.data import Dataset, DataLoader
sys.path.insert(1, '../')
from lxrt.SlowFast.slowfast.config.defaults import get_cfg

In [3]:
from lxrt.SlowFast.slowfast.datasets.tgif_direct import TGIF

In [4]:
cfg = get_cfg()
cfg_file = "../lxrt/SlowFast/configs/Kinetics/c2/SLOWFAST_8x8_R50.yaml"
cfg.merge_from_file(cfg_file)

In [74]:

class TGIFDataset(Dataset):
    def __init__(self, dataset_name='train', data_type=None, dataframe_dir=None, vocab_dir=None):
        self.dataframe_dir = dataframe_dir # of the form data/tgif/vocabulary
        self.vocab_dir = vocab_dir # of the form data/tgif/dataframe
        self.data_type = data_type # 'TRANS'
        self.dataset_name = dataset_name # 'train' or 'val' or 'test'

        self.csv = self.read_from_csvfile()
        self.header2idx = self.header2idx()
        self.gif_names = self.csv[:,self.header2idx['gif_name']]
        self.gif_tensor = None
        self.questions = self.csv[:,self.header2idx['question']]
        self.answers = self.csv[:,self.header2idx['answer']]
        self.mc_options = self.csv[:,self.header2idx['a1']:self.header2idx['a5']+1]
        ## GIF LOADER ##
        ## NOTE: May have to change the relative path of gif dir as 
        ## an extra argument to TGIF class init
        loader  = TGIF(cfg, "train")
        self.get_gif_tensor = loader.__getitem__
        
    def __getitem__(self, i): # whats the argument for this
    	#gif_path = os.path.join(self.dataframe_dir, 'gif_tensors')
    	#pick up ith gif_tensor
        #NOTE: gif_path is only the gif name, not the relative path
        # REturn value: tuple (slow frames, fast frames) where frame -> (t, 3, h, w)
        self.gif_tensor = self.get_gif_tensor(self.gif_names[i])
        return self.gif_tensor, self.questions[i], list(self.mc_options[i]), self.answers[i]
    
    def __len__(self):
        return len(self.questions)

    def header2idx(self):
    	return {'gif_name':0,'question':1,'a1':2,'a2':3,'a3':4,'a4':5,'a5':6,'answer':7,'vid_id':8,'key':9}

    def read_from_csvfile(self):
        assert self.data_type in ['TRANS', 'ACTION'] # ACTION just for starting, will be using TRANS finally

        self.total_q=[]
        if self.data_type=='TRANS':
            train_data_path = os.path.join(self.dataframe_dir, 'Train_transition_question.csv')
            test_data_path = os.path.join(self.dataframe_dir, 'Test_transition_question.csv')

        elif self.data_type=='ACTION':
            train_data_path = os.path.join(self.dataframe_dir, 'Train_action_question.csv')
            test_data_path = os.path.join(self.dataframe_dir, 'Test_action_question.csv')

        if not (os.path.exists(train_data_path) and os.path.exists(test_data_path)):
            print("Does not exist")

        csv_data=[]
        if self.dataset_name=='train':
        	with open(train_data_path) as file:
        		csv_reader = csv.reader(file, delimiter='\t')
        		for row in csv_reader:
        			csv_data.append(row)
        elif self.dataset_name=='test':
        	with open(test_data_path) as file:
        		csv_reader = csv.reader(file, delimiter='\t')
        		for row in csv_reader:
        			csv_data.append(row)
        csv_data.pop(0)

        return np.asarray(csv_data)

In [75]:
data_file = "../../../../../../IDL/project/tgif-qa/dataset/"
dataset = TGIFDataset(dataset_name='train', data_type='TRANS', dataframe_dir=data_file, vocab_dir=None)
dataset[0]

/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nkgxdo19En1qf37ejo1_250.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])


([tensor([[[[-2.0000, -2.0000, -2.0000,  ..., -1.9998, -1.9998, -1.9998],
            [-2.0000, -2.0000, -2.0000,  ..., -1.9998, -1.9998, -1.9998],
            [-1.9999, -1.9999, -1.9990,  ..., -1.9994, -1.9998, -1.9998],
            ...,
            [-1.9999, -1.9999, -1.9997,  ..., -1.9975, -2.0000, -2.0000],
            [-1.9999, -1.9999, -1.9997,  ..., -1.9975, -2.0000, -2.0000],
            [-1.9999, -1.9999, -1.9997,  ..., -1.9975, -2.0000, -2.0000]],
  
           [[-1.9999, -1.9999, -1.9999,  ..., -1.9998, -1.9997, -1.9997],
            [-1.9999, -1.9999, -1.9999,  ..., -1.9998, -1.9997, -1.9997],
            [-1.9999, -1.9999, -1.9992,  ..., -1.9993, -1.9998, -1.9998],
            ...,
            [-2.0000, -2.0000, -1.9998,  ..., -1.9974, -2.0000, -2.0000],
            [-2.0000, -2.0000, -1.9998,  ..., -1.9975, -2.0000, -2.0000],
            [-2.0000, -2.0000, -1.9998,  ..., -1.9975, -1.9999, -1.9999]],
  
           [[-1.9999, -1.9999, -1.9999,  ..., -1.9999, -1.9999, -1.999

In [76]:
dataloader = DataLoader(dataset, batch_size=32)
for tensor, question, options, answer in dataloader:
    print("tensor: ", tensor)
    print("question: ", question)
    print("options: ", options)
    print("answer: ", answer)
    break

/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nkgxdo19En1qf37ejo1_250.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])
/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nkgxdo19En1qf37ejo1_250.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])
/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nqc2mbmU2J1uxhtnwo1_400.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])
/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nqc2mbmU2J1uxhtnwo1_400.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])
/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nesl84pTnp1tmddexo1_400.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])
/users/cdwivedi/RL_EXP/IDL/project/tgif-qa/code/dataset/tgif/gifs/tumblr_nesl84pTnp1tmddexo1_400.gi/*
torch.Size([3, 8, 256, 256]) torch.Size([3, 32, 256, 256])
/users/cdwivedi/RL_EXP/IDL/project