# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [7]:
import pandas as pd
import pickle

from scipy.io import loadmat, savemat
from tqdm import tqdm

# draft

In [3]:
split_index = 'process/data_split/split1_clean.mat'

In [4]:
split_idx = loadmat(split_index)

train_index, val_index, test_index = split_idx['train_index'], split_idx['val_index'], split_idx['test_index']
train_index = train_index.reshape((train_index.shape[1],))
val_index = val_index.reshape((val_index.shape[1],))
test_index = test_index.reshape((test_index.shape[1],))

In [5]:
split_idx

{'__header__': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Sat Jun 22 14:30:00 2024',
 '__version__': '1.0',
 '__globals__': [],
 'train_index': array([[    0,     1,     2, ..., 42973, 42974, 42975]]),
 'val_index': array([[   12,    14,    16, ..., 42954, 42961, 42965]]),
 'test_index': array([[   10,    11,    13, ..., 42960, 42963, 42966]])}

# code

In [8]:
metadata = pd.read_csv('/home/josegfer/datasets/code/data/exams.csv')

In [9]:
with open('data/remove_id', 'rb') as fp:
    remove = pickle.load(fp)

In [10]:
metadata_clean = metadata.copy()
for exam_id in tqdm(remove):
    metadata_clean = metadata_clean.drop(index = metadata_clean[metadata_clean['exam_id'] == exam_id].index)
metadata_clean = metadata_clean.reset_index(drop = True)

100%|██████████| 1677/1677 [00:32<00:00, 51.36it/s]


In [11]:
metadata_clean

Unnamed: 0,exam_id,age,is_male,nn_predicted_age,1dAVb,RBBB,LBBB,SB,ST,AF,patient_id,death,timey,normal_ecg,trace_file
0,1169160,38,True,40.160484,False,False,False,False,False,False,523632,False,2.098628,True,exams_part13.hdf5
1,2873686,73,True,67.059440,False,False,False,False,False,False,1724173,False,6.657529,False,exams_part13.hdf5
2,168405,67,True,79.621740,False,False,False,False,False,True,51421,False,4.282188,False,exams_part13.hdf5
3,271011,41,True,69.750260,False,False,False,False,False,False,1737282,False,4.038353,True,exams_part13.hdf5
4,384368,73,True,78.873460,False,False,False,False,False,False,331652,False,3.786298,False,exams_part13.hdf5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
344115,1123951,33,True,35.893005,False,False,False,False,False,False,770553,False,2.189039,True,exams_part2.hdf5
344116,954704,73,False,68.169136,False,False,False,False,False,False,1044781,False,2.520546,False,exams_part2.hdf5
344117,589697,75,False,78.080810,False,False,False,False,False,False,1020589,False,3.304107,False,exams_part2.hdf5
344118,2780563,44,False,73.120636,False,False,False,False,False,False,178,False,7.339720,False,exams_part2.hdf5


In [12]:
def check_dataleakage(trn_metadata, val_metadata, tst_metadata, exam_id_col = 'exam_id'):
    trn_ids = set(trn_metadata[exam_id_col].unique())
    val_ids = set(val_metadata[exam_id_col].unique())
    tst_ids = set(tst_metadata[exam_id_col].unique())
    assert (len(trn_ids.intersection(val_ids)) == 0), "Some IDs are present in both train and validation sets."
    assert (len(trn_ids.intersection(tst_ids)) == 0), "Some IDs are present in both train and test sets."
    assert (len(val_ids.intersection(tst_ids)) == 0), "Some IDs are present in both validation and test sets."

In [16]:
def split(metadata, val_size = 0.05, tst_size = 0.05, patient_id_col = 'patient_id'):
    patient_ids = metadata[patient_id_col].unique()

    num_trn = int(len(patient_ids) * (1 - tst_size - val_size))
    num_val = int(len(patient_ids) * val_size)

    trn_ids = set(patient_ids[:num_trn])
    val_ids = set(patient_ids[num_trn : num_trn + num_val])
    tst_ids = set(patient_ids[num_trn + num_val :])

    trn_metadata = metadata.loc[metadata[patient_id_col].isin(trn_ids)]
    val_metadata = metadata.loc[metadata[patient_id_col].isin(val_ids)]
    tst_metadata = metadata.loc[metadata[patient_id_col].isin(tst_ids)]
    check_dataleakage(trn_metadata, val_metadata, tst_metadata)

    return trn_metadata, val_metadata, tst_metadata

In [17]:
trn_metadata, val_metadata, tst_metadata = split(metadata_clean)

In [21]:
trn_metadata.index.values, val_metadata.index.values, tst_metadata.index.values

(array([     0,      1,      2, ..., 344112, 344113, 344118]),
 array([297169, 297170, 297172, ..., 344000, 344024, 344067]),
 array([319905, 319906, 319907, ..., 344116, 344117, 344119]))

In [25]:
split_idx['train_index'] = trn_metadata.index.values
split_idx['val_index'] = val_metadata.index.values
split_idx['test_index'] = tst_metadata.index.values
split_idx

{'__header__': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Sat Jun 22 14:30:00 2024',
 '__version__': '1.0',
 '__globals__': [],
 'train_index': array([     0,      1,      2, ..., 344112, 344113, 344118]),
 'val_index': array([297169, 297170, 297172, ..., 344000, 344024, 344067]),
 'test_index': array([319905, 319906, 319907, ..., 344116, 344117, 344119])}

In [26]:
savemat('data/split_code.mat', split_idx)