In [19]:
from __future__ import print_function

In [3]:
#!jupyter nbconvert --to python Dataset.ipynb
from Dataset import GenericDataset

In [4]:
import cv2
import random
import torch
import torchvision.transforms as transforms
import torch.utils.data as data

In [5]:
class videoFrameError(Exception):
    """
        Raise when reading a non-existing frame in a video
        Attribute : 
            videoName --- video name or path
            frame     --- frame number
    """
    def __init__(self, videoName, frame):
        self.videoName = videoName
        self.frame = frame
    def __str__(self):
        return repr(self.videoName) + repr(self.frame)

In [6]:
def ReadRandomSequence(videoName, nbFrame, transform=transforms.ToTensor(), dropFrame=1):
    """
        Read a nbFrame number of frame randomly in the video
    """
    cap = cv2.VideoCapture(videoName)
    if cv2.CAP_PROP_FRAME_COUNT - (nbFrame*dropFrame) < 0:
        firstFrame=cv2.CAP_PROP_FRAME_COUNT
    else:
        firstFrame = random.randint(0, cap.get(cv2.CAP_PROP_FRAME_COUNT)- (nbFrame*dropFrame))
    
    ret = cap.set(cv2.CAP_PROP_POS_FRAMES,firstFrame) #go to first frame
    cframe = firstFrame #current frame

    if not ret:
        raise videoFrameError(videoName, cframe)
    t = torch.Tensor(nbFrame,3,225,225)
    for i in range(nbFrame):
        ret, frame = cap.read()
        cframe += 1
        if not ret:
            raise videoFrameError(videoName, cframe)    
        t[i] = transform(frame)
    return t

In [13]:
def ReadNFrame(videoName, nbFrame, transform=transforms.ToTensor(), dropFrame=1):
    """
        Read a nbFrame frame and concatenate them
        
        Input : videoName: path to video
                nbFrame : number of frame to read
                transform : transform to apply to each frame
                dropFrame : number of frame to drop between each frame
        
    """
    cap = cv2.VideoCapture(videoName)
    
    #verify video length
    if cap.get(cv2.CAP_PROP_FRAME_COUNT) < nbFrame*dropFrame:
        raise videoFrameError(videoName, nbFrame)
    
    #find possible first frame
    firstFrame = random.randint(0, cap.get(cv2.CAP_PROP_FRAME_COUNT)- (nbFrame*dropFrame))
    
    ret = cap.set(cv2.CAP_PROP_POS_FRAMES,firstFrame) #go to first frame
    cframe = firstFrame #current frame

    if not ret:
        raise videoFrameError(videoName, cframe)
        
    t = torch.Tensor(nbFrame,3,128,128)
    for i in range(nbFrame):
        ret, frame = cap.read()
        cframe += 1
        if not ret:
            raise videoFrameError(videoName, cframe)    
        t[i] = transform(frame)
    return t.view(t.size(0)*t.size(1), t.size(2), t.size(3))

In [16]:
def VideoDataset(rep="/video/GestureSequence/", SequenceSize=5, batchSize=4, transform=transforms.Compose(
            (transforms.ToPILImage(),
            transforms.Resize(128),
            transforms.RandomCrop(128),
            transforms.ToTensor())), 
            concat=False):
    
    
    if concat:
        openFile = lambda x:ReadNFrame(x,SequenceSize,transform)
    else:
        openFile = lambda x:ReadRandomSequence(x,SequenceSize,transform)
    
    targetOperation = lambda x:int(x)
    
    videodataset = GenericDataset(rep, fileOperation=openFile, targetOperation=targetOperation)
    data_loader = data.DataLoader(videodataset,
                                          batch_size=batchSize,
                                          shuffle=True,
                                          num_workers=4,
                                          drop_last=True
                                )
    return data_loader

In [20]:
if __name__=="__main__":
    vd = VideoDataset(concat=True)
    for j,i in enumerate(vd):
        a = i
        print(j, end="\r")
    print(a)

120

TypeError: __init__() takes exactly 3 arguments (2 given)