## test

In [2]:
from dataloaders.baseline import CODE, CODEsplit

In [3]:
database = CODE()

832it [00:00, 8312.28it/s]

checking exam_id consistency in idx dict


273026it [00:32, 8330.61it/s]
1634it [00:00, 8166.23it/s]

checking exam_id consistency in idx dict


23430it [00:02, 8221.26it/s]
1665it [00:00, 8322.88it/s]

checking exam_id consistency in idx dict


11184it [00:01, 8306.08it/s]


In [4]:
trn_ds = CODEsplit(database, database.trn_idx_dict)
val_ds = CODEsplit(database, database.val_idx_dict)
tst_ds = CODEsplit(database, database.tst_idx_dict)

In [5]:
import torch

trn_loader = torch.utils.data.DataLoader(trn_ds, batch_size = 128,
                                          shuffle = True, num_workers = 6)

In [6]:
from tqdm import tqdm

for batch in tqdm(trn_loader):
    pass

100%|██████████| 2134/2134 [46:35<00:00,  1.31s/it]


# setup

In [None]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import h5py
import pandas as pd
import numpy as np

from tqdm import tqdm

# init

In [3]:
hdf5_path = '/home/josegfer/code/code14/code14.h5'
metadata_path = '/home/josegfer/code/code14/exams.csv'
reports_csv_path = '/home/josegfer/code/code14/BioBERTpt_text_report_crop.h5'

In [4]:
random_seed = 0
val_size = 0.10
tst_size = 0.05

In [5]:
patient_id_col = 'patient_id'
exam_id_col = 'exam_id'

# loader

In [31]:
import torch

from torch.utils.data import Dataset, DataLoader

In [15]:
class CODE():
    def __init__(self, hdf5_path, metadata_path, val_size, tst_size):
        self.hdf5_file = h5py.File(hdf5_path, "r")
        self.metadata = pd.read_csv(metadata_path)

        self.val_size = val_size
        self.tst_size = tst_size

        trn_metadata, val_metadata, tst_metadata = self.split()
        self.check_dataleakage(trn_metadata, val_metadata, tst_metadata)
        
        self.trn_idx_dict = self.get_idx_dict(trn_metadata)
        self.val_idx_dict = self.get_idx_dict(val_metadata)
        self.tst_idx_dict = self.get_idx_dict(tst_metadata)

    def split(self, patient_id_col = 'patient_id'):
        patient_ids = self.metadata[patient_id_col].unique()

        num_trn = int(len(patient_ids) * (1 - self.tst_size - self.val_size))
        num_val = int(len(patient_ids) * self.val_size)

        trn_ids = set(patient_ids[:num_trn])
        val_ids = set(patient_ids[num_trn : num_trn + num_val])
        tst_ids = set(patient_ids[num_trn + num_val :])

        trn_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(trn_ids)]
        val_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(val_ids)]
        tst_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(tst_ids)]

        return trn_metadata, val_metadata, tst_metadata
    
    def check_dataleakage(self, trn_metadata, val_metadata, tst_metadata, exam_id_col = 'exam_id'):
        trn_ids = set(trn_metadata[exam_id_col].unique())
        val_ids = set(val_metadata[exam_id_col].unique())
        tst_ids = set(tst_metadata[exam_id_col].unique())
        assert (len(trn_ids.intersection(val_ids)) == 0), "Some IDs are present in both train and validation sets."
        assert (len(trn_ids.intersection(tst_ids)) == 0), "Some IDs are present in both train and test sets."
        assert (len(val_ids.intersection(tst_ids)) == 0), "Some IDs are present in both validation and test sets."

    def get_idx_dict(self, split_metadata, exam_id_col = 'exam_id'):
        split_exams, split_h5_idx, temp = np.intersect1d(self.hdf5_file[exam_id_col], split_metadata[exam_id_col].values, return_indices = True)
        split_csv_idx = split_metadata.iloc[temp].index.values
        split_idx_dict = {exam_id_col: split_exams, 'h5_idx': split_h5_idx, 'csv_idx': split_csv_idx}

        print('checking exam_id consistency in idx dict')
        for idx, exam_id in tqdm(enumerate(split_idx_dict[exam_id_col])):
            assert self.hdf5_file[exam_id_col][split_idx_dict['h5_idx'][idx]] == exam_id
            assert self.metadata[exam_id_col][split_idx_dict['csv_idx'][idx]] == exam_id
        return split_idx_dict

In [16]:
data = CODE(hdf5_path, metadata_path, val_size, tst_size)

804it [00:00, 8036.71it/s]

checking exam_id consistency in idx dict


273026it [00:33, 8110.67it/s]
817it [00:00, 8161.25it/s]

checking exam_id consistency in idx dict


23430it [00:02, 8248.85it/s]
812it [00:00, 8112.48it/s]

checking exam_id consistency in idx dict


11184it [00:01, 8150.12it/s]


In [17]:
split_idx_dict = data.trn_idx_dict

In [20]:
len(split_idx_dict['exam_id']), len(split_idx_dict['h5_idx']), len(split_idx_dict['csv_idx'])

(273026, 273026, 273026)

In [25]:
data.hdf5_file['exam_id'][0], data.hdf5_file['tracings'][0]

(590673,
 array([[-4.87810344e-01, -2.66771287e-01,  2.21039057e-01, ...,
         -1.17379367e+00, -5.56408703e-01, -4.87810344e-01],
        [-4.81065780e-01, -2.60196328e-01,  2.20869452e-01, ...,
         -1.16749966e+00, -5.51513910e-01, -4.82329220e-01],
        [-4.79793221e-01, -2.57917106e-01,  2.21876070e-01, ...,
         -1.15992427e+00, -5.42299151e-01, -4.77297604e-01],
        ...,
        [-1.37249243e+00, -1.28117001e+00,  9.13222730e-02, ...,
         -8.78245056e-01, -5.03962398e-01,  8.59386753e-04],
        [-1.36670578e+00, -1.27589798e+00,  9.08078253e-02, ...,
         -8.68614078e-01, -4.97095346e-01,  1.69357355e-03],
        [-1.36146402e+00, -1.26946747e+00,  9.19967964e-02, ...,
         -8.57375383e-01, -4.82351691e-01,  1.05662365e-02]]))

In [28]:
data.metadata['exam_id'][0]

1169160

In [70]:
data.trn_idx_dict['csv_idx'][0]

70599

In [73]:
data.metadata[["1dAVb", "RBBB", "LBBB", "SB", "AF", "ST"]].loc[data.trn_idx_dict['csv_idx'][0:3]].values

array([[False, False, False, False,  True, False],
       [False, False, False, False, False, False],
       [False, False, False,  True, False, False]])

In [79]:
class CODEsplit(Dataset):
    def __init__(self, database, split_idx_dict, 
                 tracing_col = 'tracings', exam_id_col = 'exam_id', output_col = ["1dAVb", "RBBB", "LBBB", "SB", "AF", "ST"]):
        self.database = database
        self.split_idx_dict = split_idx_dict

        self.tracing_col = tracing_col
        self.exam_id_col = exam_id_col
        self.output_col = output_col
    
    def __len__(self):
        return len(self.split_idx_dict[exam_id_col])
    
    def __getitem__(self, idx):
        return {'X': self.database.hdf5_file[self.tracing_col][self.split_idx_dict['h5_idx'][idx]], 
                'y': self.database.metadata[self.output_col].loc[self.split_idx_dict['csv_idx'][idx]].values}

In [80]:
trn_ds = CODEsplit(data, data.trn_idx_dict)

In [81]:
trn_loader = torch.utils.data.DataLoader(trn_ds, batch_size = 128,
                                          shuffle = True, num_workers = 6)

In [82]:
for batch in tqdm(trn_loader):
    break

  0%|          | 0/2134 [00:13<?, ?it/s]


# read

In [14]:
hdf5_file = h5py.File(hdf5_path, "r")
metadata = pd.read_csv(metadata_path)
reports = h5py.File(reports_csv_path, "r")

# split

In [15]:
patient_ids = metadata[patient_id_col].unique()

num_trn = int(len(patient_ids) * (1 - tst_size - val_size))
num_val = int(len(patient_ids) * val_size)

trn_ids = set(patient_ids[:num_trn])
val_ids = set(patient_ids[num_trn : num_trn + num_val])
tst_ids = set(patient_ids[num_trn + num_val :])

trn_metadata = metadata.loc[metadata[patient_id_col].isin(trn_ids)]
val_metadata = metadata.loc[metadata[patient_id_col].isin(val_ids)]
tst_metadata = metadata.loc[metadata[patient_id_col].isin(tst_ids)]

# data leakage

In [16]:
trn_ids = set(trn_metadata[exam_id_col].unique())
val_ids = set(val_metadata[exam_id_col].unique())
tst_ids = set(tst_metadata[exam_id_col].unique())

In [17]:
assert (len(trn_ids.intersection(val_ids)) == 0), "Some IDs are present in both train and validation sets."

In [18]:
assert (len(trn_ids.intersection(tst_ids)) == 0), "Some IDs are present in both train and test sets."
assert (len(val_ids.intersection(tst_ids)) == 0), "Some IDs are present in both validation and test sets."

# idx

In [30]:
trn_exams, trn_h5_idx, temp = np.intersect1d(hdf5_file['exam_id'], trn_metadata['exam_id'].values, return_indices = True)
trn_csv_idx = trn_metadata.iloc[temp].index.values
trn_idx_dict = {'exam_id': trn_exams, 'h5_idx': trn_h5_idx, 'csv_idx': trn_csv_idx}

In [33]:
len(trn_idx_dict['exam_id']), len(trn_idx_dict['h5_idx']), len(trn_idx_dict['csv_idx'])

(273026, 273026, 273026)

In [37]:
print('checking exam_id consistency in idx dict')
for idx, exam_id in tqdm(enumerate(trn_idx_dict['exam_id'])):
    assert hdf5_file[exam_id_col][trn_idx_dict['h5_idx'][idx]] == exam_id
    assert metadata[exam_id_col][trn_idx_dict['csv_idx'][idx]] == exam_id

868it [00:00, 8671.04it/s]

checking exam_id consistency in idx dict


273026it [00:29, 9365.08it/s] 
