In [1]:
import os
import math
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
base_dir = '<path_to_repo>/datasets/trauma_icu_resuscitation/preprocessed_cohort'

# create of a mapping of patient_id to trajectory length
patient_id_to_trajectory_length = dict()
p_ids = {int(f.name.split('_')[-1][:-3]) for f in os.scandir(base_dir) if f.is_file() and not f.name.startswith('.')} # filter out .DS_Store
print(f'Number of patients: {len(p_ids)}')
for p_id in p_ids:
    # dones = torch.load(os.path.join(base_dir, f'dones_{p_id}.pt'))
    missing = torch.load(os.path.join(base_dir, f'missing_{p_id}.pt'))
    patient_id_to_trajectory_length[p_id] = missing.logical_not().sum().item()
print(f'Number of patients: {len(patient_id_to_trajectory_length)}')

Number of patients: 4305
Number of patients: 4305


In [3]:
lengths_array = np.array(list(patient_id_to_trajectory_length.values()))
strat_quantiles = np.quantile(lengths_array, [0.15, 0.3, 0.45, 0.6, 0.75, 0.9])

In [4]:
strat_quantiles

array([ 2., 11., 25., 40., 51., 53.])

In [5]:
# now build the split df - NOTE: I know this is super inefficient - idc - lets just get it done and move on
data_dict = dict(traj=list(patient_id_to_trajectory_length.keys()), traj_length=list(patient_id_to_trajectory_length.values()))
resuscitated = list()
for traj in data_dict['traj']:
    missing_data_mask = torch.load(os.path.join(base_dir, f'missing_{traj}.pt')).squeeze(0)
    rewards = torch.load(os.path.join(base_dir, f'resuscitated_w_time_penalty_rewards_{traj}.pt')).squeeze(0)
    rewards = rewards[~missing_data_mask]
    resuscitated.append((rewards[-1] > 0).item()) # resuscitated if last reward is positive

In [7]:
# resus_quantiles = np.quantile(np.array(resuscitated), [0.0, 0.15, 0.3, 0.6, 0.75, 0.9])
# resus_quantiles
resuscitated

[True,
 True,
 False,
 False,
 True,
 True,
 True,
 True,
 False,
 True,
 True,
 True,
 False,
 True,
 False,
 False,
 True,
 True,
 True,
 False,
 True,
 True,
 True,
 True,
 True,
 True,
 False,
 True,
 False,
 True,
 True,
 True,
 True,
 False,
 False,
 True,
 True,
 True,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 True,
 False,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 False,
 False,
 True,
 True,
 True,
 True,
 False,
 False,
 True,
 True,
 True,
 True,
 True,
 True,
 False,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 False,
 False,
 False,
 False,
 True,
 True,
 True,
 False,
 False,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 False,
 False,
 True,
 True,
 True,
 False,
 False,
 True,
 True,
 False,
 False,
 True,
 False,
 True,
 True,
 False,
 True,
 True,
 True,
 False,
 True,
 False,
 False,
 False,
 True,
 False,
 False,
 True,
 True,


In [6]:
sum(resuscitated)

3027

In [8]:
data_dict['resuscitated'] = resuscitated
strat = list()
for p_resus, p_length in zip(data_dict['resuscitated'], data_dict['traj_length']):
    # match p_vent_free_days:
    #     case _ as v if v <= resus_quantiles[0] or np.isclose(v, resus_quantiles[0]):
    #         strat_num = 0
    #     case _ as v if v <= resus_quantiles[1] or np.isclose(v, resus_quantiles[1]):
    #         strat_num = 1
    #     case _ as v if v <= resus_quantiles[2] or np.isclose(v, resus_quantiles[2]):
    #         strat_num = 2
    #     case _ as v if v <= resus_quantiles[3] or np.isclose(v, resus_quantiles[3]):
    #         strat_num = 3
    #     case _ as v if v <= resus_quantiles[4] or np.isclose(v, resus_quantiles[4]):
    #         strat_num = 4
    #     case _:
    #         strat_num = 5
    if p_length <= strat_quantiles[0]:
        strat_num = 0
    elif p_length <= strat_quantiles[1]:
        strat_num = 1
    elif p_length <= strat_quantiles[2]:
        strat_num = 2
    elif p_length <= strat_quantiles[3]:
        strat_num = 3
    elif p_length <= strat_quantiles[4]:
        strat_num = 4
    elif p_length <= strat_quantiles[5]:
        strat_num = 5
    else:
        strat_num = 6
    strat.append(f'{strat_num}_{int(p_resus)}')
data_dict['strat_col'] = strat
df = pd.DataFrame.from_dict(data_dict)
df.head()

Unnamed: 0,traj,traj_length,resuscitated,strat_col
0,8194,2,True,0_1
1,8195,26,True,3_1
2,16387,52,False,5_0
3,8,63,False,6_0
4,8201,2,True,0_1


In [9]:
df.shape[0]

4305

In [10]:
# Now compute split
train_split = 0.7
val_split = 0.1
test_split = 0.2
num_splits = 10

split_save_path = '<path_to_repo>/datasets/trauma_icu_resuscitation/stratified_splits/'
if not os.path.isdir(split_save_path):
    os.makedirs(split_save_path)
    
for i in range(num_splits):
    remainder_data, test_data = train_test_split(df, test_size=math.ceil(df.shape[0] * test_split), stratify=df['strat_col'])
    train_data, val_data = train_test_split(remainder_data, test_size=math.ceil(df.shape[0] * val_split), stratify=remainder_data['strat_col'])
    train_data.reset_index(drop=True, inplace=True)
    val_data.reset_index(drop=True, inplace=True)
    test_data.reset_index(drop=True, inplace=True)
    train_data['split'] = 'train'
    val_data['split'] = 'val'
    test_data['split'] = 'test'
    data = pd.concat([train_data, val_data, test_data])
    data.drop('strat_col', axis=1, inplace=True)
    assert (data['split'] == 'train').sum() + (data['split'] == 'val').sum() + (data['split'] == 'test').sum() == data.shape[0] == df.shape[0], 'Error in split'
    data.to_csv(os.path.join(split_save_path, f'split_{i}.csv'))
    loaded_data = pd.read_csv(os.path.join(split_save_path, f'split_{i}.csv'), index_col=0)
    assert data.equals(loaded_data), f'Error in saving dataframe {i}'