# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [9]:
import numpy as np
import torch
import h5py

from scipy.io import loadmat
from base import BaseDataLoader
from torch.utils.data import Dataset
from torchvision import transforms

In [None]:
import augmentation.augmentation as module_augmentation

# init

In [4]:
label_dir = '/home/josegfer/datasets/challenge2020/data'
split_index = 'process/data_split/split1.mat'

# classes

In [36]:
class CustomTensorDataset_BeatAligned_h5(Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, database, split_idx, transform=None, p=0.5):
        # assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.database = database
        self.split_idx = split_idx
        self.transform = transform
        self.p = p
        self.leads_index = [0, 1, 6, 7, 8, 9, 10, 11]

    def __getitem__(self, index):
        # x = self.tensors[0][0][index]
        # x2 = self.tensors[0][1][index]
        # torch.randn(1)
        x = self.database['recording'][self.split_idx[index], self.leads_index, :, :]
        x2 = self.database['ratio'][self.split_idx[index]]

        if self.transform:
            if torch.rand(1) >= self.p:
                x = self.transform(x)

        # y = self.tensors[1][index]
        # w = self.tensors[2][index]
        y = self.database['label'][self.split_idx[index]]
        w = self.database['weight'][self.split_idx[index]]

        return [x, x2], y, w

    def __len__(self):
        # return self.tensors[0][0].size(0)
        return len(self.split_idx)

In [37]:
class ChallengeDataLoader_beat_aligned_data_h5(BaseDataLoader):
    """
    challenge2020 data loading
    """

    def __init__(self, label_dir, split_index, batch_size, shuffle=True, num_workers=0, resample_Fs=300,
                 window_size=3000, n_segment=1, normalization=False, augmentations=None, p=0.5, _25classes=False,
                 lead_number=12, save_data=False, load_saved_data=True, save_dir=None, seg_with_r=False, beat_length=400):

        split_idx = loadmat(split_index)
        train_index, val_index, test_index = split_idx['train_index'], split_idx['val_index'], split_idx['test_index']
        train_index = train_index.reshape((train_index.shape[1],))
        val_index = val_index.reshape((val_index.shape[1],))

        self.hdf5_file = h5py.File('data/challenge2020.h5', 'r')

        # leads_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
        # if lead_number == 2:
        #     # two leads
        #     leads_index = [1, 10]
        # elif lead_number == 3:
        #     # three leads
        #     leads_index = [0, 1, 7]
        # elif lead_number == 6:
        #     # six leads
        #     leads_index = [0, 1, 2, 3, 4, 5]
        # else:
        #     # eight leads
        #     leads_index = [0, 1, 6, 7, 8, 9, 10, 11]

        # ### different leads in the same shape
        # print(X_train.shape, X_val.shape)
        # X_train_tmp = X_train[:, leads_index, :, :]
        # X_val_tmp = X_val[:, leads_index, :, :]

        if augmentations:
            transformers = list()

            for key, value in augmentations.items():
                module_args = dict(value['args'])
                transformers.append(getattr(module_augmentation, key)(**module_args))

            train_transform = transforms.Compose(transformers)
            self.train_dataset = CustomTensorDataset_BeatAligned_h5(database = self.hdf5_file, split_idx = train_index, 
                                                                    transform=train_transform, p=p)
        else:
            self.train_dataset = CustomTensorDataset_BeatAligned_h5(database = self.hdf5_file, split_idx = train_index)
        self.val_dataset = CustomTensorDataset_BeatAligned_h5(database = self.hdf5_file, split_idx = val_index)
        self.test_dataset = CustomTensorDataset_BeatAligned_h5(database = self.hdf5_file, split_idx = test_index)

        super().__init__(self.train_dataset, self.val_dataset, self.test_dataset, batch_size, shuffle, num_workers)

    def normalization(self, X):
        return X

# draft

In [40]:
database = ChallengeDataLoader_beat_aligned_data_h5(label_dir = label_dir, split_index = split_index, batch_size = 16)

In [41]:
for batch_idx, ([data, info], target, class_weights) in enumerate(database):
    break
data.shape, info.shape, target.shape, class_weights.shape

(torch.Size([16, 8, 10, 400]),
 torch.Size([16, 1, 10]),
 torch.Size([16, 108]),
 torch.Size([16, 108]))