In [18]:
import module
import torch
import pandas as pd
from torch.utils.data import DataLoader
from itertools import combinations

Data

In [None]:
stage = pd.read_parquet("processed/stage.parquet")
demog = pd.read_parquet("processed/demog.parquet")

In [None]:
stage = stage.fillna(0)
stage = stage.sort_values(by=['stay_id','charttime']).reset_index(drop=True)

In [None]:
primary_id = demog[demog['anchor_year_group'] != '2020 - 2022']
primary_id

In [None]:
temporal_id = demog[demog['anchor_year_group'] == '2020 - 2022']
temporal_id = temporal_id[~temporal_id['subject_id'].isin(primary_id['subject_id'].to_list())]
temporal_id

# Split : Stratified set to 80% : 5% : 5% : 10%

In [None]:
stage_primary = stage[stage['subject_id'].isin(primary_id['subject_id'].to_list())]

target_subject_id = stage_primary.groupby('subject_id', as_index=False).agg({
    'hadm_id': 'max',
    'stay_id': 'max',
    'Liver_disease': 'max',
    'Dehydration / hypovolemia': 'max',
    'Hypertension': 'max',
    'Renal_disease': 'max',
    'Myocardial_infarction': 'max',
    'los': 'max',
    'Diabetes': 'max',
    'Vascular_disease': 'max',
    'Congestive_heart_failure': 'max',
    'Chronic_pulmonary_disease': 'max',
    'race': lambda x: x.mode().iloc[0],    
    'age': 'max',
    'gender': lambda x: x.mode().iloc[0],   
    'first_careunit': lambda x: x.mode().iloc[0], 
    'RRT': 'max',
    'RRT_icu_history': 'max',
    'RRT_hosp_history': 'max',
    'method': lambda x: x.mode().iloc[0],
    'baseline': 'max',
    'max_stage': 'max',
    'length' : 'max'
})

target_stay_id = stage_primary.drop_duplicates(subset='stay_id', keep='last')
target_stay_id = target_stay_id[['subject_id','hadm_id','stay_id','Liver_disease','Dehydration / hypovolemia','Hypertension','Renal_disease','Myocardial_infarction','los',
                           'Diabetes','Vascular_disease','Congestive_heart_failure','Chronic_pulmonary_disease','race','age','gender','first_careunit', 'RRT','RRT_icu_history','RRT_hosp_history','method','baseline','max_stage','length']]

In [None]:
stratify_cols = ['max_stage','race','gender','Dehydration / hypovolemia','Diabetes','first_careunit']
splits_stay, splits_stage = module.split_and_prepare(target_subject_id, target_stay_id, stage_primary, stratify_cols)

torch.save(splits_stay, "processed/splits_stay.pt")
torch.save(splits_stage, "processed/splits_stage.pt")

train = splits_stage['train']
valid = splits_stage['valid']
calibration = splits_stage['calibration']
test = splits_stage['test']
temporal = pd.merge(stage, temporal_id[['stay_id']], on='stay_id', how='inner')

subject_id_sets = {
    'Train': set(train['subject_id'].drop_duplicates()),
    'Validation': set(valid['subject_id'].drop_duplicates()),
    'Calibration': set(calibration['subject_id'].drop_duplicates()),
    'Test': set(test['subject_id'].drop_duplicates()),
    'Temporal': set(temporal['subject_id'].drop_duplicates())
}

common_subject_ids = []

for (name1, ids1), (name2, ids2) in combinations(subject_id_sets.items(), 2):
    common = ids1 & ids2
    if common:
        common_subject_ids.append({'Split 1': name1, 'Split 2': name2, 'Common subject_id': list(common)})

if common_subject_ids:
    common_df = pd.DataFrame(common_subject_ids)
    print("Common subject_id between splits:")
    print(common_df)
else:
    print("No common subject_id between any splits.")

# Datasets

Features

In [None]:
Mask = [col for col in stage.columns 
        if ('mask' in col or 'presence' in col) and 'GT' not in col and col != 'RRT']

Diff = [col for col in stage.columns 
        if 'diff' in col and 'mask' not in col and col != 'charttime_diff']

Binary = [col for col in stage.columns 
          if stage[col].dropna().isin([0, 1]).all() 
          and stage[col].dropna().nunique() == 2 
          and 'GT' not in col and col != 'RRT']

Comorb = [
    'Liver_disease', 'Dehydration / hypovolemia', 'Hypertension', 'Renal_disease', 
    'Myocardial_infarction', 'Diabetes', 'Vascular_disease', 
    'Congestive_heart_failure', 'Chronic_pulmonary_disease'
]

Demog = [
    'BLACK', 'gender', 'Surgical', 'Medical', 'Medical/Surgical', 'Other', 
    'RRT_hosp_history', 'RRT_icu_history', 'method'
]

Vaso = [f"{stat}_{drug}" for stat in ["median", "max"] 
         for drug in ['norad', 'dobut', 'epi', 'vasopressin', 'phenyl', 'dopa', 'milri']]

Input = [col for col in stage.columns if 'input' in col]

SCr = [col for col in stage.columns 
       if 'SCr' in col and all(x not in col for x in ['mask', 'GT', 'presence'])] \
      + ['baseline', 'max', 'min', 'median', 'mean', 'ratio']

Urine = [col for col in stage.columns 
         if 'Urine' in col and all(x not in col for x in ['mask', 'GT', 'presence', 'diff']) 
         and col != 'charttime_diff'] + ['cum_value']

Weight = ['Weight']

Vital = ['temperature', 'heartrate', 'sbp', 'dbp', 'resprate', 'o2sat']

Lab = ['Hemoglobin', 'Hemoglobin_diff', 'WBC', 'Sodium', 'Potassium', 'BUN', 'Platelet',
       'Glucose', 'HCO3', 'Chloride', 'Hematocrit', 'AnionGap', 'Calcium']

Time = ['age', 'current_charttime']

Dataset

In [None]:
existence_ABC = sorted(set(Mask + Comorb + Binary + Demog))
numeric_ABC = sorted(set(SCr + Urine + Weight + Diff + Vaso + Input + Vital + Lab + Time))
GT_presence_ABC = [f'GT_presence_{h}' for h in [6, 12, 18, 24, 30, 36, 42, 48]]
GT_stage_ABC = ['GT_stage_3D', 'GT_stage_3', 'GT_stage_2', 'GT_stage_1']

exclude_feats = {
    '6h', '12h', '24h', '6h_mask', '12h_mask', '24h_mask',
    'Anuria_12h', 'Anuria_12h_mask', 'cum_value_mask',
    'charttime_diff_mask', 'cum_time_diff_mask'
}

exclude_numeric = {'charttime_diff', 'cum_time_diff', 'cum_value'}

existence_D = sorted({col for col in (Mask + Comorb + Binary + Demog) if col not in exclude_feats and 'Urine' not in col})
numeric_D = sorted({col for col in (SCr + Weight + Diff + Vaso + Input + Vital + Lab + Time) if col not in exclude_numeric and 'Urine' not in col})

GT_presence_D = GT_presence_ABC
GT_stage_D = GT_stage_ABC

GT_presence_E = [f"GT_presence_{h}_SCr" for h in [6, 12, 18, 24, 30, 36, 42, 48]]
GT_stage_E = ["GT_stage_3D_SCr", "GT_stage_3_SCr", "GT_stage_2_SCr", "GT_stage_1_SCr"]

datasets = {}

for model_name, existence, numeric, GT_p, GT_s in [
    ("ABC", existence_ABC, numeric_ABC, GT_presence_ABC, GT_stage_ABC),
    ("D", existence_D, numeric_D, GT_presence_D, GT_stage_D),
    ("E", existence_D, numeric_D, GT_presence_E, GT_stage_E),
]:
    datasets[model_name] = {
        "train": module.Dataset(train, numeric, existence, GT_p, GT_s),
        "valid": module.Dataset(valid, numeric, existence, GT_p, GT_s),
        "calibration": module.Dataset_test(calibration, numeric, existence, GT_p, GT_s),
        "test": module.Dataset_test(test, numeric, existence, GT_p, GT_s),
    }

torch.save(datasets, "processed/datasets.pt")
print("✅ Saved to processed/datasets.pt")

Shape

In [None]:
train_dataloader = DataLoader(datasets['ABC']['train'], batch_size=1, shuffle=False, drop_last=True)
valid_dataloader = DataLoader(datasets['ABC']['valid'], batch_size=1, shuffle=False, drop_last=True)
calibration_dataloader = DataLoader(datasets['ABC']['calibration'], batch_size=1, shuffle=False, drop_last=True)
test_dataloader = DataLoader(datasets['ABC']['test'], batch_size=1, shuffle=False, drop_last=True)

In [None]:
for batch in train_dataloader.dataset:
    X_numeric, X_presence, Y_main, Y_sub, mask = batch.tensors
    print("X_numeric shape:", X_numeric.shape)
    print("X_presence shape:", X_presence.shape)
    print("Y_main shape:", Y_main.shape)
    print("Y_sub shape:", Y_sub.shape)
    print("mask shape:", mask.shape)