# Create ECO DataLoader

#### This file uses Kinetics dataset, kinetics video data to create a DataLoader for ECO.

In [1]:
import os
from PIL import Image
import csv
import numpy as np

import torch
import torch.utils.data
from torch import nn

import torchvision

In [2]:
torch.__version__

'0.4.1'

In [3]:
def make_datapath_list(root_path):
    """
    Create a file path list to a folder that uses video as image data.
    root_path: str, the root path to the data folder
    Returns: ret: video_list, a file path list to a folder with video as image data
    """

    video_list = list()

    class_list = os.listdir(root_path)

    for class_list_i in (class_list):  
        
        class_path = os.path.join(root_path, class_list_i)

        
        for file_name in os.listdir(class_path):

            
            name, ext = os.path.splitext(file_name)

            
            if ext == '.mp4':
                continue

            
            video_img_directory_path = os.path.join(class_path, name)

            
            video_list.append(video_img_directory_path)

    return video_list

In [4]:
class VideoTransform():
    def __init__(self, resize, crop_size, mean, std):
        self.data_transform = {
            'train': torchvision.transforms.Compose([
                GroupResize(int(resize)),
                GroupCenterCrop(crop_size),  
                GroupToTensor(),  
                GroupImgNormalize(mean, std),  
                Stack()  
            ]),
            'val': torchvision.transforms.Compose([
                GroupResize(int(resize)),
                GroupCenterCrop(crop_size),  
                GroupToTensor(),  
                GroupImgNormalize(mean, std),  
                Stack()  
            ])
        }

    def __call__(self, img_group, phase):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
        """
        return self.data_transform[phase](img_group)

In [5]:

class GroupResize():
  
    def __init__(self, resize, interpolation=Image.BILINEAR):
        self.rescaler = torchvision.transforms.Resize(resize, interpolation)

    def __call__(self, img_group):
        return [self.rescaler(img) for img in img_group]


class GroupCenterCrop():


    def __init__(self, crop_size):
        self.ccrop = torchvision.transforms.CenterCrop(crop_size)

    def __call__(self, img_group):
        return [self.ccrop(img) for img in img_group]


class GroupToTensor():


    def __init__(self):
        self.to_tensor = torchvision.transforms.ToTensor()

    def __call__(self, img_group):


        return [self.to_tensor(img)*255 for img in img_group]


class GroupImgNormalize():


    def __init__(self, mean, std):
        
        self.normlize = torchvision.transforms.Normalize(mean, std)

    def __call__(self, img_group):
        
        return [self.normlize(img) for img in img_group]


class Stack():


    def __call__(self, img_group):

        ret = torch.cat([(x.flip(dims=[0])).unsqueeze(dim=0)
                         for x in img_group], dim=0)  


        return ret

In [6]:
class VideoDataset(torch.utils.data.Dataset):

    def __init__(self, video_list, label_id_dict, num_segments, phase, transform, img_tmpl='image_{:05d}.jpg'):
        self.video_list = video_list  
        self.label_id_dict = label_id_dict  
        self.num_segments = num_segments  
        self.phase = phase  # train or val
        self.transform = transform  
        self.img_tmpl = img_tmpl  

    def __len__(self):
        
        return len(self.video_list)

    def __getitem__(self, index):
        
        imgs_transformed, label, label_id, dir_path = self.pull_item(index)
        return imgs_transformed, label, label_id, dir_path

    def pull_item(self, index):

        dir_path = self.video_list[index]  
        indices = self._get_indices(dir_path)  
        img_group = self._load_imgs(
            dir_path, self.img_tmpl, indices)  
        
        label = (dir_path.split('/')[3].split('/')[0])
        label_id = self.label_id_dict[label] 
        
        imgs_transformed = self.transform(img_group, phase=self.phase)

        return imgs_transformed, label, label_id, dir_path

    def _load_imgs(self, dir_path, img_tmpl, indices):
        
        img_group = []  

        for idx in indices:
            
            file_path = os.path.join(dir_path, img_tmpl.format(idx))
            
            img = Image.open(file_path).convert('RGB')
           
            img_group.append(img)
        return img_group

    def _get_indices(self, dir_path):
        
        
        file_list = os.listdir(dir_path)
        num_frames = len(file_list)

        
        tick = (num_frames) / float(self.num_segments)
        # 250 / 16 = 15.625
        
        indices = np.array([int(tick / 2.0 + tick * x)
                            for x in range(self.num_segments)])+1
        
        # indices = [  8  24  40  55  71  86 102 118 133 149 165 180 196 211 227 243]

        return indices

In [7]:
root_path = './data/kinetics_videos/'
video_list = make_datapath_list(root_path)

In [8]:
video_list

['./data/kinetics_videos/bungee jumping/zkXOcxGnUhs_000025_000035',
 './data/kinetics_videos/bungee jumping/dAeUFSdYG1I_000010_000020',
 './data/kinetics_videos/bungee jumping/TUvSX0pYu4o_000002_000012',
 './data/kinetics_videos/bungee jumping/b6yQZjPE26c_000023_000033',
 './data/kinetics_videos/arm wrestling/BdMiTo_OtnU_000024_000034',
 './data/kinetics_videos/arm wrestling/5JzkrOVhPOw_000027_000037',
 './data/kinetics_videos/arm wrestling/ehLnj7pXnYE_000027_000037',
 './data/kinetics_videos/arm wrestling/C4lCVBZ3ux0_000028_000038']

In [9]:
resize, crop_size = 224, 224
mean, std = [104, 117, 123], [1, 1, 1]
video_transform = VideoTransform(resize, crop_size, mean, std)

In [11]:
def get_label_id_dictionary(label_dicitionary_path):
    label_id_dict = {}
    id_label_dict = {}

    with open(label_dicitionary_path) as f:

        reader = csv.DictReader(f, delimiter=",", quotechar='"')
        for row in reader:
            label_id_dict.setdefault(
                row["class_label"], int(row["label_id"])-1)
            id_label_dict.setdefault(
                int(row["label_id"])-1, row["class_label"])

    return label_id_dict,  id_label_dict

In [12]:
label_dicitionary_path = 'kinetics_400_label_dicitionary.csv'
label_id_dict, id_label_dict = get_label_id_dictionary(label_dicitionary_path)
label_id_dict

{'abseiling': 0,
 'air drumming': 1,
 'answering questions': 2,
 'applauding': 3,
 'applying cream': 4,
 'archery': 5,
 'arm wrestling': 6,
 'arranging flowers': 7,
 'assembling computer': 8,
 'auctioning': 9,
 'baby waking up': 10,
 'baking cookies': 11,
 'balloon blowing': 12,
 'bandaging': 13,
 'barbequing': 14,
 'bartending': 15,
 'beatboxing': 16,
 'bee keeping': 17,
 'belly dancing': 18,
 'bench pressing': 19,
 'bending back': 20,
 'bending metal': 21,
 'biking through snow': 22,
 'blasting sand': 23,
 'blowing glass': 24,
 'blowing leaves': 25,
 'blowing nose': 26,
 'blowing out candles': 27,
 'bobsledding': 28,
 'bookbinding': 29,
 'bouncing on trampoline': 30,
 'bowling': 31,
 'braiding hair': 32,
 'breading or breadcrumbing': 33,
 'breakdancing': 34,
 'brush painting': 35,
 'brushing hair': 36,
 'brushing teeth': 37,
 'building cabinet': 38,
 'building shed': 39,
 'bungee jumping': 40,
 'busking': 41,
 'canoeing or kayaking': 42,
 'capoeira': 43,
 'carrying baby': 44,
 'cartw

In [13]:
val_dataset = VideoDataset(video_list, label_id_dict, num_segments=16,
                           phase="val", transform=video_transform, img_tmpl='image_{:05d}.jpg')

In [14]:
val_dataset

<__main__.VideoDataset at 0x7f0e769c0ad0>

In [15]:
index = 0 #0-7
print(val_dataset.__getitem__(index)[0].shape)  
print(val_dataset.__getitem__(index)[1])  
print(val_dataset.__getitem__(index)[2])  
print(val_dataset.__getitem__(index)[3])  


torch.Size([16, 3, 224, 224])
bungee jumping
40
./data/kinetics_videos/bungee jumping/zkXOcxGnUhs_000025_000035


In [16]:
batch_size = 8
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

batch_iterator = iter(val_dataloader)  
imgs_transformeds, labels, label_ids, dir_path = next(
    batch_iterator)  
print(imgs_transformeds.shape)

torch.Size([8, 16, 3, 224, 224])
