In [1]:
import os
import glob
import pickle
import sys
sys.path.append("../modules/")

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import skimage
from tqdm.notebook import tqdm
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader

from utils import set_random_seed

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
rescale_factor = 1.0
random_seed = 42

set_random_seed(random_seed)
full_path = "../h4_data/FullLengthVideos/"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Random seed: 42


In [6]:
class FullLengthVideosDataset(Dataset):
    def __init__(self, video_path, video_label_path, rescale_factor=1.0, length=100, overlap=15, test=False, sorting=False, transform=None):
        self.video_path = video_path
        self.video_label_path = video_label_path
        self.rescale_factor = rescale_factor
        self.length = length
        self.overlap = overlap
        self.test = test
        self.sorting = sorting
        
        if transform != None:
            self.transform = transform
        else:
            self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
        
        print('preparing data...')
        self.categories = sorted(os.listdir(self.video_path))
        self.datas = []
        for category in tqdm(self.categories, total=len(self.categories)):
            
            ### frames
            frames_path = sorted(glob.glob(os.path.join(self.video_path, category, '*')))
            frames = []
            for path in frames_path:
#                 frame = skimage.io.imread(path)
#                 frame = skimage.transform.rescale(frame, rescale_factor, mode='constant', preserve_range=True, multichannel=True, anti_aliasing=True).astype(np.uint8)
#                 frame.append(frame)
                frame = Image.open(path)
                frames.append(np.array(frame).astype(np.uint8))
            frames = np.stack(frames)
            
            ### labels
            if not self.test:
                labels_path = os.path.join(self.video_label_path, '{}.txt'.format(category))
                with open(labels_path, 'r') as fin:
                    labels = fin.readlines()
                labels = np.array([int(i.strip()) for i in labels])
            else:
                labels = np.array([0] * len(frames_path))
            
            self.datas += self.trim_frames(frames, labels)
    
    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, index):
        return self.datas[index]
    
    def collate_fn(self, datas):
        batch = {}
        
        lens = [data[0].shape[0] for data in datas]
        padding_len = max(lens)
        
        if self.sorting:
            # sort whole datas with its video length
            sorted_idx = np.argsort(lens)[::-1]
            datas = [datas[idx] for idx in sorted_idx]
            
        # frames_len
        frames_len = [data[0].shape[0] for data in datas]
        batch["frames_len"] = frames_len
        
        # frames
        batch_size = len(datas)
        width, height, channel = datas[0][0].shape[1:]
        frames = torch.zeros((batch_size, padding_len, channel, width, height))
        frames[:,:,0,:,:] = (frames[:,:,0,:,:] - MEAN[0]) / STD[0]
        frames[:,:,1,:,:] = (frames[:,:,1,:,:] - MEAN[1]) / STD[1]
        frames[:,:,2,:,:] = (frames[:,:,2,:,:] - MEAN[2]) / STD[2]
        for idx, (data, _) in enumerate(datas):
            for step, frame in enumerate(data):
                frames[idx, step] = self.transform(frame)
        batch['frames'] = frames.float()
        
        if not self.test:
            # labels
            labels = np.zeros((batch_size, padding_len), dtype=np.int64)
            for idx, (_, data) in enumerate(datas):
                labels[idx, :data.shape[0]] = data
            batch['labels'] = torch.tensor(labels).long()
            
        return batch
    
    def trim_frames(self, frames, labels):
        chunk_size = frames.shape[0] // (self.length - self.overlap)

        frame_chunks = np.array_split(frames, chunk_size)
        label_chunks = np.array_split(labels, chunk_size)

        final_chunks = []

        for i in range(chunk_size):
            if self.overlap > 0:
                if i == 0:
                    final_chunks.append((frame_chunks[i], label_chunks[i]))
                else:
                    frame_chunk = np.concatenate((frame_chunks[i-1][-self.overlap:], frame_chunks[i]))
                    label_chunk = np.concatenate((label_chunks[i-1][-self.overlap:], label_chunks[i]))
                    final_chunks.append((frame_chunk, label_chunk))
            else:
                final_chunks.append((frame_chunks[i], label_chunks[i]))

        return final_chunks

In [7]:
dataset = FullLengthVideosDataset(os.path.join(full_path, 'videos', 'valid'), os.path.join(full_path, 'labels', 'valid'), length=80, overlap=20, sorting=True)

preparing data...


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))




In [10]:
dataloader = DataLoader(dataset, batch_size=4, num_workers=32, shuffle=False, collate_fn=dataset.collate_fn)

In [11]:
for _ in tqdm(dataloader, total=len(dataloader)):
    pass

HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))


