In [1]:
import os
from datetime import datetime
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from matplotlib import pyplot as plt
import seaborn as sns

In [2]:
# File paths and separator

data_path = r"/home/jori152b/DIR/horse/jori152b-medinf/KP_MedInf/model_development/data"

DATA_PATH_stages = os.path.join(data_path, "extracted", "kdigo_stages_measured.csv")
DATA_PATH_labs = os.path.join(data_path, "extracted", "labs_original.csv")
DATA_PATH_labs_extended = os.path.join(data_path, "extracted", "labs_extended.csv")
DATA_PATH_labs_new = os.path.join(data_path, "extracted", "labs_new.csv")
DATA_PATH_vitals = os.path.join(data_path, "extracted", "vitals.csv")
DATA_PATH_vents = os.path.join(data_path, "extracted", "vents_vasopressor_sedatives.csv")
DATA_PATH_detail = os.path.join(data_path, "extracted", "icustay_detail.csv")
DATA_PATH_heightweight = os.path.join(data_path, "extracted", "heightweight.csv")
DATA_PATH_calcium = os.path.join(data_path, "extracted", "calcium.csv")
DATA_PATH_inr_max = os.path.join(data_path, "extracted", "inr_max.csv")

SEPARATOR = ";"

# Constants
IMPUTE_EACH_ID = False
IMPUTE_COLUMN = False
TESTING = False
TEST_SIZE = 0.05
SPLIT_SIZE = 0.2
MAX_DAYS = 35
CLASS1 = True
ALL_STAGES = False
MAX_FEATURE_SET = True
FIRST_TURN_POS = True
TIME_SAMPLING = True
SAMPLING_INTERVAL = '6H'
RESAMPLE_LIMIT = 16
MOST_COMMON = False
IMPUTE_METHOD = 'most_frequent'
FILL_VALUE = 0
ADULTS_MIN_AGE = 18
ADULTS_MAX_AGE = 120
NORMALIZATION = 'min-max'
HOURS_AHEAD = 48
NORM_TYPE = 'min_max'
RANDOM = 42

def filter_by_length_of_stay(X):
    drop_list = []
    long_stays = X.groupby(['icustay_id']).apply(lambda group: (group['charttime'].max() - group['charttime'].min()).total_seconds() / (24 * 60 * 60) > MAX_DAYS)

    for icustay_id, is_long in long_stays.items():
        if is_long:
            max_time = X[X['icustay_id'] == icustay_id]['charttime'].max() - pd.to_timedelta(MAX_DAYS, unit='D')
            X = X[~((X['icustay_id'] == icustay_id) & (X['charttime'] < max_time))]

    short_stays = X.groupby(['icustay_id']).apply(lambda group: (group['charttime'].max() - group['charttime'].min()).total_seconds() / (24 * 60 * 60) < (HOURS_AHEAD/24))
    drop_list = short_stays[short_stays].index.tolist()

    X = X[~X.icustay_id.isin(drop_list)]
    return X

In [None]:
# Load datasets
print("Loading datasets...")
X = pd.read_csv(DATA_PATH_stages, sep=SEPARATOR)
print("X", X.columns)
X.drop(["aki_stage_creat", "aki_stage_uo"], axis=1, inplace=True)
X = X.dropna(how='all', subset=['creat', 'uo_rt_6hr', 'uo_rt_12hr', 'uo_rt_24hr', 'aki_stage'])
X['charttime'] = pd.to_datetime(X['charttime'])

print(len(X))
print(X['aki_stage'].value_counts())

dataset_detail = pd.read_csv(DATA_PATH_detail, sep=SEPARATOR)
print("dataset_detail", dataset_detail.columns)
# original data
# out data
dataset_detail.drop(['dod', 'admittime', 'dischtime', 'los_hospital', 'ethnicity', 
                     'hospital_expire_flag', 'hospstay_seq', 'first_hosp_stay', 'intime', 
                     'outtime', 'los_icu', 'icustay_seq', 'first_icu_stay', 'ethnicity_grouped'], axis=1, inplace=True)


dataset_labs = pd.read_csv(DATA_PATH_labs, sep=SEPARATOR)
dataset_labs.drop(['glucose_min', 'glucose_max','creatinine_min', 'creatinine_max'], axis = 1, inplace = True)
print("dataset_labs", dataset_labs.columns) 
dataset_labs = dataset_labs.dropna(subset=['charttime']).dropna(subset=dataset_labs.columns[4:], how='all')
dataset_labs['charttime'] = pd.to_datetime(dataset_labs['charttime'])
dataset_labs = dataset_labs.sort_values(by=['icustay_id', 'charttime'])
dataset_labs.drop(['albumin_min', 'albumin_max','bilirubin_min', 'bilirubin_max','bands_min', 'bands_max',
                   'lactate_min', 'lactate_max','platelet_min', 'platelet_max','ptt_min', 'ptt_max', 
                   'inr_min', 'inr_max', 'pt_min', 'pt_max'], axis = 1, inplace = True)
# Calculate mean for each pair and drop original columns
column_pairs = [('aniongap_min', 'aniongap_max'), ('albumin_min', 'albumin_max'), 
                ('bands_min', 'bands_max'), ('bicarbonate_min', 'bicarbonate_max'), 
                ('bilirubin_min', 'bilirubin_max'), 
                ('chloride_min', 'chloride_max'), 
                ('hematocrit_min', 'hematocrit_max'), ('hemoglobin_min', 'hemoglobin_max'), 
                ('lactate_min', 'lactate_max'), ('platelet_min', 'platelet_max'), 
                ('potassium_min', 'potassium_max'), ('ptt_min', 'ptt_max'), 
                ('inr_min', 'inr_max'), ('pt_min', 'pt_max'), ('sodium_min', 'sodium_max'), 
                ('bun_min', 'bun_max'), ('wbc_min', 'wbc_max')]

for min_col, max_col in column_pairs:
    try:
        mean_col = min_col.rsplit('_', 1)[0] + '_mean'
        dataset_labs[mean_col] = dataset_labs[[min_col, max_col]].mean(axis=1)
        dataset_labs.drop([min_col, max_col], axis=1, inplace=True)
    except:
        pass

dataset_labs_extended = pd.read_csv(DATA_PATH_labs_extended, sep=SEPARATOR)
dataset_labs_extended.drop(['glucose_min', 'glucose_max','creatinine_min', 'creatinine_max'], axis = 1, inplace = True)

print("dataset_labs_extended", dataset_labs_extended.columns)
dataset_labs_extended = dataset_labs_extended.dropna(subset=['charttime']).dropna(subset=dataset_labs_extended.columns[4:], how='all')
dataset_labs_extended['charttime'] = pd.to_datetime(dataset_labs_extended['charttime'])
dataset_labs_extended = dataset_labs_extended.sort_values(by=['icustay_id', 'charttime'])

column_pairs_extended = [('aniongap_min', 'aniongap_max'), ('albumin_min', 'albumin_max'), 
                ('bands_min', 'bands_max'), ('bicarbonate_min', 'bicarbonate_max'), 
                ('bilirubin_min', 'bilirubin_max'), 
                ('chloride_min', 'chloride_max'), 
                ('hematocrit_min', 'hematocrit_max'), ('hemoglobin_min', 'hemoglobin_max'), 
                ('lactate_min', 'lactate_max'), ('platelet_min', 'platelet_max'), 
                ('potassium_min', 'potassium_max'), ('ptt_min', 'ptt_max'), 
                ('inr_min', 'inr_max'), ('pt_min', 'pt_max'), ('sodium_min', 'sodium_max'), 
                ('bun_min', 'bun_max'), ('wbc_min', 'wbc_max'), 
                ('gfr_min', 'gfr_max'), ('phosphate_min', 'phosphate_max'),('uric_acid_min', 'uric_acid_max'), 
                ('calcium_min', 'calcium_max')]


for min_col, max_col in column_pairs_extended:
    mean_col = min_col.rsplit('_', 1)[0] + '_mean'
    dataset_labs_extended[mean_col] = dataset_labs_extended[[min_col, max_col]].mean(axis=1)
    dataset_labs_extended.drop([min_col, max_col], axis=1, inplace=True)
dataset_labs_extended.drop(['gfr_mean'], axis=1, inplace=True)

dataset_vitals = pd.read_csv(DATA_PATH_vitals, sep=SEPARATOR)
print("dataset_vitals", dataset_vitals.columns)
dataset_vitals.drop(["glucose_min", "glucose_max"], axis=1, inplace=True)
dataset_vents = pd.read_csv(DATA_PATH_vents, sep=SEPARATOR)
print("dataset_vents", dataset_vents.columns)
dataset_vitals.drop(["heartrate_min", "heartrate_max", "sysbp_min", "sysbp_max", "diasbp_min", "diasbp_max",
                        'meanbp_min', 'meanbp_max', 'tempc_min', 'tempc_max', "resprate_min", "resprate_max", 
                        "spo2_min", "spo2_max"], axis=1, inplace=True)
dataset_vitals['charttime'] = pd.to_datetime(dataset_vitals['charttime'])
dataset_vents['charttime'] = pd.to_datetime(dataset_vents['charttime'])
dataset_vitals = dataset_vitals.dropna(subset=dataset_vitals.columns[4:], how='all')
dataset_vitals = dataset_vitals.sort_values(by=['icustay_id', 'charttime'])
dataset_vents = dataset_vents.sort_values(by=['icustay_id', 'charttime'])

dataset_heightweight = pd.read_csv(DATA_PATH_heightweight, sep=SEPARATOR)
print("dataset_heightweight", dataset_heightweight.columns)
dataset_heightweight = dataset_heightweight.dropna(subset=['icustay_id', 'height_first', 'weight_first'], how='all')
dataset_heightweight = dataset_heightweight.sort_values(by=['icustay_id'])

dataset_calcium = pd.read_csv(DATA_PATH_calcium, sep=SEPARATOR)
print("dataset_calcium", dataset_calcium.columns)
dataset_calcium.drop(["hadm_id"], axis=1, inplace=True)
dataset_calcium['charttime'] = pd.to_datetime(dataset_calcium['charttime'])
dataset_calcium = dataset_calcium.sort_values(by=['icustay_id', 'charttime'])
dataset_calcium.drop(["subject_id"], axis=1, inplace=True)

dataset_inr_max = pd.read_csv(DATA_PATH_inr_max, sep=SEPARATOR)
print("dataset_inr_max", dataset_inr_max.columns)
dataset_inr_max.drop(["hadm_id", "subject_id"], axis=1, inplace=True)
dataset_inr_max = dataset_inr_max.sort_values(by=['icustay_id'])

# Merge datasets
if MAX_FEATURE_SET:
    
    # Perform merge operations and then drop the 'subject_id' column
    X_extended = X.merge(dataset_labs_extended, on=["icustay_id", "charttime"], how="outer") \
                   .merge(dataset_vitals, on=["icustay_id", "charttime", "subject_id", "hadm_id"], how="outer") \
                   .merge(dataset_vents, on=["icustay_id", "charttime"], how="outer") \
                   .merge(dataset_calcium, on=["icustay_id", "charttime"], how="outer")
    X_extended.drop(["subject_id"], axis=1, inplace=True)
    
    X_original = X.merge(dataset_labs, on=["icustay_id", "charttime"], how="outer") \
                  .merge(dataset_vitals, on=["icustay_id", "charttime", "subject_id", "hadm_id"], how="outer") \
                  .merge(dataset_vents, on=["icustay_id", "charttime"], how="outer") \
                #   .merge(dataset_calcium, on=["icustay_id", "charttime"], how="outer")
    X_original.drop(["subject_id"], axis=1, inplace=True)

In [None]:
print("X_extended", X_extended.columns)
print("X_original", X_original.columns)

In [None]:
print("Filtering patients by age and length of stay...")
# Filtering patients by age and length of stay
dataset_detail = dataset_detail[dataset_detail['admission_age'] >= ADULTS_MIN_AGE]
adults_icustay_id_list = dataset_detail['icustay_id'].unique()
X = X[X.icustay_id.isin(adults_icustay_id_list)].sort_values(by=['icustay_id', 'charttime'])
X_extended = X_extended[X_extended.icustay_id.isin(adults_icustay_id_list)].sort_values(by=['icustay_id', 'charttime'])
X_original = X_original[X_original.icustay_id.isin(adults_icustay_id_list)].sort_values(by=['icustay_id', 'charttime'])

X = filter_by_length_of_stay(X)
X_extended = filter_by_length_of_stay(X_extended)
X_original = filter_by_length_of_stay(X_original)
dataset_detail = dataset_detail[dataset_detail.icustay_id.isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])

In [5]:
# categorical features
dataset_detail_ours = pd.read_csv(DATA_PATH_detail, sep=SEPARATOR)
# original data
# out data
dataset_detail_ours.drop(['dod', 'admittime', 'dischtime', 'los_hospital', 'ethnicity', 
                     'hospital_expire_flag', 'hospstay_seq', 'first_hosp_stay', 'intime', 
                     'outtime', 'los_icu', 'icustay_seq', 'first_icu_stay', 'ethnicity_grouped'], axis=1, inplace=True)

# categorical features
dataset_detail_theirs = pd.read_csv(DATA_PATH_detail, sep=SEPARATOR)
# original data
# out data
dataset_detail_theirs.drop(['dod', 'admittime', 'dischtime', 'los_hospital', 'ethnicity', 
                     'hospital_expire_flag', 'hospstay_seq', 'first_hosp_stay', 'intime', 
                     'outtime', 'los_icu', 'icustay_seq', 'first_icu_stay'], axis=1, inplace=True)

dataset_heightweight = pd.read_csv(DATA_PATH_heightweight, sep=SEPARATOR)
dataset_heightweight = dataset_heightweight.dropna(subset=['icustay_id', 'height_first', 'weight_first'], how='all')
dataset_heightweight = dataset_heightweight.sort_values(by=['icustay_id'])
dataset_inr_max = pd.read_csv(DATA_PATH_inr_max, sep=SEPARATOR)
dataset_inr_max.drop(["hadm_id", "subject_id"], axis=1, inplace=True)
dataset_inr_max = dataset_inr_max.sort_values(by=['icustay_id'])

In [None]:
# save datasets unresampled for data exploration
datasets = {'X_original': X_original, 'X_extended': X_extended}
for dataset, X in datasets.items():
    X = datasets[dataset].copy()

    if dataset == "X_original":
        print('dataset is X_original')
        # Merging not time-dependent data
        dataset_detail_merging = dataset_detail_theirs.copy()
        print(dataset_detail_merging.columns)
        dataset_detail_merging = dataset_detail_merging[dataset_detail_merging['icustay_id'].isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])
        dataset_detail_merging = pd.get_dummies(dataset_detail_merging, columns=['gender', 'ethnicity_grouped'])
        dataset_detail_merging.drop(['subject_id', 'hadm_id'], axis=1, inplace=True)
        X = X.merge(dataset_detail_merging, on='icustay_id')
        print(X.columns)

    elif dataset == "X_extended":
        print('dataset is X_extended')
        dataset_detail_merging = dataset_detail_ours.copy()
        dataset_detail_merging = dataset_detail_merging[dataset_detail_merging['icustay_id'].isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])
        dataset_detail_merging = pd.get_dummies(dataset_detail_merging, columns=['gender'])
        dataset_detail_merging.drop(['subject_id', 'hadm_id'], axis=1, inplace=True)
        X = X.merge(dataset_detail_merging, on='icustay_id')
        X = X.merge(dataset_heightweight, on='icustay_id')
        X = X.merge(dataset_inr_max, on='icustay_id')   
        print(X.columns)
        
    # save preprocessed data
    # X.to_csv(os.path.join(data_path, 'preprocessed', f'{dataset}.csv'), index=False)

## dataset creation

In [None]:

datasets = {'X_original':X_original, 'X_extended':X_extended}
SAMPLING_INTERVALS = ['2H', '4H', '6H', '8H', '12H', '24H']

label = ['aki_stage']
skip = ['icustay_id', 'charttime', 'aki_stage']
discrete_feat = ['sedative', 'vasopressor', 'vent']
skip.extend(discrete_feat)    

resampled_dir = 'resampled_correct'
if not os.path.exists(os.path.join(data_path, resampled_dir)):
    os.makedirs(os.path.join(data_path, resampled_dir))
    
for dataset in datasets:
    numeric_feat = list(datasets[dataset].columns.difference(skip))
    for SAMPLING_INTERVAL in SAMPLING_INTERVALS:
        RESAMPLE_LIMIT = int(SAMPLING_INTERVAL[:-1]) * 96//int(SAMPLING_INTERVAL[:-1])
        X = datasets[dataset].copy()
        # Resampling
        if TIME_SAMPLING:
            
            # Set index and group by 'icustay_id' before resampling
            X = X.set_index('charttime').groupby('icustay_id').resample(SAMPLING_INTERVAL)
            
            # Resample and aggregate features
            if MAX_FEATURE_SET:
                X_discrete = X[discrete_feat].max().fillna(FILL_VALUE).astype(np.int64)
            X_numeric = X[numeric_feat].mean()
            X_label = X['aki_stage'].max()

            print("Merging sampled features")
            try:
                X = pd.concat([X_numeric, X_discrete, X_label], axis=1).reset_index()
            except:
                X = pd.concat([X_numeric, X_label], axis=1).reset_index()


        # Forward fill again after resampling
        X['aki_stage'] = X.groupby('icustay_id')['aki_stage'].ffill(limit=RESAMPLE_LIMIT).fillna(0)



        # Ensure binary values (convert any positive number to 1)
        X['aki_stage'] = (X['aki_stage'] > 0).astype(int)

        # Shifting labels
        shift_steps = HOURS_AHEAD // int(SAMPLING_INTERVAL[:-1])
        X['aki_stage'] = X.groupby('icustay_id')['aki_stage'].shift(-shift_steps)
        X = X.dropna(subset=['aki_stage'])

        if dataset is "X_original":
            print('dataset is X_original')
            # Merging not time-dependent data
            dataset_detail_merging = dataset_detail_theirs.copy()
            print(dataset_detail_merging.columns)
            dataset_detail_merging = dataset_detail_merging[dataset_detail_merging['icustay_id'].isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])
            dataset_detail_merging = pd.get_dummies(dataset_detail_merging, columns=['gender', 'ethnicity_grouped'])
            dataset_detail_merging.drop(['subject_id', 'hadm_id'], axis=1, inplace=True)
            X = X.merge(dataset_detail_merging, on='icustay_id')

        elif dataset is "X_extended":
            print('dataset is X_extended')
            dataset_detail_merging = dataset_detail_ours.copy()
            dataset_detail_merging = dataset_detail_merging[dataset_detail_merging['icustay_id'].isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])
            dataset_detail_merging = pd.get_dummies(dataset_detail_merging, columns=['gender'])
            dataset_detail_merging.drop(['subject_id', 'hadm_id'], axis=1, inplace=True)
            X = X.merge(dataset_detail_merging, on='icustay_id')
            X = X.merge(dataset_heightweight, on='icustay_id')
            X = X.merge(dataset_inr_max, on='icustay_id')   

        X = X.fillna(FILL_VALUE) 
        # save the data
        X.to_csv(os.path.join(data_path, resampled_dir, f'aki_stage_{dataset}_{SAMPLING_INTERVAL}.csv'), index=False)


In [None]:
# parallel (for hpc)

import os
import pandas as pd
import numpy as np
from joblib import Parallel, delayed

# Number of CPU cores to use
n_jobs = 32  # You can adjust this if necessary

datasets = {'X_original': X_original, 'X_extended': X_extended}
SAMPLING_INTERVALS = ['1H', '2H', '4H', '6H', '8H', '12H', '24H']

label = ['aki_stage']
skip = ['icustay_id', 'charttime', 'aki_stage']
discrete_feat = ['sedative', 'vasopressor', 'vent']
skip.extend(discrete_feat)

resampled_dir = 'resampled_correct'
if not os.path.exists(os.path.join(data_path, resampled_dir)):
    os.makedirs(os.path.join(data_path, resampled_dir))


def process_resampling(dataset_name, sampling_interval):
    dataset = datasets[dataset_name]
    numeric_feat = list(dataset.columns.difference(skip))

    RESAMPLE_LIMIT = int(sampling_interval[:-1]) * 96 // int(sampling_interval[:-1])
    X = dataset.copy()

    # Resampling
    if TIME_SAMPLING:
        # Set index and group by 'icustay_id' before resampling
        X = X.set_index('charttime').groupby('icustay_id').resample(sampling_interval)

        # Resample and aggregate features
        if MAX_FEATURE_SET:
            X_discrete = X[discrete_feat].max().fillna(FILL_VALUE).astype(np.int64)
        X_numeric = X[numeric_feat].mean()
        X_label = X['aki_stage'].max()

        # Merge sampled features
        try:
            X = pd.concat([X_numeric, X_discrete, X_label], axis=1).reset_index()
        except:
            X = pd.concat([X_numeric, X_label], axis=1).reset_index()

    # Forward fill again after resampling
    X['aki_stage'] = X.groupby('icustay_id')['aki_stage'].ffill(limit=RESAMPLE_LIMIT).fillna(0)

    # Ensure binary values (convert any positive number to 1)
    X['aki_stage'] = (X['aki_stage'] > 0).astype(int)

    # Shifting labels
    shift_steps = HOURS_AHEAD // int(sampling_interval[:-1])
    X['aki_stage'] = X.groupby('icustay_id')['aki_stage'].shift(-shift_steps)
    X = X.dropna(subset=['aki_stage'])

    if dataset_name == "X_original":
        # Merging not time-dependent data
        dataset_detail_merging = dataset_detail_theirs.copy()
        dataset_detail_merging = dataset_detail_merging[dataset_detail_merging['icustay_id'].isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])
        dataset_detail_merging = pd.get_dummies(dataset_detail_merging, columns=['gender', 'ethnicity_grouped'])
        dataset_detail_merging.drop(['subject_id', 'hadm_id'], axis=1, inplace=True)
        X = X.merge(dataset_detail_merging, on='icustay_id')

    elif dataset_name == "X_extended":
        dataset_detail_merging = dataset_detail_ours.copy()
        dataset_detail_merging = dataset_detail_merging[dataset_detail_merging['icustay_id'].isin(X['icustay_id'].unique())].sort_values(by=['icustay_id'])
        dataset_detail_merging = pd.get_dummies(dataset_detail_merging, columns=['gender'])
        dataset_detail_merging.drop(['subject_id', 'hadm_id'], axis=1, inplace=True)
        X = X.merge(dataset_detail_merging, on='icustay_id')
        X = X.merge(dataset_heightweight, on='icustay_id')
        X = X.merge(dataset_inr_max, on='icustay_id')

    X = X.fillna(FILL_VALUE)

    # Save the data
    file_path = os.path.join(data_path, resampled_dir, f'aki_stage_{dataset_name}_{sampling_interval}.csv')
    X.to_csv(file_path, index=False)


# Parallel processing of the resampling
Parallel(n_jobs=n_jobs)(
    delayed(process_resampling)(dataset_name, sampling_interval)
    for dataset_name in datasets.keys()
    for sampling_interval in SAMPLING_INTERVALS
)