In [1]:
!jupyter nbconvert --to python Dataset.ipynb
from Dataset import GenericDataset

[NbConvertApp] Converting notebook Dataset.ipynb to python
[NbConvertApp] Writing 1778 bytes to Dataset.py


In [2]:
import cv2
import random
import torch
import torchvision.transforms as transforms
import torch.utils.data as data

In [3]:
class videoFrameError(Exception):
    """
        Raise when reading a non-existing frame in a video
        Attribute : 
            videoName --- video name or path
            frame     --- frame number
    """
    def __init__(self, videoName, frame):
        self.videoName = videoName
        self.frame = frame
    def __str__(self):
        return repr(self.videoName) + repr(self.frame)

In [4]:
def ReadRandomSequence(videoName, nbFrame, transform=transforms.ToTensor(), dropFrame=1):
    """
        Read a nbFrame number of frame randomly in the video
    """
    cap = cv2.VideoCapture(videoName)
    if cv2.CAP_PROP_FRAME_COUNT - (nbFrame*dropFrame) < 0:
        firstFrame=cv2.CAP_PROP_FRAME_COUNT
    else:
        firstFrame = random.randint(0, cap.get(cv2.CAP_PROP_FRAME_COUNT)- (nbFrame*dropFrame))
    
    ret = cap.set(cv2.CAP_PROP_POS_FRAMES,firstFrame) #go to first frame
    cframe = firstFrame #current frame

    if not ret:
        raise videoFrameError(videoName, cframe)
    t = torch.Tensor(nbFrame,3,225,225)
    for i in range(nbFrame):
        ret, frame = cap.read()
        cframe += 1
        if not ret:
            raise videoFrameError(videoName, cframe)    
        t[i] = transform(frame)
    return t

In [5]:
def VideoDataset(rep="/video/GestureSequence/", SequenceSize=5, batchSize=4):
    transform = transforms.Compose(
            (transforms.ToPILImage(),
            transforms.Resize(225),
            transforms.RandomCrop(225),
            transforms.ToTensor())
    )
    
    openFile = lambda x:ReadRandomSequence(x,SequenceSize,transform)
    
    videodataset = GenericDataset(rep, fileOperation=openFile)
    data_loader = data.DataLoader(videodataset,
                                          batch_size=batchSize,
                                          shuffle=True,
                                          num_workers=4,
                                          drop_last=True
                                )
    return data_loader