# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import h5py
import pandas as pd
import numpy as np
import torch

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# class

In [3]:
class CODE():
    def __init__(self, hdf5_path = '/home/josegfer/datasets/code/output/code15.h5', 
                 metadata_path = '/home/josegfer/datasets/code/output/metadata.csv', 
                 texth5_path = '/home/josegfer/datasets/code/output/BioBERTpt_text_report.h5', 
                 val_size = 0.05, tst_size = 0.05):
        self.hdf5_file = h5py.File(hdf5_path, 'r')
        self.metadata = pd.read_csv(metadata_path)
        self.texth5_file = h5py.File(texth5_path, 'r')

        self.val_size = val_size
        self.tst_size = tst_size

        self.trn_metadata, self.val_metadata, self.tst_metadata = self.split()

    def split(self, patient_id_col = 'patient_id'):
        patient_ids = self.metadata[patient_id_col].unique()

        num_trn = int(len(patient_ids) * (1 - self.tst_size - self.val_size))
        num_val = int(len(patient_ids) * self.val_size)

        trn_ids = set(patient_ids[:num_trn])
        val_ids = set(patient_ids[num_trn : num_trn + num_val])
        tst_ids = set(patient_ids[num_trn + num_val :])

        trn_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(trn_ids)].reset_index()
        val_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(val_ids)].reset_index()
        tst_metadata = self.metadata.loc[self.metadata[patient_id_col].isin(tst_ids)].reset_index()
        self.check_dataleakage(trn_metadata, val_metadata, tst_metadata)

        return trn_metadata, val_metadata, tst_metadata

    def check_dataleakage(self, trn_metadata, val_metadata, tst_metadata, exam_id_col = 'exam_id'):
        trn_ids = set(trn_metadata[exam_id_col].unique())
        val_ids = set(val_metadata[exam_id_col].unique())
        tst_ids = set(tst_metadata[exam_id_col].unique())
        assert (len(trn_ids.intersection(val_ids)) == 0), "Some IDs are present in both train and validation sets."
        assert (len(trn_ids.intersection(tst_ids)) == 0), "Some IDs are present in both train and test sets."
        assert (len(val_ids.intersection(tst_ids)) == 0), "Some IDs are present in both validation and test sets."

In [4]:
class CODEsplit(Dataset):
    def __init__(self, database, metadata,
                 tracing_col = 'tracings', output_col = ['1dAVb', 'RBBB', 'LBBB', 'SB', 'AF', 'ST'], textfeatures_col = 'embeddings', 
                 exam_id_col = 'exam_id', h5_idx_col = 'h5_idx'):
        self.database = database
        self.metadata = metadata

        self.tracing_col = tracing_col
        self.output_col = output_col
        self.textfeatures_col = textfeatures_col

        self.exam_id_col = exam_id_col
        self.h5_idx_col = h5_idx_col
    
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        return {'x': self.database.hdf5_file[self.tracing_col][self.metadata[self.h5_idx_col].loc[idx]],
                'y': self.metadata[self.output_col].loc[idx].values, 
                'exam_id': self.metadata[self.exam_id_col].loc[idx], 
                'h': self.database.texth5_file[self.textfeatures_col][self.metadata[self.h5_idx_col].loc[idx]]}

# check

In [5]:
db = CODE()

In [6]:
trn_ds = CODEsplit(db, db.trn_metadata)

In [7]:
trn_loader = DataLoader(trn_ds, batch_size = 128, shuffle = True, num_workers = 6)

In [8]:
for batch in (trn_loader):
    break
batch['x'].shape, batch['y'].shape, batch['exam_id'], batch['h'].shape

(torch.Size([128, 4096, 12]),
 torch.Size([128, 6]),
 tensor([1791248, 1572419,  562516,  310531, 1071666,  632537,  500067,   61374,
          943340, 1467955, 1051298,  968516, 1811794, 2929646,  849459, 1116246,
          688719, 1837638, 3086851, 2872443,  336437, 1204787,  326740, 1538981,
          676243,  478226, 1719497, 2966990,  796235, 2890693,  464035, 1357552,
         3096695,  379880, 1714008,  414206, 2905437,  143803, 1260151, 1373160,
         1559696, 3616596,  111887, 3202301,  262549, 1112241, 2533179, 3099585,
           92552, 1901125,  219825,  208715, 2751324,  710433, 4218808,  508430,
          486764, 1856067,  948831,  531820, 1663784, 1493213,  691798, 1000694,
          673612,  536513, 1434839, 1025109,  969884, 2960163, 1058331, 1959184,
          997269, 3789333, 1217086,  671062,  200481, 1059448, 1716340,  242548,
          316693, 1568966, 2869263, 1465339, 2881320, 3209699,  622106, 4390434,
         3228135,   92481,  921953, 1124331,  478221, 13