# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import h5py
import pandas as pd
import numpy as np
import torch

from tqdm import tqdm

# init

In [3]:
hdf5_path = '/home/josegfer/datasets/ptbxl/ptbxl.h5'
metadata_path = '/home/josegfer/datasets/ptbxl/metadata.csv'

In [8]:
exam_id_col = 'exam_id'
patient_id_col = 'patient_id'
fold_col = 'fold'

# loader

In [5]:
from torch.utils.data import Dataset

In [6]:
metadata = pd.read_csv(metadata_path)
metadata

Unnamed: 0,exam_id,NORM,MI,STTC,CD,HYP,patient_id,fold
0,1,True,False,False,False,False,15709,3
1,2,True,False,False,False,False,13243,2
2,3,True,False,False,False,False,20372,5
3,4,True,False,False,False,False,17014,3
4,5,True,False,False,False,False,17448,4
...,...,...,...,...,...,...,...,...
21832,21833,False,False,True,False,False,17180,7
21833,21834,True,False,False,False,False,20703,4
21834,21835,False,False,True,False,False,19311,2
21835,21836,True,False,False,False,False,8873,8


In [33]:
val_metadata = metadata.loc[metadata[fold_col] == 9]
val_metadata

Unnamed: 0,exam_id,NORM,MI,STTC,CD,HYP,patient_id,fold
7,8,False,True,False,False,False,11275,9
9,10,True,False,False,False,False,9456,9
16,17,False,False,False,False,False,13619,9
17,18,False,False,False,False,False,13619,9
19,20,False,False,False,False,False,13619,9
...,...,...,...,...,...,...,...,...
21776,21777,True,False,False,False,False,8572,9
21787,21788,False,True,False,False,False,12360,9
21816,21817,False,False,False,False,False,18354,9
21830,21831,True,False,False,False,False,11905,9


In [34]:
tst_metadata = metadata.loc[metadata[fold_col] == 10]
tst_metadata

Unnamed: 0,exam_id,NORM,MI,STTC,CD,HYP,patient_id,fold
8,9,True,False,False,False,False,18792,10
37,38,True,False,False,False,False,17076,10
39,40,True,False,False,False,False,19501,10
56,57,True,False,False,False,False,16063,10
58,59,True,False,False,False,False,19475,10
...,...,...,...,...,...,...,...,...
21808,21809,True,False,False,False,False,12931,10
21811,21812,False,False,False,True,False,20789,10
21817,21818,True,False,False,False,False,19204,10
21818,21819,False,False,False,True,False,9843,10


In [35]:
trn_metadata = metadata.loc[(metadata[fold_col] != 9) * (metadata[fold_col] != 10)]
trn_metadata

Unnamed: 0,exam_id,NORM,MI,STTC,CD,HYP,patient_id,fold
0,1,True,False,False,False,False,15709,3
1,2,True,False,False,False,False,13243,2
2,3,True,False,False,False,False,20372,5
3,4,True,False,False,False,False,17014,3
4,5,True,False,False,False,False,17448,4
...,...,...,...,...,...,...,...,...
21831,21832,False,False,False,True,False,7954,7
21832,21833,False,False,True,False,False,17180,7
21833,21834,True,False,False,False,False,20703,4
21834,21835,False,False,True,False,False,19311,2


In [57]:
hdf5_file.close()

In [49]:
class PTBXL():
    def __init__(self, hdf5_path = '/home/josegfer/datasets/ptbxl/ptbxl.h5', 
                 metadata_path = '/home/josegfer/datasets/ptbxl/metadata.csv'):
        self.hdf5_file = h5py.File(hdf5_path, "r")
        self.metadata = pd.read_csv(metadata_path)

        trn_metadata, val_metadata, tst_metadata = self.split()
        # self.check_dataleakage(trn_metadata, val_metadata, tst_metadata)
        
        self.trn_idx_dict = self.get_idx_dict(trn_metadata)
        self.val_idx_dict = self.get_idx_dict(val_metadata)
        self.tst_idx_dict = self.get_idx_dict(tst_metadata)

    def split(self, fold_col = 'fold'): # authors use this split setup
        trn_metadata = self.metadata.loc[(self.metadata[fold_col] != 9) * (self.metadata[fold_col] != 10)]
        val_metadata = self.metadata.loc[self.metadata[fold_col] == 9]
        tst_metadata = self.metadata.loc[self.metadata[fold_col] == 10]

        return trn_metadata, val_metadata, tst_metadata
    
    def check_dataleakage(self, trn_metadata, val_metadata, tst_metadata, exam_id_col = 'exam_id', patient_id_col = 'patient_id'):
        print('checking exam_id leakage')
        trn_ids = set(trn_metadata[exam_id_col].unique())
        val_ids = set(val_metadata[exam_id_col].unique())
        tst_ids = set(tst_metadata[exam_id_col].unique())
        assert (len(trn_ids.intersection(val_ids)) == 0), "Some IDs are present in both train and validation sets."
        assert (len(trn_ids.intersection(tst_ids)) == 0), "Some IDs are present in both train and test sets."
        assert (len(val_ids.intersection(tst_ids)) == 0), "Some IDs are present in both validation and test sets."

        print('checking patient_id leakage')
        trn_ids = set(trn_metadata[patient_id_col].unique())
        val_ids = set(val_metadata[patient_id_col].unique())
        tst_ids = set(tst_metadata[patient_id_col].unique())
        assert (len(trn_ids.intersection(val_ids)) == 0), "Some patient IDs are present in both train and validation sets."
        assert (len(trn_ids.intersection(tst_ids)) == 0), "Some patient IDs are present in both train and test sets."
        assert (len(val_ids.intersection(tst_ids)) == 0), "Some patient IDs are present in both validation and test sets."

    def get_idx_dict(self, split_metadata, exam_id_col = 'exam_id'):
        split_exams, split_h5_idx, temp = np.intersect1d(self.hdf5_file[exam_id_col], split_metadata[exam_id_col].values, return_indices = True)
        split_csv_idx = split_metadata.iloc[temp].index.values
        split_idx_dict = {exam_id_col: split_exams, 'h5_idx': split_h5_idx, 'csv_idx': split_csv_idx}

        print('checking exam_id consistency in idx dict')
        for idx, exam_id in tqdm(enumerate(split_idx_dict[exam_id_col])):
            assert self.hdf5_file[exam_id_col][split_idx_dict['h5_idx'][idx]] == exam_id
            assert self.metadata[exam_id_col][split_idx_dict['csv_idx'][idx]] == exam_id
        return split_idx_dict

In [50]:
data = PTBXL()

0it [00:00, ?it/s]

796it [00:00, 7957.77it/s]

checking exam_id consistency in idx dict


17441it [00:02, 8071.15it/s]
1605it [00:00, 8030.20it/s]

checking exam_id consistency in idx dict


2193it [00:00, 8050.99it/s]
1599it [00:00, 7997.44it/s]

checking exam_id consistency in idx dict


2203it [00:00, 8028.10it/s]


In [51]:
class PTBXLsplit(Dataset):
    def __init__(self, database, split_idx_dict, 
                 tracing_col = 'tracings', exam_id_col = 'exam_id', output_col = ['NORM', 'MI', 'STTC', 'CD', 'HYP']):
        self.database = database
        self.split_idx_dict = split_idx_dict

        self.tracing_col = tracing_col
        self.exam_id_col = exam_id_col
        self.output_col = output_col
    
    def __len__(self):
        return len(self.split_idx_dict[self.exam_id_col])
    
    def __getitem__(self, idx):
        return {'X': self.database.hdf5_file[self.tracing_col][self.split_idx_dict['h5_idx'][idx]], 
                'y': self.database.metadata[self.output_col].loc[self.split_idx_dict['csv_idx'][idx]].values}

In [52]:
trn_ds = PTBXLsplit(data, data.trn_idx_dict)
val_ds = PTBXLsplit(data, data.val_idx_dict)
tst_ds = PTBXLsplit(data, data.tst_idx_dict)

In [53]:
trn_loader = torch.utils.data.DataLoader(trn_ds, batch_size = 128,
                                          shuffle = True, num_workers = 6)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size = 128,
                                          shuffle = False, num_workers = 6)
tst_loader = torch.utils.data.DataLoader(tst_ds, batch_size = 128,
                                          shuffle = False, num_workers = 6)

In [54]:
for batch in tqdm(trn_loader):
    break
batch['X'].shape, batch['X'], batch['y'].shape, batch['y']

  0%|          | 0/137 [00:09<?, ?it/s]


(torch.Size([128, 4096, 12]),
 tensor([[[ 7.9683e-06, -4.5007e-02, -4.5015e-02,  ...,  5.4008e-02,
            0.0000e+00, -2.7004e-02],
          [ 2.0577e-06, -5.1655e-02, -5.1657e-02,  ...,  6.1985e-02,
            0.0000e+00, -3.0993e-02],
          [ 6.7754e-06, -4.9152e-02, -4.9159e-02,  ...,  5.8982e-02,
            0.0000e+00, -2.9491e-02],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],
 
         [[ 1.1690e-01,  4.4829e-02, -7.2067e-02,  ...,  5.3792e-02,
           -1.8029e-02,  1.2606e-01],
          [ 1.3445e-01,  5.1747e-02, -8.2704e-02,  ...,  6.2177e-02,
           -2.0713e-02,  1.4447e-01],
          [ 1.2754e-01,  4.8911e-02, -7.8632e-02,  ...,  5.8578e-02,
           -1.9653e-02,  1.

In [55]:
for batch in tqdm(val_loader):
    break
batch['X'].shape, batch['X'], batch['y'].shape, batch['y']

  0%|          | 0/18 [00:09<?, ?it/s]


(torch.Size([128, 4096, 12]),
 tensor([[[-0.3062, -0.2431,  0.0631,  ...,  0.6301, -0.2340, -0.0180],
          [-0.3510, -0.2789,  0.0721,  ...,  0.7235, -0.2686, -0.0206],
          [-0.3346, -0.2655,  0.0691,  ...,  0.6879, -0.2556, -0.0198],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.1084, -0.3601, -0.4685,  ...,  0.4949,  0.3152,  0.3061],
          [ 0.1231, -0.4133, -0.5364,  ...,  0.5683,  0.3613,  0.3511],
          [ 0.1200, -0.3932, -0.5132,  ...,  0.5404,  0.3445,  0.3345],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[-0.0088, -0.7112, -0.7024,  ...,  0.6210,  0.4954,  0.4504]

In [56]:
for batch in tqdm(tst_loader):
    break
batch['X'].shape, batch['X'], batch['y'].shape, batch['y']

  0%|          | 0/18 [00:13<?, ?it/s]


(torch.Size([128, 4096, 12]),
 tensor([[[-2.1599e-01, -1.5303e-01,  6.2965e-02,  ..., -6.3001e-01,
           -4.5010e-01, -8.0984e-02],
          [-2.4818e-01, -1.7570e-01,  7.2476e-02,  ..., -7.2377e-01,
           -5.1666e-01, -9.3033e-02],
          [-2.3558e-01, -1.6701e-01,  6.8567e-02,  ..., -6.8670e-01,
           -4.9146e-01, -8.8392e-02],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],
 
         [[ 4.3186e-01,  5.3199e-01,  1.0012e-01,  ...,  9.3765e-01,
            7.3913e-01,  7.4840e-01],
          [ 4.9687e-01,  6.0972e-01,  1.1285e-01,  ...,  1.0744e+00,
            8.4709e-01,  8.5643e-01],
          [ 4.6993e-01,  5.8191e-01,  1.1198e-01,  ...,  1.0251e+00,
            8.0823e-01,  8.