# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import h5py
import pandas as pd
import numpy as np
import torch

from tqdm import tqdm

# init

In [3]:
hdf5_path = '/home/josegfer/datasets/cpsc2018/cpsc2018.h5'
metadata_path = '/home/josegfer/datasets/cpsc2018/metadata.csv'

In [4]:
val_size = 0.10
tst_size = 0.05

In [5]:
exam_id_col = 'exam_id'

# loader

In [6]:
from torch.utils.data import Dataset

In [7]:
class CPSC2018():
    def __init__(self, hdf5_path = '/home/josegfer/datasets/cpsc2018/cpsc2018.h5', 
                 metadata_path = '/home/josegfer/datasets/cpsc2018/metadata.csv', 
                 val_size = 0.1, tst_size = 0.05):
        self.hdf5_file = h5py.File(hdf5_path, "r")
        self.metadata = pd.read_csv(metadata_path)

        self.val_size = val_size
        self.tst_size = tst_size

        trn_metadata, val_metadata, tst_metadata = self.split()
        self.check_dataleakage(trn_metadata, val_metadata, tst_metadata)
        
        self.trn_idx_dict = self.get_idx_dict(trn_metadata)
        self.val_idx_dict = self.get_idx_dict(val_metadata)
        self.tst_idx_dict = self.get_idx_dict(tst_metadata)

    def split(self, patient_id_col = 'exam_id'): # im assuming all pacient are unique for this dataset!
        patient_ids = self.metadata[patient_id_col].unique()

        num_trn = int(len(patient_ids) * (1 - self.tst_size - self.val_size))
        num_val = int(len(patient_ids) * self.val_size)

        trn_ids = set(patient_ids[:num_trn])
        val_ids = set(patient_ids[num_trn : num_trn + num_val])
        tst_ids = set(patient_ids[num_trn + num_val :])

        trn_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(trn_ids)]
        val_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(val_ids)]
        tst_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(tst_ids)]

        return trn_metadata, val_metadata, tst_metadata
    
    def check_dataleakage(self, trn_metadata, val_metadata, tst_metadata, exam_id_col = 'exam_id'):
        trn_ids = set(trn_metadata[exam_id_col].unique())
        val_ids = set(val_metadata[exam_id_col].unique())
        tst_ids = set(tst_metadata[exam_id_col].unique())
        assert (len(trn_ids.intersection(val_ids)) == 0), "Some IDs are present in both train and validation sets."
        assert (len(trn_ids.intersection(tst_ids)) == 0), "Some IDs are present in both train and test sets."
        assert (len(val_ids.intersection(tst_ids)) == 0), "Some IDs are present in both validation and test sets."

    def get_idx_dict(self, split_metadata, exam_id_col = 'exam_id'):
        split_exams, split_h5_idx, temp = np.intersect1d(self.hdf5_file[exam_id_col], split_metadata[exam_id_col].values, return_indices = True)
        split_csv_idx = split_metadata.iloc[temp].index.values
        split_idx_dict = {exam_id_col: split_exams, 'h5_idx': split_h5_idx, 'csv_idx': split_csv_idx}

        print('checking exam_id consistency in idx dict')
        for idx, exam_id in tqdm(enumerate(split_idx_dict[exam_id_col])):
            assert self.hdf5_file[exam_id_col][split_idx_dict['h5_idx'][idx]] == exam_id
            assert self.metadata[exam_id_col][split_idx_dict['csv_idx'][idx]] == exam_id
        return split_idx_dict


In [8]:
data = CPSC2018()

561it [00:00, 5606.46it/s]

checking exam_id consistency in idx dict


5845it [00:00, 8097.38it/s]
687it [00:00, 8572.87it/s]
345it [00:00, 8554.64it/s]

checking exam_id consistency in idx dict
checking exam_id consistency in idx dict





In [9]:
class CPSC2018split(Dataset):
    def __init__(self, database, split_idx_dict, 
                 tracing_col = 'tracings', exam_id_col = 'exam_id', output_col = ['AF', 'I-AVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']):
        self.database = database
        self.split_idx_dict = split_idx_dict

        self.tracing_col = tracing_col
        self.exam_id_col = exam_id_col
        self.output_col = output_col
    
    def __len__(self):
        return len(self.split_idx_dict[self.exam_id_col])
    
    def __getitem__(self, idx):
        return {'X': self.database.hdf5_file[self.tracing_col][self.split_idx_dict['h5_idx'][idx]], 
                'y': self.database.metadata[self.output_col].loc[self.split_idx_dict['csv_idx'][idx]].values}

In [10]:
trn_ds = CPSC2018split(data, data.trn_idx_dict)
val_ds = CPSC2018split(data, data.val_idx_dict)
tst_ds = CPSC2018split(data, data.tst_idx_dict)

In [11]:
trn_loader = torch.utils.data.DataLoader(trn_ds, batch_size = 128,
                                          shuffle = True, num_workers = 6)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size = 128,
                                          shuffle = False, num_workers = 6)
tst_loader = torch.utils.data.DataLoader(tst_ds, batch_size = 128,
                                          shuffle = False, num_workers = 6)

In [12]:
for batch in tqdm(trn_loader):
    break
batch['X'].shape, batch['X'], batch['y'].shape, batch['y']

  0%|          | 0/46 [00:08<?, ?it/s]


(torch.Size([128, 4096, 12]),
 tensor([[[ 0.0916,  0.2334,  0.1418,  ...,  0.0552,  0.0077, -0.1428],
          [ 0.0558,  0.2210,  0.1652,  ...,  0.0249, -0.0091, -0.1754],
          [ 0.0132,  0.1663,  0.1531,  ...,  0.0007, -0.0285, -0.1809],
          ...,
          [ 0.0055,  0.0444,  0.0390,  ..., -0.0392, -0.0229, -0.1960],
          [ 0.0086,  0.0463,  0.0377,  ..., -0.0508, -0.0407, -0.2167],
          [-0.0025,  0.0479,  0.0504,  ..., -0.0590, -0.0536, -0.2134]],
 
         [[-0.0855, -0.1844, -0.0989,  ..., -0.1340, -0.1940, -0.1748],
          [-0.0992, -0.2194, -0.1202,  ..., -0.1233, -0.1804, -0.1678],
          [-0.1149, -0.2320, -0.1171,  ..., -0.1526, -0.2002, -0.1807],
          ...,
          [-0.1032, -0.1653, -0.0621,  ..., -0.1012, -0.1973, -0.1618],
          [-0.0374, -0.1184, -0.0810,  ..., -0.1248, -0.2250, -0.1823],
          [-0.0417, -0.1472, -0.1055,  ..., -0.0918, -0.1894, -0.1567]],
 
         [[-0.1125,  0.0535,  0.1661,  ...,  0.0948, -0.0942, -0.1234]

In [13]:
for batch in tqdm(val_loader):
    break
batch['X'].shape, batch['X'], batch['y'].shape, batch['y']

  0%|          | 0/6 [00:04<?, ?it/s]


(torch.Size([128, 4096, 12]),
 tensor([[[-8.2722e-03, -6.9901e-02, -6.1629e-02,  ..., -5.9233e-02,
           -2.9944e-02, -1.3844e-01],
          [-2.7614e-03, -6.0083e-02, -5.7322e-02,  ..., -5.5806e-02,
           -2.7615e-02, -1.3302e-01],
          [ 1.1273e-03, -4.9353e-02, -5.0480e-02,  ..., -5.2080e-02,
           -2.4260e-02, -1.3152e-01],
          ...,
          [-1.1659e-01, -1.7738e-01, -6.0785e-02,  ..., -5.0219e-02,
           -5.6028e-02, -8.9997e-02],
          [-1.1820e-01, -1.7607e-01, -5.7866e-02,  ..., -4.9025e-02,
           -5.6014e-02, -9.0067e-02],
          [-1.1803e-01, -1.7583e-01, -5.7795e-02,  ..., -4.7815e-02,
           -5.6048e-02, -8.9954e-02]],
 
         [[ 2.0451e-01,  1.6684e-01, -3.7666e-02,  ...,  1.7393e-01,
            3.1393e-01,  3.7568e-01],
          [ 2.2731e-01,  1.8720e-01, -4.0111e-02,  ...,  1.8000e-01,
            3.2325e-01,  4.0576e-01],
          [ 2.4771e-01,  2.0193e-01, -4.5775e-02,  ...,  1.8846e-01,
            3.3473e-01,  4.

In [14]:
for batch in tqdm(tst_loader):
    break
batch['X'].shape, batch['X'], batch['y'].shape, batch['y']

  0%|          | 0/3 [00:01<?, ?it/s]


(torch.Size([128, 4096, 12]),
 tensor([[[ 0.0040,  0.0080,  0.0040,  ...,  0.0080,  0.0020,  0.0060],
          [ 0.0040,  0.0080,  0.0040,  ...,  0.0080,  0.0020,  0.0060],
          [ 0.0040,  0.0080,  0.0040,  ...,  0.0080,  0.0020,  0.0060],
          ...,
          [-0.0456,  0.0416,  0.0872,  ..., -0.0185, -0.0647, -0.0548],
          [-0.0460,  0.0332,  0.0791,  ..., -0.0131, -0.0536, -0.0543],
          [-0.0334,  0.0239,  0.0573,  ...,  0.0124, -0.0467, -0.0500]],
 
         [[ 0.0834,  0.0504, -0.0329,  ...,  0.0371,  0.0463,  0.0486],
          [ 0.0322,  0.0836,  0.0514,  ...,  0.0298,  0.0409,  0.0219],
          [ 0.1092,  0.1542,  0.0450,  ...,  0.0439,  0.0456,  0.0460],
          ...,
          [-0.0237, -0.1663, -0.1426,  ..., -0.0525, -0.1115, -0.1521],
          [-0.0079, -0.1177, -0.1098,  ..., -0.0255, -0.1050, -0.1467],
          [-0.0659, -0.2061, -0.1401,  ..., -0.0650, -0.1419, -0.1692]],
 
         [[ 0.3449,  0.4817,  0.1367,  ...,  0.8018,  0.6890,  0.6276]