In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import datetime as dt
%matplotlib inline

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

pd.set_option('display.max_columns', None)

In [4]:
df = pd.read_csv('imputed.csv')

#cleaning
categorical = ['ethnicity', 
              'marital_status',
              'language',
              'admission_location',
              'gender',
              'insurance',
              'first_careunit',
              'last_careunit',
              'admission_type']
proceduretype=['aortic','mit','tricuspid','pulmonary','cabg']
ptParams = ['weight', 'height']
boolFields = ['reintubation', 'liver_severe', 'liver_mild', 'rheum', 'cvd', 'aids', 'ckd', 'copd', 'arrhythmia', 'pud', 'smoking', 'pvd', 'paraplegia', 
              'ccf', 'met_ca', 't2dm', 't1dm', 'malig', 'mi', 'dementia', 'hospital_expire_flag', 'diab_un', 'diab_cc',]
deathInfo = ['dod', 'deathtime']
ptinfo = ['hadm_id', 'subject_id']
durations = ['duration1', 'duration2', 'icu_stay_duration', 'icu_stay_days']
timeFields =  ['admittime', 'dischtime', 'intime', 'outtime', 'ext_time', 
       'int_time1', 'ext_time1', 'int_time2', 'ext_time2']

tsColumns = [i for i in df.columns if '_max' in i or '_min' in i or '_mean' in i]
print([i for i in df.columns if i not in categorical + proceduretype + tsColumns + ptParams + boolFields + ptinfo + deathInfo + durations + timeFields])

for i in categorical:
    df[i] = df[i].astype('category')
    
df['icu_stay_days'] = [round(i/86400) for i in df['icu_stay_duration']]
df['icu_stay_duration'] /= 3600  # now icu_stay_duration
df = df[[i for i in df.columns if i not in ("Unnamed: 0", "0")]]
df['dod'] = df['dod'].apply(lambda x: dt.datetime.strptime(x, "%Y-%m-%d") if not pd.isnull(x) else np.NaN)
for i in timeFields:
    df[i] = df[i].apply(lambda x: dt.datetime.strptime(x, "%Y-%m-%d %H:%M:%S") if not pd.isnull(x) else np.NaN)
    
df = df[[i for i in df.columns if '_max' not in i and '_min' not in i]]
df = df[[i for i in df.columns if i not in ('last_careunit', 'infection_vent', 'icustay_seq', 'los')]]  # for some reason last_careunit messes up the Cox training

print(df.shape)
df

['Unnamed: 0', 'infection_vent', 'los', 'icustay_seq']
(9474, 91)


Unnamed: 0,temp_mean,hr_mean,spo2_mean,rr_mean,sbp_mean,dbp_mean,meanbp_mean,cardiac_index_mean,pt_mean,ptt_mean,inr_mean,inr_1_mean,fibrinogen_mean,hb_mean,hematocrit_mean,wcc_mean,lymphocytes_mean,neutrophils_mean,chloride_mean,magnesium_mean,potassium_mean,creatinine_mean,free_calcium_mean,sodium_mean,bicarb_mean,bun_mean,hba1c_mean,glucose_mean,lactate_mean,po2_mean,pco2_mean,baseexcess_mean,ph_mean,insulin_mean,prbc_mean,plt_mean,gender,ethnicity,marital_status,insurance,language,aortic,mit,tricuspid,pulmonary,cabg,weight,height,reintubation,liver_severe,liver_mild,rheum,cvd,aids,ckd,copd,arrhythmia,pud,smoking,pvd,paraplegia,ccf,met_ca,t2dm,t1dm,malig,mi,dementia,first_careunit,admission_location,admission_type,hospital_expire_flag,diab_un,diab_cc,icu_stay_duration,admittime,dischtime,intime,outtime,ext_time,deathtime,hadm_id,subject_id,int_time1,ext_time1,duration1,int_time2,ext_time2,duration2,dod,icu_stay_days
0,37.477779,91.714286,99.157895,23.363636,120.547619,60.023810,77.904762,2.331152,14.05,29.8,1.200000,1.200000,324.0,12.050000,34.333333,13.90,19.1,75.20,104.500000,2.600000,3.666667,0.80,1.150000,138.000000,25.500000,12.0,5.9,126.692308,1.700000,0.500000,41.000000,0.500000,7.400000,211.333333,360.000000,261.500000,M,white,SINGLE,Private,ENGL,0,0,0,0,1,84.00,172.72,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,50.110833,2198-01-31 08:00:00,2198-02-04 12:00:00,2198-01-31 12:27:58,2198-02-02 19:06:39,2198-01-31 22:00:00,,195663,27328,2198-01-31 17:00:00,2198-01-31 22:00:00,5.000000,NaT,NaT,,NaT,2
1,36.823404,92.122449,97.549020,11.836538,96.714286,51.183673,65.591837,2.081019,17.30,42.8,1.600000,1.600000,203.0,7.075000,21.250000,12.15,10.0,80.30,108.000000,2.100000,4.200000,0.60,1.180000,142.000000,26.000000,12.0,6.0,149.100000,2.100000,0.375000,43.750000,0.375000,7.386250,262.500000,375.000000,122.500000,F,other,DIVORCED,Self Pay,SPAN,1,0,0,0,0,60.00,170.18,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,50.766667,2198-05-08 07:15:00,2198-05-15 13:49:00,2198-05-08 13:14:00,2198-05-10 19:46:00,2198-05-09 09:29:00,,106984,6280,2198-05-08 17:00:00,2198-05-09 09:29:00,16.483333,NaT,NaT,,NaT,2
2,36.591667,87.708333,99.950000,13.960000,121.583333,61.791667,82.291667,2.042943,17.20,59.3,1.600000,1.600000,228.0,9.950000,28.000000,22.80,15.4,83.65,113.000000,2.600000,4.060000,1.20,1.156000,141.000000,25.000000,19.0,6.1,122.894737,2.575000,1.000000,40.750000,1.000000,7.404000,165.333333,375.000000,148.000000,F,white,MARRIED,Medicare,CANT,1,0,0,0,1,57.00,165.10,0,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,48.630000,2189-02-18 08:00:00,2189-03-17 14:20:00,2189-02-18 10:51:08,2189-02-20 13:37:48,2189-02-19 09:00:00,,123613,15201,2189-02-18 13:00:00,2189-02-19 09:00:00,20.000000,NaT,NaT,,2191-12-14,2
3,37.608696,89.782609,98.363636,16.227273,115.565217,58.565217,75.260870,3.040592,15.20,27.5,1.333333,1.333333,254.0,10.450000,31.000000,17.90,7.6,69.20,106.000000,1.800000,3.766667,0.40,1.110000,138.000000,27.000000,14.0,5.1,125.705882,2.200000,0.625000,45.750000,0.625000,7.373750,226.500000,550.000005,140.500000,M,white,MARRIED,Private,ALBA,0,0,0,0,1,135.00,190.50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,1,0,23.550556,2118-01-25 07:15:00,2118-01-29 13:00:00,2118-01-25 10:46:42,2118-01-26 12:33:02,2118-01-25 17:30:00,,126027,25226,2118-01-25 13:00:00,2118-01-25 17:30:00,4.500000,NaT,NaT,,NaT,1
4,36.880000,88.027778,99.457143,12.585714,113.250000,66.041667,81.564815,2.378381,18.40,33.2,1.700000,1.700000,188.0,9.985714,30.000000,9.00,11.6,83.00,105.000000,2.050000,4.166667,0.65,0.963333,137.000000,23.000000,29.5,5.3,133.750000,1.500000,-3.222222,44.000000,-3.222222,7.313333,160.000000,250.000000,165.000000,M,white,MARRIED,Medicare,SPAN,1,0,0,0,0,70.00,175.26,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,CSRU,PHYSICIAN REFERRAL,ELECTIVE,0,1,0,47.000000,2198-01-01 07:15:00,2198-01-09 13:07:00,2198-01-01 10:47:00,2198-01-03 12:00:00,2198-01-01 21:00:00,,190332,19637,2198-01-01 13:00:00,2198-01-01 21:00:00,8.000000,NaT,NaT,,2203-12-06,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9469,36.003226,72.750000,98.176471,18.125000,101.625000,53.562500,68.812500,2.396073,12.20,31.1,1.100000,1.100000,158.0,12.800000,40.750000,16.55,12.9,81.30,108.000000,2.500000,3.920000,0.70,1.093333,139.000000,26.000000,9.0,5.4,126.750000,3.200000,140.000000,44.000000,-0.333333,7.367500,2.545723,356.249996,181.500000,M,white,MARRIED,Other,ENGL,1,0,0,0,0,96.80,178.00,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,CVICU,PHYSICIAN REFERRAL,ELECTIVE,0,0,0,54.318611,2121-05-01 19:37:00,2121-05-06 16:50:00,2121-05-02 09:33:46,2121-05-04 18:19:07,2121-05-02 18:01:00,,22051087,14971805,2121-05-02 12:00:00,2121-05-02 18:01:00,6.000000,NaT,NaT,,NaT,2
9470,37.610000,87.187500,98.941176,18.705882,107.235294,51.764706,68.352941,2.875829,13.30,29.9,1.200000,1.200000,244.0,10.150000,33.000000,7.75,12.0,77.90,109.500000,2.966667,4.360000,0.65,1.096667,136.333333,21.500000,11.5,5.9,110.750000,0.900000,185.333333,40.000000,-3.000000,7.350000,3.991799,375.000000,98.000000,M,white,MARRIED,Medicare,ENGL,0,0,0,0,1,85.10,168.00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,CVICU,TRANSFER FROM HOSPITAL,URGENT,0,0,0,70.493056,2111-11-18 18:24:00,2111-11-23 17:00:00,2111-11-19 17:29:25,2111-11-22 17:29:35,2111-11-20 04:30:00,,21555454,15547313,2111-11-19 19:00:00,2111-11-20 04:30:00,9.000000,NaT,NaT,,NaT,3
9471,36.512000,68.888889,98.000000,17.500000,109.333333,55.444444,70.222222,5.171814,12.60,23.4,1.200000,1.200000,233.0,9.333333,28.000000,14.70,18.5,71.40,109.500000,2.600000,4.333333,1.30,1.285000,138.500000,21.000000,40.0,11.4,120.000000,1.200000,136.600000,41.800000,-2.400000,7.335000,7.224064,350.000010,154.000000,M,white,MARRIED,Medicare,ENGL,0,0,0,0,1,100.00,175.00,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,1,0,0,1,0,CVICU,TRANSFER FROM HOSPITAL,URGENT,0,0,1,25.164167,2156-02-26 18:43:00,2156-03-06 18:30:00,2156-03-01 09:26:32,2156-03-02 17:09:51,2156-03-01 22:27:00,,22084741,16252024,2156-03-01 16:00:00,2156-03-01 22:27:00,6.000000,NaT,NaT,,NaT,1
9472,36.625000,79.071429,99.000000,14.142857,116.357143,54.214286,72.142857,3.300911,16.20,43.9,1.433333,1.433333,135.0,11.600000,34.666667,12.40,11.7,78.15,111.666667,2.800000,3.800000,0.65,1.083333,130.000000,22.333333,12.0,5.5,112.000000,2.266667,145.000000,41.166667,-3.166667,7.335714,4.916971,375.000000,128.000000,M,white,SINGLE,Medicare,ENGL,0,0,0,0,1,112.05,173.00,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,CVICU,PROCEDURE SITE,EW EMER.,0,0,0,142.522778,2127-03-04 15:15:00,2127-03-15 15:55:00,2127-03-05 10:07:40,2127-03-11 10:31:22,2127-03-06 12:00:00,,25588352,18504988,2127-03-05 12:00:00,2127-03-06 12:00:00,24.000000,2127-03-07 13:00:00,2127-03-08 05:00:00,16.0,NaT,6


In [5]:
# target variable: icu stay duration
data_y = np.array([(True, df['icu_stay_duration'][i]) for i in range(df.shape[0])], dtype=[('Status', '?'), ('Stay_in_hrs', '<f8')])
print(data_y)

# one hot encodes categorical variables and removes columns related to irrelevant stuff
data_x = df[[i for i in df.columns if i not in durations + ptinfo + deathInfo + timeFields]]
print(data_x.columns)

data_x.shape

[( True,  50.11083333) ( True,  50.76666667) ( True,  48.63      ) ...
 ( True,  25.16416667) ( True, 142.52277778) ( True,  22.29111111)]
Index(['temp_mean', 'hr_mean', 'spo2_mean', 'rr_mean', 'sbp_mean', 'dbp_mean',
       'meanbp_mean', 'cardiac_index_mean', 'pt_mean', 'ptt_mean', 'inr_mean',
       'inr_1_mean', 'fibrinogen_mean', 'hb_mean', 'hematocrit_mean',
       'wcc_mean', 'lymphocytes_mean', 'neutrophils_mean', 'chloride_mean',
       'magnesium_mean', 'potassium_mean', 'creatinine_mean',
       'free_calcium_mean', 'sodium_mean', 'bicarb_mean', 'bun_mean',
       'hba1c_mean', 'glucose_mean', 'lactate_mean', 'po2_mean', 'pco2_mean',
       'baseexcess_mean', 'ph_mean', 'insulin_mean', 'prbc_mean', 'plt_mean',
       'gender', 'ethnicity', 'marital_status', 'insurance', 'language',
       'aortic', 'mit', 'tricuspid', 'pulmonary', 'cabg', 'weight', 'height',
       'reintubation', 'liver_severe', 'liver_mild', 'rheum', 'cvd', 'aids',
       'ckd', 'copd', 'arrhythmia', 'pud'

(9474, 74)

In [8]:
Xt = np.column_stack((data_x.values, ))

# feature_names = X_no_grade.columns.tolist() + ["tgrade"]
# feature_names
Xt.shape

(74, 9474)

In [None]:
random_state = 20

X_train, X_test, y_train, y_test = train_test_split(
    Xt, data_y, test_size=0.25, random_state=random_state)

In [None]:
rsf = RandomSurvivalForest(n_estimators=1000,
                           min_samples_split=10,
                           min_samples_leaf=15,
                           max_features="sqrt",
                           n_jobs=-1,
                           random_state=random_state)
rsf.fit(X_train, y_train)

In [None]:
rsf.score(X_test, y_test)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

X, y = load_gbsg2()

grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)

X_no_grade = X.drop("tgrade", axis=1)
Xt = OneHotEncoder().fit_transform(X_no_grade)
Xt = np.column_stack((Xt.values, grade_num))

feature_names = X_no_grade.columns.tolist() + ["tgrade"]

In [None]:
Xt.shape
y.shape

In [None]:
random_state = 20

X_train, X_test, y_train, y_test = train_test_split(
    Xt, y, test_size=0.25, random_state=random_state)