In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import datetime as dt
%matplotlib inline

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

pd.set_option('display.max_columns', None)

In [2]:
df = pd.read_csv('imputed.csv')

#cleaning
categorical = ['ethnicity', 
              'marital_status',
              'language',
              'admission_location',
              'gender',
              'insurance',
              'first_careunit',
              'last_careunit',
              'admission_type']
proceduretype=['aortic','mit','tricuspid','pulmonary','cabg']
ptParams = ['weight', 'height']
boolFields = ['reintubation', 'liver_severe', 'liver_mild', 'rheum', 'cvd', 'aids', 'ckd', 'copd', 'arrhythmia', 'pud', 'smoking', 'pvd', 'paraplegia', 
              'ccf', 'met_ca', 't2dm', 't1dm', 'malig', 'mi', 'dementia', 'hospital_expire_flag', 'diab_un', 'diab_cc',]
deathInfo = ['dod', 'deathtime']
ptinfo = ['hadm_id', 'subject_id']
durations = ['duration1', 'duration2', 'icu_stay_duration', 'icu_stay_days']
timeFields =  ['admittime', 'dischtime', 'intime', 'outtime', 'ext_time', 
       'int_time1', 'ext_time1', 'int_time2', 'ext_time2']

tsColumns = [i for i in df.columns if '_max' in i or '_min' in i or '_mean' in i]
print([i for i in df.columns if i not in categorical + proceduretype + tsColumns + ptParams + boolFields + ptinfo + deathInfo + durations + timeFields])

for i in categorical:
    df[i] = df[i].astype('category')
    
df['icu_stay_days'] = [round(i/86400) for i in df['icu_stay_duration']]
df['icu_stay_duration'] /= 3600  # now icu_stay_duration
df = df[[i for i in df.columns if i not in ("Unnamed: 0", "0")]]
df['dod'] = df['dod'].apply(lambda x: dt.datetime.strptime(x, "%Y-%m-%d") if not pd.isnull(x) else np.NaN)
for i in timeFields:
    df[i] = df[i].apply(lambda x: dt.datetime.strptime(x, "%Y-%m-%d %H:%M:%S") if not pd.isnull(x) else np.NaN)
    
df = df[[i for i in df.columns if '_max' not in i and '_min' not in i]]
df = df[[i for i in df.columns if i not in ('last_careunit', 'infection_vent', 'icustay_seq', 'los')]]  # for some reason last_careunit messes up the Cox training

print(df.shape)
df

['Unnamed: 0', 'infection_vent', 'los', 'icustay_seq']
(9474, 91)


Unnamed: 0,temp_mean,hr_mean,spo2_mean,rr_mean,sbp_mean,dbp_mean,meanbp_mean,cardiac_index_mean,pt_mean,ptt_mean,inr_mean,inr_1_mean,fibrinogen_mean,hb_mean,hematocrit_mean,wcc_mean,lymphocytes_mean,neutrophils_mean,chloride_mean,magnesium_mean,potassium_mean,creatinine_mean,free_calcium_mean,sodium_mean,bicarb_mean,bun_mean,hba1c_mean,glucose_mean,lactate_mean,po2_mean,pco2_mean,baseexcess_mean,ph_mean,insulin_mean,prbc_mean,plt_mean,gender,ethnicity,marital_status,insurance,language,aortic,mit,tricuspid,pulmonary,cabg,weight,height,reintubation,liver_severe,liver_mild,rheum,cvd,aids,ckd,copd,arrhythmia,pud,smoking,pvd,paraplegia,ccf,met_ca,t2dm,t1dm,malig,mi,dementia,first_careunit,admission_location,admission_type,hospital_expire_flag,diab_un,diab_cc,icu_stay_duration,admittime,dischtime,intime,outtime,ext_time,deathtime,hadm_id,subject_id,int_time1,ext_time1,duration1,int_time2,ext_time2,duration2,dod,icu_stay_days
0,37.145834,89.393939,97.967742,24.382353,115.666667,60.515152,75.712121,2.759261,14.450000,29.80,1.200,1.200,527.0,12.050000,35.000000,13.900000,17.1,80.1,104.500000,1.80,3.666667,0.80,1.150000,138.000000,25.5,12.0,6.4,124.875000,1.657143,0.500000,41.000000,0.500000,7.400000,335.333333,843.750000,261.500000,M,white,SINGLE,Private,ENGL,0,0,0,0,1,84.00,172.72,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,50.110833,2198-01-31 08:00:00,2198-02-04 12:00:00,2198-01-31 12:27:58,2198-02-02 19:06:39,2198-01-31 22:00:00,,195663,27328,2198-01-31 17:00:00,2198-01-31 22:00:00,5.000000,NaT,NaT,,NaT,2
1,37.023530,92.277778,97.276316,12.807692,98.000000,53.222222,67.861111,2.222764,17.350000,42.80,1.600,1.600,188.5,7.320000,22.000000,19.566667,8.3,80.1,108.000000,2.10,4.375000,0.60,1.128000,142.000000,26.0,12.0,6.3,145.562500,1.400000,0.818182,42.818182,0.818182,7.398182,262.500000,375.000000,122.500000,F,other,DIVORCED,Self Pay,SPAN,1,0,0,0,0,60.00,170.18,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,50.766667,2198-05-08 07:15:00,2198-05-15 13:49:00,2198-05-08 13:14:00,2198-05-10 19:46:00,2198-05-09 09:29:00,,106984,6280,2198-05-08 17:00:00,2198-05-09 09:29:00,16.483333,NaT,NaT,,NaT,2
2,36.883784,85.027027,99.636364,15.342105,121.000000,58.459459,80.108108,2.108540,17.633333,59.30,1.600,1.600,138.0,10.060000,28.000000,21.850000,12.2,84.8,111.000000,2.60,4.285714,1.20,1.156000,140.333333,26.5,19.0,9.9,120.629630,1.400000,1.166667,41.333333,1.166667,7.404286,170.000000,375.000000,144.500000,F,asian,MARRIED,Medicare,CANT,1,0,0,0,1,57.00,165.10,0,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,48.630000,2189-02-18 08:00:00,2189-03-17 14:20:00,2189-02-18 10:51:08,2189-02-20 13:37:48,2189-02-19 09:00:00,,123613,15201,2189-02-18 13:00:00,2189-02-19 09:00:00,20.000000,NaT,NaT,,2191-12-14,2
3,37.532258,87.939394,97.531250,16.212121,113.191176,56.823529,73.575758,2.899448,13.800000,28.30,1.200,1.200,230.0,13.800000,40.166667,14.000000,8.1,88.0,106.000000,1.85,3.800000,0.55,1.125000,139.000000,27.5,12.5,5.6,126.533333,1.785714,0.625000,45.750000,0.625000,7.377778,229.666667,1184.000000,161.000000,M,other,MARRIED,Private,PTUN,0,0,0,0,1,135.00,190.50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,1,0,23.550556,2118-01-25 07:15:00,2118-01-29 13:00:00,2118-01-25 10:46:42,2118-01-26 12:33:02,2118-01-25 17:30:00,,126027,25226,2118-01-25 13:00:00,2118-01-25 17:30:00,4.500000,NaT,NaT,,NaT,1
4,36.880362,87.240000,99.083333,13.397959,113.440000,63.810000,79.446667,2.386696,13.200000,30.50,1.175,1.175,377.0,9.812500,30.000000,11.600000,15.0,61.0,108.000000,3.60,4.420000,1.20,1.027500,137.000000,23.0,16.0,6.4,134.230769,1.500000,-3.000000,44.600000,-3.000000,7.318000,162.666667,187.500000,95.000000,M,white,MARRIED,Medicare,HIND,1,0,0,0,0,70.00,175.26,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,1,0,47.000000,2198-01-01 07:15:00,2198-01-09 13:07:00,2198-01-01 10:47:00,2198-01-03 12:00:00,2198-01-01 21:00:00,,190332,19637,2198-01-01 13:00:00,2198-01-01 21:00:00,8.000000,NaT,NaT,,2203-12-06,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9469,36.813333,71.035714,97.586207,19.666667,102.155172,52.482759,67.224138,2.528107,12.400000,28.95,1.150,1.150,210.0,12.333333,31.000000,16.333333,8.6,80.4,105.500000,2.00,4.000000,0.65,1.120000,137.666667,26.5,9.5,5.5,130.333333,2.780000,140.000000,44.000000,-0.333333,7.370000,4.853254,425.000000,178.333333,M,white,MARRIED,Other,ENGL,1,0,0,0,0,96.80,178.00,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,CVICU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,54.318611,2121-05-01 19:37:00,2121-05-06 16:50:00,2121-05-02 09:33:46,2121-05-04 18:19:07,2121-05-02 18:01:00,,22051087,14971805,2121-05-02 12:00:00,2121-05-02 18:01:00,6.000000,NaT,NaT,,NaT,2
9470,37.481667,88.285714,98.517241,20.413793,111.666667,52.150000,69.966667,2.756978,13.300000,29.90,1.200,1.200,655.0,10.150000,30.500000,7.750000,6.7,90.6,107.333333,2.00,4.300000,0.65,1.096667,136.000000,21.5,11.5,6.4,110.750000,0.900000,185.333333,40.000000,-3.000000,7.350000,3.493849,350.500000,98.000000,M,white,MARRIED,Medicare,ENGL,0,0,0,0,1,85.10,168.00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,CVICU,TRANSFER FROM HOSPITAL,URGENT,0,0,0,70.493056,2111-11-18 18:24:00,2111-11-23 17:00:00,2111-11-19 17:29:25,2111-11-22 17:29:35,2111-11-20 04:30:00,,21555454,15547313,2111-11-19 19:00:00,2111-11-20 04:30:00,9.000000,NaT,NaT,,NaT,3
9471,36.685556,70.500000,96.333333,19.222222,118.948718,59.256410,76.000000,3.026018,12.600000,23.40,1.200,1.200,168.0,9.333333,28.000000,14.700000,11.6,83.6,109.500000,2.30,4.333333,1.30,1.285000,138.500000,21.0,40.0,7.0,120.000000,1.200000,136.600000,41.800000,-2.400000,7.335000,5.596820,374.999994,154.000000,M,white,MARRIED,Medicare,ENGL,0,0,0,0,1,100.00,175.00,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,1,0,0,1,0,CVICU,TRANSFER FROM HOSPITAL,URGENT,0,0,1,25.164167,2156-02-26 18:43:00,2156-03-06 18:30:00,2156-03-01 09:26:32,2156-03-02 17:09:51,2156-03-01 22:27:00,,22084741,16252024,2156-03-01 16:00:00,2156-03-01 22:27:00,6.000000,NaT,NaT,,NaT,1
9472,36.691429,78.571429,98.178571,14.089286,119.214286,55.178571,74.357143,2.568054,13.900000,29.70,1.300,1.300,260.0,10.200000,30.000000,8.800000,8.4,78.2,107.000000,2.00,3.857143,0.60,1.120000,132.500000,24.0,7.0,5.3,101.000000,0.850000,133.777778,40.222222,-2.111111,7.356000,5.552349,375.000000,146.000000,M,white,SINGLE,Medicare,ENGL,0,0,0,0,1,112.05,173.00,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,CVICU,PROCEDURE SITE,EW EMER.,0,0,0,142.522778,2127-03-04 15:15:00,2127-03-15 15:55:00,2127-03-05 10:07:40,2127-03-11 10:31:22,2127-03-06 12:00:00,,25588352,18504988,2127-03-05 12:00:00,2127-03-06 12:00:00,24.000000,2127-03-07 13:00:00,2127-03-08 05:00:00,16.0,NaT,6


In [3]:
# target variable: icu stay duration
data_y = np.array([(True, df['icu_stay_duration'][i]) for i in range(df.shape[0])], dtype=[('Status', '?'), ('Stay_in_hrs', '<f8')])
print(data_y)

# one hot encodes categorical variables and removes columns related to irrelevant stuff
data_x = df[[i for i in df.columns if i not in durations + ptinfo + deathInfo + timeFields]]
print(data_x.columns)

data_x_numeric = OneHotEncoder().fit_transform(data_x)
data_x_numeric.shape

[( True,  50.11083333) ( True,  50.76666667) ( True,  48.63      ) ...
 ( True,  25.16416667) ( True, 142.52277778) ( True,  22.29111111)]
Index(['temp_mean', 'hr_mean', 'spo2_mean', 'rr_mean', 'sbp_mean', 'dbp_mean',
       'meanbp_mean', 'cardiac_index_mean', 'pt_mean', 'ptt_mean', 'inr_mean',
       'inr_1_mean', 'fibrinogen_mean', 'hb_mean', 'hematocrit_mean',
       'wcc_mean', 'lymphocytes_mean', 'neutrophils_mean', 'chloride_mean',
       'magnesium_mean', 'potassium_mean', 'creatinine_mean',
       'free_calcium_mean', 'sodium_mean', 'bicarb_mean', 'bun_mean',
       'hba1c_mean', 'glucose_mean', 'lactate_mean', 'po2_mean', 'pco2_mean',
       'baseexcess_mean', 'ph_mean', 'insulin_mean', 'prbc_mean', 'plt_mean',
       'gender', 'ethnicity', 'marital_status', 'insurance', 'language',
       'aortic', 'mit', 'tricuspid', 'pulmonary', 'cabg', 'weight', 'height',
       'reintubation', 'liver_severe', 'liver_mild', 'rheum', 'cvd', 'aids',
       'ckd', 'copd', 'arrhythmia', 'pud'

(9474, 132)

In [4]:
Xt = np.column_stack((data_x_numeric.values,))

# feature_names = X_no_grade.columns.tolist() + ["tgrade"]
# feature_names
Xt.shape

(9474, 132)

In [5]:
random_state = 20

X_train, X_test, y_train, y_test = train_test_split(
    Xt, data_y, test_size=0.25, random_state=random_state)

In [6]:
y_test.shape

(2369,)

In [7]:
rsf = RandomSurvivalForest(n_estimators=500,
                           min_samples_split=10,
                           min_samples_leaf=15,
                           max_features="sqrt",
                           n_jobs=-1,
                           random_state=random_state)
rsf.fit(X_train, y_train)

RandomSurvivalForest(max_features='sqrt', min_samples_leaf=15,
                     min_samples_split=10, n_estimators=500, n_jobs=-1,
                     random_state=20)

In [8]:
rsf.score(X_test, y_test)

0.6855168435493669