In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import h5py

import os
# import wandb
import warnings

from config import *
from utils.eda_functions import *
from src.data_preprocessing.demographics import *
from src.data_preprocessing.vitals_labs import *
from src.data_preprocessing.split_dataset import *
from src.data_imputation.new_impute import (calculate_new_values, new_imputer)
from src.data_imputation.hybrid_impute import hybrid_imputer
from utils.safe_display import blind_display

%load_ext autoreload
%autoreload 2

In [2]:
# Safety flag - when set to True will not display sensitive data
BLINDED = False
warnings.filterwarnings('ignore')

In [3]:
# Open the file in read mode
with h5py.File(DATA_FILE_PATH, 'r') as file:
    # Print the keys at the root of the file
    print(list(file.keys()))

patients = pd.read_hdf(DATA_FILE_PATH, 'patients')
vitals_labs = pd.read_hdf(DATA_FILE_PATH, 'vitals_labs')
vitals_labs_mean = pd.read_hdf(DATA_FILE_PATH, 'vitals_labs_mean')
interventions = pd .read_hdf(DATA_FILE_PATH, 'interventions')
codes = pd.read_hdf(DATA_FILE_PATH, 'codes')

['codes', 'interventions', 'patients', 'vitals_labs', 'vitals_labs_mean']


In [4]:
print(patients.shape)
print(vitals_labs_mean.shape)
print(interventions.shape)

(34472, 28)
(2200954, 104)
(2200954, 14)


In [5]:
subjects_48 = vitals_labs[vitals_labs.index.get_level_values('hours_in') == 47].index.get_level_values('subject_id')
len(subjects_48)

17530

In [6]:
# Limit vital signs, lab measurements, and interventions to the subjects_48 patients
patients = patients[patients.index.get_level_values('subject_id').isin(subjects_48)]
vitals_labs = vitals_labs[vitals_labs.index.get_level_values('subject_id').isin(subjects_48)]
vitals_labs_mean = vitals_labs_mean[vitals_labs_mean.index.get_level_values('subject_id').isin(subjects_48)]
interventions = interventions[interventions.index.get_level_values('subject_id').isin(subjects_48)]

In [7]:
# Limit vital signs, lab measurements, and interventions to the first 48 hours of admission
vitals_labs = vitals_labs[vitals_labs.index.get_level_values('hours_in') < 48]
vitals_labs_mean = vitals_labs_mean[vitals_labs_mean.index.get_level_values('hours_in') < 48]
interventions = interventions[interventions.index.get_level_values('hours_in') < 48]

In [9]:
print(patients.shape)
print(vitals_labs.shape)
print(vitals_labs_mean.shape)
print(interventions.shape)

(17530, 28)
(841440, 312)
(841440, 104)
(841440, 14)


## Patients

In [10]:
# group age
patients['age'] = patients['age'].apply(categorize_age)
patients['age'].value_counts(dropna=False)

age
>70      7800
51-70    6448
31-50    2502
<31       780
Name: count, dtype: int64

In [11]:
# Regroup similar categories of ethnicities
patients['ethnicity'] = patients['ethnicity'].apply(categorize_ethnicity)
patients['ethnicity'].value_counts(dropna=False)

ethnicity
WHITE              12416
OTHER/UNKNOWN       2779
BLACK               1333
HISPANIC             538
ASIAN                448
ISLANDER               9
NATIVE AMERICAN        7
Name: count, dtype: int64

In [12]:
# regroup admission types into: EMERGENCY & ELECTIVE
patients['admission_type'] = patients['admission_type'].apply(group_admission_type)
patients['admission_type'].value_counts(dropna=False)

admission_type
EMERGENCY    14815
ELECTIVE      2715
Name: count, dtype: int64

In [13]:
# Drop irrelevant or duplicated columns
columns_to_drop = ['admittime', 'dischtime', 'intime', 'outtime', 'deathtime', 'discharge_location', 'dnr_first_charttime', 'diagnosis_at_admission', 'insurance', 'hospstay_seq', 'hospital_expire_flag', 'los_icu', 'dnr', 'fullcode', 'cmo', 'cmo_last', 'mort_icu']

patients.drop(columns= columns_to_drop, inplace=True)

# Select columns to be one-hot encoded
categorical_cols = ['gender', 'age', 'ethnicity', 'admission_type', 'first_careunit']

patients = pd.get_dummies(patients, columns=categorical_cols)
patients.drop(columns=['gender_F', 'admission_type_ELECTIVE'], inplace=True)
patients.fillna(0, inplace=True)

bool_cols = [col for col in patients.columns if patients[col].dtype == 'bool']
patients[bool_cols] = patients[bool_cols].astype(int)

blind_display(patients, blinded=BLINDED)

patients.shape:  (17530, 24)


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,fullcode_first,dnr_first,cmo_first,mort_hosp,readmission_30,max_hours,gender_M,age_31-50,age_51-70,age_<31,...,ethnicity_ISLANDER,ethnicity_NATIVE AMERICAN,ethnicity_OTHER/UNKNOWN,ethnicity_WHITE,admission_type_EMERGENCY,first_careunit_CCU,first_careunit_CSRU,first_careunit_MICU,first_careunit_SICU,first_careunit_TSICU
subject_id,hadm_id,icustay_id,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
3,145834,211552,1.0,0.0,0.0,0,0,145,1,0,0,0,...,0,0,0,1,1,0,0,1,0,0
6,107064,228232,1.0,0.0,0.0,0,0,88,0,0,1,0,...,0,0,0,1,0,0,0,0,1,0
9,150750,220597,1.0,0.0,0.0,1,0,127,1,1,0,0,...,0,0,1,0,1,0,0,1,0,0
12,112213,232669,1.0,0.0,0.0,1,0,183,1,0,0,0,...,0,0,0,1,0,0,0,0,1,0
13,143045,263738,1.0,0.0,0.0,0,0,87,0,1,0,0,...,0,0,0,1,1,1,0,0,0,0


## Vitals


In [None]:
reset_df = vitals_labs.reset_index()
reduced_df = reset_df.groupby(['subject_id', 'hadm_id', 'icustay_id']).mean().reset_index()
reduced_df.drop(columns=['hours_in'], inplace=True)

melted_df = pd.melt(reduced_df, id_vars=['subject_id', 'hadm_id', 'icustay_id'], var_name=['LEVEL2', 'Aggregation Function'], value_name='Value')
melted_df.drop(columns = ['subject_id', 'hadm_id', 'icustay_id'], inplace=True)

vitals_summary = melted_df.groupby(['LEVEL2', 'Aggregation Function']).mean().reset_index()
vitals_summary = vitals_summary.pivot_table(index='LEVEL2', columns='Aggregation Function', values='Value')
vitals_summary.drop(columns=['count'], inplace=True)

vitals_missing = vitals_labs_mean.isnull().sum() / vitals_labs_mean.shape[0] * 100
vitals_missing = vitals_missing.reset_index()
vitals_missing.drop(columns=['Aggregation Function'], inplace=True)
vitals_missing.rename(columns={0: 'missing percent'}, inplace=True)

vitals_pivot = pd.merge(vitals_summary, vitals_missing, on=['LEVEL2'])
vitals_pivot.rename(columns={'LEVEL2': 'measurement'}, inplace=True)
vitals_pivot.sort_values(by='missing percent', ascending=True, inplace=True)
vitals_pivot.reset_index(drop=True, inplace=True)

vitals_ranges_df = pd.read_csv('../resources/vitals_labs_ranges.csv')
vitals_ranges_df.columns = vitals_ranges_df.columns.str.lower()

merged_vitals = pd.merge(vitals_pivot, vitals_ranges_df[['measurement', 'valid low', 'valid high']], on=['measurement'], how='left')

merged_vitals

In [14]:
common_vitals = pd.read_csv('../resources/common_vitals_labs.csv')

In [15]:
vitals_labs_mean.columns = vitals_labs_mean.columns.droplevel(-1)
columns_to_drop = [col for col in vitals_labs_mean.columns if col not in common_vitals['measurement'].values]
vitals_labs_mean.drop(columns=columns_to_drop, inplace=True)
vitals_labs_mean.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,anion gap,bicarbonate,blood urea nitrogen,co2,calcium,central venous pressure,chloride,creatinine,diastolic blood pressure,fraction inspired oxygen set,...,prothrombin time pt,pulmonary artery pressure systolic,red blood cell count,respiratory rate,sodium,systolic blood pressure,temperature,tidal volume observed,white blood cell count,ph
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
3,145834,211552,0,20.666667,16.333333,44.2,12.0,6.92,,109.5,2.6,39.666667,1.0,...,15.22,,2.884,16.0,142.666667,95.166667,,690.0,14.842857,7.4
3,145834,211552,1,,,,,,,,,44.125,,...,,,,15.5,,81.0,,,,
3,145834,211552,2,,,,,,,,,47.333333,,...,,,,7.0,139.0,90.666667,,800.0,,7.26
3,145834,211552,3,,,,,,13.0,,,64.5,,...,,19.0,,5.25,,117.0,,,,
3,145834,211552,4,,,,,,16.0,,,63.0,0.8,...,,40.0,,13.666667,,102.0,,680.0,,


In [26]:
dead_patients = patients[patients['mort_hosp'] == 1].index.get_level_values('subject_id')
len(dead_patients)

3805

In [27]:
patients['mort_8hosp'].value_counts(normalize=True)

mort_hosp
0    0.782944
1    0.217056
Name: proportion, dtype: float64

In [23]:
parameters = [{'mort_hosp': 1,'gender_M': 0, 'age_31-50': 1},
              {'mort_hosp': 1,'gender_M': 0, 'age_>70': 1},
              {'mort_hosp': 1,'gender_M': 1, 'age_31-50': 1},
              {'mort_hosp': 1,'gender_M': 1, 'age_51-70': 1},
              {'mort_hosp': 1,'gender_M': 1, 'age_>70': 1}]

In [24]:
for p in parameters:
    arr = list(p.items())
    dead_patients = patients[(patients[arr[0][0]] == arr[0][1]) & (patients[arr[1][0]] == arr[1][1]) & (patients[arr[2][0]] == arr[2][1])]
    dead_patients_indices = dead_patients.index.get_level_values('subject_id')
    
    alive_patients = patients[(patients['mort_hosp'] == 0) & (patients[arr[1][0]] == arr[1][1]) & (patients[arr[2][0]] == arr[2][1])]
    alive_patients = alive_patients.iloc[:dead_patients.shape[0]]
    alive_patients_indices = alive_patients.index.get_level_values('subject_id')
    
    # Resample Patients
    patients.loc[alive_patients_indices] = dead_patients.values
    
    # Resample vitals_labs
    vitals_labs.loc[alive_patients_indices] = vitals_labs.loc[dead_patients_indices].values
    
    # Resample vitals_labs_mean
    vitals_labs_mean.loc[alive_patients_indices] = vitals_labs_mean.loc[dead_patients_indices].values
    
    # Resample interventions
    interventions.loc[alive_patients_indices] = interventions.loc[dead_patients_indices].values


In [25]:
print(patients.shape)
print(vitals_labs.shape)
print(vitals_labs_mean.shape)
print(interventions.shape)

(17530, 24)
(841440, 312)
(841440, 39)
(841440, 14)


In [28]:
common_vitals

Unnamed: 0,measurement,normal min,normal high,risk min,risk high,unit of measurement
0,heart rate,60.0,100.0,0,100.0,bpm
1,respiratory rate,12.0,16.0,0,20.0,breaths per minute
2,systolic blood pressure,90.0,120.0,0,140.0,mmHg
3,diastolic blood pressure,60.0,80.0,0,90.0,mmHg
4,mean blood pressure,70.0,100.0,0,110.0,mmHg
5,oxygen saturation,95.0,100.0,0,90.0,%
6,temperature,36.1,37.2,0,37.5,°C
7,glucose,70.0,100.0,0,200.0,mg/dL
8,central venous pressure,2.0,6.0,0,6.0,mmHg
9,glasgow coma scale total,13.0,15.0,0,13.0,


In [29]:
vitals_labs_mean

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,anion gap,bicarbonate,blood urea nitrogen,co2,calcium,central venous pressure,chloride,creatinine,diastolic blood pressure,fraction inspired oxygen set,...,prothrombin time pt,pulmonary artery pressure systolic,red blood cell count,respiratory rate,sodium,systolic blood pressure,temperature,tidal volume observed,white blood cell count,ph
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
3,145834,211552,0,,10.0,,,,,112.0,,,,...,16.4,,,,139.0,,,,,7.08375
3,145834,211552,1,,,,,,,,,,,...,,,,,,,,,,7.22000
3,145834,211552,2,28.0,11.0,28.0,11.0,8.3,,111.0,1.3,72.5,0.6,...,15.5,,4.4,10.000000,145.0,136.25,35.648141,776.0,8.4,7.18000
3,145834,211552,3,,,,,,15.0,,,61.0,0.5,...,,,,12.666667,,105.75,35.500000,576.0,,
3,145834,211552,4,,,,,,13.0,,,68.0,0.5,...,,,,12.800000,,127.00,36.166682,586.0,,7.26000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,137810,229633,43,,,,,,,,,55.0,,...,,,,23.000000,,150.00,,,,
99995,137810,229633,44,,,,,,,,,60.0,,...,,,,23.000000,,151.00,,,,
99995,137810,229633,45,,,,,,,,,51.5,,...,,,,19.000000,,122.50,,,,
99995,137810,229633,46,,,,,,,,,60.0,,...,,,,22.000000,,124.00,36.777778,,,


In [31]:
Ys = patients[['mort_hosp']]
Ys.astype(float)  # Convert to float type

In [53]:
%%time
dead_global_means, dead_icustay_means, alive_global_means, alive_icustay_means = calculate_new_values(vitals_labs, Ys)

CPU times: total: 938 ms
Wall time: 1.48 s


In [None]:
%%time
vitals_labs_imputed = new_imputer(vitals_labs, Ys, dead_global_means, dead_icustay_means, alive_global_means, alive_icustay_means)

In [84]:
from src.data_preprocessing.split_dataset import *

In [88]:
%%time
datasets = train_test_dev_split(patients, vitals_labs_imputed, interventions, Ys)

CPU times: total: 812 ms
Wall time: 1.21 s


In [89]:
%%time
# Define keys and corresponding variable names
keys_varnames = [
    ('patients', 'patients_train', 'patients_dev', 'patients_test'),
    ('vitals', 'vitals_train', 'vitals_dev', 'vitals_test'),
    ('interv', 'interv_train', 'interv_dev', 'interv_test'),
    ('Ys', 'Ys_train', 'Ys_dev', 'Ys_test')
]

# Loop through each key and variable name to extract data and print shapes
for key, train_var, dev_var, test_var in keys_varnames:
    train_data, dev_data, test_data = datasets[key]
    print(f'\n{key.capitalize()}:')
    print(train_data.shape)
    print(dev_data.shape)
    print(test_data.shape)

    # # log the datasets
    # save_to_pickle(train_data, os.path.join(LOG_DATA_DIR, f'{train_var}_split.pkl'))
    # save_to_pickle(dev_data, os.path.join(LOG_DATA_DIR, f'{dev_var}_split.pkl'))
    # save_to_pickle(test_data, os.path.join(LOG_DATA_DIR, f'{test_var}_split.pkl'))

    # set the variables
    globals()[train_var], globals()[dev_var], globals()[test_var] = train_data, dev_data, test_data


Patients:
(12271, 22)
(1753, 22)
(3506, 22)

Vitals:
(589008, 117)
(84144, 117)
(168288, 117)

Interv:
(589008, 14)
(84144, 14)
(168288, 14)

Ys:
(12271, 1)
(1753, 1)
(3506, 1)
CPU times: total: 15.6 ms
Wall time: 15.6 ms


In [91]:
def standardize_gru(vitals_train, vitals_dev, vitals_test):
    idx = pd.IndexSlice
    X_train, X_dev, X_test = vitals_train.copy(), vitals_dev.copy(), vitals_test.copy()

    # Min-Max Scaling
    train_min = X_train.loc[:, idx[:, 'mean']].min()
    train_max = X_train.loc[:, idx[:, 'mean']].max()
    for df in [X_train, X_dev, X_test]:
        df.loc[:, idx[:, 'mean']] = minmax_scaling(df.loc[:, idx[:, 'mean']], train_min, train_max)

    # Standardization
    X_train.loc[:, idx[:, 'time_since_measured']] = np.where(X_train.loc[:, idx[:, 'time_since_measured']] == 100, 0, X_train.loc[:, idx[:, 'time_since_measured']])
    train_mean = X_train.loc[:, idx[:, 'time_since_measured']].mean()
    train_std = X_train.loc[:, idx[:, 'time_since_measured']].std()
    for df in [X_train, X_dev, X_test]:
        df.loc[:, idx[:, 'time_since_measured']] = standardize_time_since_measured(
            df.loc[:, idx[:, 'time_since_measured']], train_mean, train_std)

    return X_train, X_dev, X_test

In [92]:
%%time
vitals_train_std, vitals_dev_std, vitals_test_std = standardize_gru(vitals_train, vitals_dev, vitals_test)

CPU times: total: 1min 54s
Wall time: 2min 8s


In [93]:
RESAMPLED_DIR = '../data/resampled/'

In [94]:
save_to_pickle(df=patients_train, filename=os.path.join(RESAMPLED_DIR, 'patients_train.pkl'))
save_to_pickle(df=patients_dev, filename=os.path.join(RESAMPLED_DIR, 'patients_dev.pkl'))
save_to_pickle(df=patients_test, filename=os.path.join(RESAMPLED_DIR, 'patients_test.pkl'))

save_to_pickle(df=vitals_train_std, filename=os.path.join(RESAMPLED_DIR, 'vitals_train.pkl'))
save_to_pickle(df=vitals_dev_std, filename=os.path.join(RESAMPLED_DIR, 'vitals_dev.pkl'))
save_to_pickle(df=vitals_test_std, filename=os.path.join(RESAMPLED_DIR, 'vitals_test.pkl'))

save_to_pickle(df=interv_train, filename=os.path.join(RESAMPLED_DIR, 'interv_train.pkl'))
save_to_pickle(df=interv_dev, filename=os.path.join(RESAMPLED_DIR, 'interv_dev.pkl'))
save_to_pickle(df=interv_test, filename=os.path.join(RESAMPLED_DIR, 'interv_test.pkl'))

save_to_pickle(df=Ys_train, filename=os.path.join(RESAMPLED_DIR, 'Ys_train.pkl'))
save_to_pickle(df=Ys_dev, filename=os.path.join(RESAMPLED_DIR, 'Ys_dev.pkl'))
save_to_pickle(df=Ys_test, filename=os.path.join(RESAMPLED_DIR, 'Ys_test.pkl'))