# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import h5py
import pandas as pd
import numpy as np
import torch

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# dataset

In [3]:
class LUDB(Dataset):
    def __init__(self, hdf5_path = '/home/josegfer/datasets/ludb/output/ludb.h5', 
                 ecg_col = 'tracings', output_col = 'annotations', id_col = 'exam_id', ):
        self.hdf5_file = h5py.File(hdf5_path, 'r')

        self.ecg_col = ecg_col
        self.output_col = output_col
        self.id_col = id_col
    
    def __len__(self):
        return len(self.hdf5_file[self.ecg_col])
    
    def __getitem__(self, idx):
        return {'x': self.hdf5_file[self.ecg_col][idx],
                'y': self.hdf5_file[self.output_col][idx], 
                'exam_id': self.hdf5_file[self.id_col][idx], }

# get

In [4]:
ds = LUDB()
loader = DataLoader(ds, batch_size = 16, shuffle = True, num_workers = 6)

In [5]:
for batch in (loader):
    break
batch['exam_id'], batch['x'].shape, batch['y'].shape, batch['x'], batch['y']

(tensor([142,   4,  93, 174, 149,  37, 158, 103,  86, 177, 133,  53,  34, 170,
          69, 104], dtype=torch.int32),
 torch.Size([16, 12, 5000]),
 torch.Size([16, 12, 3, 5000]),
 tensor([[[-0.0134,  0.0184,  0.0117,  ..., -0.1011, -0.1236, -0.0593],
          [-0.0271, -0.0073, -0.0059,  ..., -0.0026, -0.0271, -0.0086],
          [-0.0316, -0.0422, -0.0290,  ...,  0.1555,  0.1423,  0.0777],
          ...,
          [ 0.1645,  0.1695,  0.1735,  ..., -0.1249, -0.1174, -0.0466],
          [ 0.2253,  0.2286,  0.2328,  ..., -0.1510, -0.1477, -0.0570],
          [ 0.1257,  0.1335,  0.1366,  ..., -0.0792, -0.0701, -0.0278]],
 
         [[-0.0892, -0.0462, -0.0834,  ..., -0.0586, -0.0479, -0.0223],
          [ 0.4373,  0.5258,  0.4861,  ..., -0.2034, -0.2985, -0.1532],
          [ 0.3333,  0.3447,  0.3561,  ..., -0.0630, -0.1276, -0.0676],
          ...,
          [-0.1545, -0.1573, -0.1514,  ...,  0.0588,  0.0799,  0.0334],
          [-0.0800, -0.0850, -0.0786,  ...,  0.0262,  0.0049, -0.00