In [80]:
from __future__ import print_function, division
import os, sys
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import h5py
import json
from random import shuffle
from Dataset import *

# data sequence class

In [81]:
class Data_Sequences(Dataset):
    def __init__(self, root_dir, partition, transform=None):
        """
        Args:
            root_dir (string): Directory with all the h5 sequences.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.partition = partition
        self.transform = transform
    def __len__(self):
        return len(self.partition)
    def __getitem__(self, index):
        sequence = self.partition[index]
        with h5py.File(self.root_dir + '/' + sequence+ ".h5", 'r') as f:
            data = np.stack(tuple(f['ch{}'.format(i)] for i in range(4)), axis = -1 )
            label = f['label'][()]
        if self.transform != None:
            data, label = self.transform((data, label))# make the necessary transformations
        return data, label
    
    
    
    

# transforms 

In [82]:
            
class Reshape(object):
    '''
    reshape an image to 32 * 32 
    '''
    def __call__(self, sample):
        data, label = sample
        # a whole sequence 
        data = data.reshape((data.shape[0], 32,32, 4))
        return data, label

In [83]:
class Rescale(object):
    
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple): Desired output size. If tuple, output is
            matched to output_size. 
    """

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        newH, newW = self.output_size
        data, label = sample
        data = transform.resize(data, (data.shape[0], newH, newW, data.shape[3]))
        return data, label

In [84]:
class ToTensor(object):
    """transposes ndarrays in sample to Tensors.
        and transforms them to tensors
    """

    def __call__(self, sample):
        data, label= sample
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        data = data.transpose((0, 3, 1, 2))
        return torch.from_numpy(data),torch.from_numpy(label)


In [112]:
class SameSize(object):
    " either completes a sequence with pads or subsamples it"
    def __init__(self, size):
        self.size = size
    def __call__(self, sample):
        data, labels = sample
        sequences = np.zeros((self.size, *data.shape[1:]))
        labels_2 = np.zeros((self.size, *labels.shape[1:]))
        sequences[-data.shape[0]:] = data[:self.size]
        labels_2[-labels.shape[0]:] = labels[:self.size]
        return sequences, labels_2
        

In [113]:
params = {'batch_size': 2,
          'shuffle': False,
          'num_workers': 6}

max_epochs = 1 # j'ai mis a 1 pour le test pour aller vite
directory = "/home/chekirou/Documents/SOLI/SoliData/dsp"
# Datasets
train, test = split(directory,frames = False, percentage = 0.8, use= 0.2) # returns a split of the frames 80% train, 20% test, for 80% of the data

# Generators

t = transforms.Compose([Reshape(), Rescale((224,224)),SameSize(40), ToTensor()]) # composition of transformations
training_set = Data_Sequences(directory,train[:6], transform = t) # open the dataset
training_generator = DataLoader(training_set, **params)


In [114]:
test, label = training_set[0]


In [115]:
from IPython.display import clear_output
def show_gesture(test, channel):
    for i in range(test.shape[0]):
        plt.imshow(test[i,channel])
        plt.show()
        clear_output(wait=True)

In [116]:
for i, (t, l) in enumerate(training_set):
    print(i, t.shape, l.shape)

0 torch.Size([40, 4, 224, 224]) torch.Size([40, 1])
1 torch.Size([40, 4, 224, 224]) torch.Size([40, 1])
2 torch.Size([40, 4, 224, 224]) torch.Size([40, 1])
3 torch.Size([40, 4, 224, 224]) torch.Size([40, 1])
4 torch.Size([40, 4, 224, 224]) torch.Size([40, 1])
5 torch.Size([40, 4, 224, 224]) torch.Size([40, 1])


In [117]:
for i, (test, label) in enumerate(training_generator):
    print(i, test.shape, label.shape)

0 torch.Size([2, 40, 4, 224, 224]) torch.Size([2, 40, 1])
1 torch.Size([2, 40, 4, 224, 224]) torch.Size([2, 40, 1])
2 torch.Size([2, 40, 4, 224, 224]) torch.Size([2, 40, 1])
