In [35]:
# Import libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
import re
import pickle

import os
path_dir = os.path.dirname(os.getcwd())

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
cd ../src/

/Users/linafaik/Documents/survival_analysis/src


In [37]:
from train import *
from train_survival_ml import *

In [38]:
df = pd.read_csv(os.path.join(path_dir, "outputs/hdhi_clean.csv"))


Columns (1) have mixed types.Specify dtype option on import or set low_memory=False.



In [39]:
# Parameters

scaler_name = "StandardScaler" #MinMaxScaler
random_state=123
test_size = 0.3

# 1. Train / test split

In [45]:
# covariate columns (used when possible)

cols_x = [
    'age', 'gender', 'rural',
    'duration_of_stay', 'duration_of_intensive_unit_stay', 
    'smoking','alcohol', 'dm', 'htn', 'cad', 'prior_cmp', 'ckd', 'hb', 'tlc',
    'platelets', 'glucose', 'urea', 'creatinine', 'raised_cardiac_enzymes',
    'severe_anaemia', 'anaemia', 'stable_angina', 'acs', 'stemi',
    'atypical_chest_pain', 'heart_failure', 'hfref', 'hfnef', 'valvular',
    'chb', 'sss', 'aki', 'cva_infract', 'cva_bleed', 'af', 'vt', 'psvt',
    'congenital', 'uti', 'neuro_cardiogenic_syncope', 'orthostatic',
    'infective_endocarditis', 'dvt', 'cardiogenic_shock', 'shock',
    'pulmonary_embolism', 'chest_infection',
    'type_adm', 
    'first_visit'
]

col_target = "time_before_readm"

In [46]:
Xy_train, Xy_test, y_train, y_test = split_train_test(df, cols_x, col_target, 
                                                      test_size=test_size, random_state=random_state)

In [47]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# rescale
scaler = eval(scaler_name)()

Xy_train[cols_x] = scaler.fit_transform(Xy_train[cols_x])
Xy_test[cols_x] = scaler.transform(Xy_test[cols_x])

# 2. Kaplan-Meier estimator

In [48]:
from sksurv.nonparametric import kaplan_meier_estimator

time, probas = kaplan_meier_estimator(Xy_train["censored"], Xy_train[col_target])

fig = px.line(x=time, y=probas, width=800, height = 400)
fig.update_layout(dict(xaxis={'title' : 'Time (# days)'}, yaxis={'title' : 'Survival probability'}))

In [49]:
from sksurv.nonparametric import kaplan_meier_estimator

for i, age_bin in enumerate(df.age_bin.unique()):
    
    Xy_train_filter = df[df.age_bin == age_bin]

    time, probas = kaplan_meier_estimator(Xy_train_filter["censored"], Xy_train_filter[col_target])
    probas = pd.DataFrame({'time': time, 'age_bin' : age_bin, 'proba_readm': [1-p for p in probas]})
    
    preds = probas if i ==0 else pd.concat([probas, preds], axis=0)

preds.head()

Unnamed: 0,time,age_bin,proba_readm
0,42.0,"[110,120[",0.5
1,212.0,"[110,120[",0.5
0,88.0,"[0,10[",0.066667
1,111.0,"[0,10[",0.133333
2,170.0,"[0,10[",0.2


In [50]:
preds_graph = preds[preds.age_bin!="[110,120["].sort_values(by=["age_bin", "time"])
fig = px.line(preds_graph, x="time", y="proba_readm", color="age_bin", width=800, height = 400)
fig.update_layout(dict(xaxis={'title' : 'nb days'}, yaxis={'title' : 'proba'}))

# 3. Cox PH estimator

## 3.1 Model training & analysis

### Training

In [71]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

# train an estimator
estimator = CoxPHSurvivalAnalysis(alpha=0.5)
estimator = estimator.fit(Xy_train[cols_x], y_train)

feat_importance, fig = plot_feat_imp(cols_x, estimator.coef_)
estimator.score(Xy_test[cols_x], y_test)
feat_importance

Unnamed: 0,feature,coef,coef_abs
45,pulmonary_embolism,-0.001735,0.001735
24,atypical_chest_pain,-0.002547,0.002547
30,sss,-0.005683,0.005683
25,heart_failure,0.006423,0.006423
29,chb,0.00652,0.00652
37,congenital,-0.00662,0.00662
36,psvt,0.006652,0.006652
2,rural,-0.00716,0.00716
1,gender,-0.00716,0.00716
47,type_adm,-0.00716,0.00716


In [72]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

# train an estimator
estimator = CoxPHSurvivalAnalysis(alpha=10)
estimator = estimator.fit(Xy_train[cols_x], y_train)

feat_importance, fig = plot_feat_imp(cols_x, estimator.coef_)
estimator.score(Xy_test[cols_x], y_test)
feat_importance

Unnamed: 0,feature,coef,coef_abs
45,pulmonary_embolism,-0.001768,0.001768
24,atypical_chest_pain,-0.002535,0.002535
30,sss,-0.005676,0.005676
29,chb,0.0065,0.0065
37,congenital,-0.006623,0.006623
36,psvt,0.006643,0.006643
2,rural,-0.007128,0.007128
47,type_adm,-0.007128,0.007128
1,gender,-0.007128,0.007128
19,severe_anaemia,0.008746,0.008746


### Cumulative hazard functions

In [52]:
# predict cumulative hazard function
chf_funcs = estimator.predict_cumulative_hazard_function(Xy_test[cols_x].iloc[:3])

data = [go.Scatter(x=fn.x,y= fn(fn.x), name=i) for i, fn in enumerate(chf_funcs)]
fig = go.Figure(data, layout=dict(width=800, height=400))
fig.update_layout({"yaxis":{"range": [0,1]}})

### Survival functions

In [53]:
# predict survival function
surv_funcs = estimator.predict_survival_function(Xy_test[cols_x].iloc[:3])

# plot results
data = [go.Scatter(x=fn.x,y= fn(fn.x), name=i) for i, fn in enumerate(surv_funcs)]
go.Figure(data, layout=dict(width=800, height=400))

### Feature importance

In [54]:
feat_importance, fig = plot_feat_imp(cols_x, estimator.coef_)
fig

## 3.2. Model evaluation

### C-index

In [55]:
from sksurv.metrics import concordance_index_censored

prediction = estimator.predict(Xy_test[cols_x])
result = concordance_index_censored(list(Xy_test.censored), Xy_test[col_target], prediction)
result
# c-index, concordant,  discordant, tied_risk, tied_time

(0.630632914800985, 3887623, 2277014, 0, 3428)

### Time-dependant AUC

In [56]:
from sksurv.metrics import cumulative_dynamic_auc

times = np.percentile(df[col_target], np.linspace(5, 81, 15))
risk_score = estimator.predict(Xy_test[cols_x]) 

# Possible because the Cox PH is not time-dependant
auc, mean_auc = cumulative_dynamic_auc(y_train, y_test, risk_score, times)
mean_auc

0.6849824900597511

In [57]:
fig = px.line(x=times, y= auc, width=800, height=400)
fig.update_layout({
    "xaxis": dict(title = "Time before readmission (#days)"),
    "yaxis": dict(title = "Time-dependent AUC")
})

### Bier score

In [58]:
from sksurv.metrics import brier_score, integrated_brier_score

In [59]:
survs = estimator.predict_survival_function(Xy_test[cols_x])

In [60]:
T = 364/2
preds = [fn(T) for fn in survs]
times, score = brier_score(y_train, y_test, preds, T)
score

array([0.18845062])

In [61]:
times = np.arange(364/2, 365)

preds = np.asarray([[fn(t) for t in times] for fn in survs])
score = integrated_brier_score(y_train, y_test, preds, times)
print(score)

KeyboardInterrupt: 

## 3.3. Model fine-tuning

In [62]:
from sklearn.model_selection import KFold

cv = KFold(n_splits=5, random_state=random_state, shuffle=True)

grid_params = {"alpha": [1, 20]}

estimator_cox, results = grid_search(grid_params, df, cv, CoxPHSurvivalAnalysis, cols_x,  col_target, verbose = True)

2 total scenario to run
1/2: params: {'alpha': 1}
Fold 0: 0.626
Fold 1: 0.622
Fold 2: 0.638
Fold 3: 0.626
Fold 4: 0.616
2/2: params: {'alpha': 20}
Fold 0: 0.626
Fold 1: 0.622
Fold 2: 0.638
Fold 3: 0.626
Fold 4: 0.617


In [24]:
results

Unnamed: 0,alpha,fold_0,fold_1,fold_2,fold_3,fold_4,mean,std
0,0.3,0.621437,0.624121,0.622463,0.624962,0.61662,0.621921,0.002923
1,0.6,0.621436,0.624131,0.622474,0.624974,0.616613,0.621926,0.00293
2,1.0,0.621443,0.624138,0.622487,0.625005,0.6166,0.621935,0.002942
3,1.2,0.621448,0.624143,0.622492,0.625015,0.616609,0.621941,0.002942


In [25]:
with open(os.path.join(path_dir, "outputs/cox_ph.pkl"), "wb") as f:
    pickle.dump(estimator_cox, f)

# 4. Gradient Boosting Survival Analysis

In [29]:
from sksurv.ensemble import GradientBoostingSurvivalAnalysis

grid_params = {
    "n_estimators": [100,200, 300],
    "min_samples_leaf": [2, 3],
    "random_state": [random_state],
    "verbose":[0]}

estimator_gb, results = grid_search(grid_params, df, cv, GradientBoostingSurvivalAnalysis, cols_x, col_target, verbose = True)

results

6 total scenario to run
1/6: params: {'n_estimators': 100, 'min_samples_leaf': 2, 'random_state': 123, 'verbose': 0}
Fold 0: 0.653
Fold 1: 0.657
Fold 2: 0.655
Fold 3: 0.645
Fold 4: 0.654
2/6: params: {'n_estimators': 100, 'min_samples_leaf': 3, 'random_state': 123, 'verbose': 0}
Fold 0: 0.652
Fold 1: 0.655
Fold 2: 0.653
Fold 3: 0.645
Fold 4: 0.655
3/6: params: {'n_estimators': 200, 'min_samples_leaf': 2, 'random_state': 123, 'verbose': 0}
Fold 0: 0.657
Fold 1: 0.658
Fold 2: 0.655
Fold 3: 0.646
Fold 4: 0.654
4/6: params: {'n_estimators': 200, 'min_samples_leaf': 3, 'random_state': 123, 'verbose': 0}
Fold 0: 0.657
Fold 1: 0.655
Fold 2: 0.654
Fold 3: 0.646
Fold 4: 0.655
5/6: params: {'n_estimators': 300, 'min_samples_leaf': 2, 'random_state': 123, 'verbose': 0}
Fold 0: 0.657
Fold 1: 0.658
Fold 2: 0.654
Fold 3: 0.643
Fold 4: 0.652
6/6: params: {'n_estimators': 300, 'min_samples_leaf': 3, 'random_state': 123, 'verbose': 0}
Fold 0: 0.655
Fold 1: 0.657
Fold 2: 0.652
Fold 3: 0.644
Fold 4: 0.65

Unnamed: 0,n_estimators,min_samples_leaf,random_state,verbose,fold_0,fold_1,fold_2,fold_3,fold_4,mean,std
0,100,2,123,0,0.652857,0.657498,0.65483,0.645339,0.653513,0.652807,0.004059
1,100,3,123,0,0.652061,0.654978,0.652907,0.645416,0.654609,0.651994,0.00346
2,200,2,123,0,0.656616,0.657684,0.655194,0.645512,0.65402,0.653805,0.004329
3,200,3,123,0,0.656897,0.654639,0.653975,0.646013,0.655239,0.653353,0.003796
4,300,2,123,0,0.656999,0.658285,0.653573,0.643493,0.652373,0.652944,0.005196
5,300,3,123,0,0.655376,0.657088,0.651854,0.644326,0.653414,0.652412,0.004412


In [30]:
feat_importance_gb, fig = plot_feat_imp(cols_x, estimator_gb.feature_importances_)
fig

In [31]:
with open(os.path.join(path_dir, "outputs/gradient_boosting.pkl"), "wb") as f:
    pickle.dump(estimator_gb, f)

## 5. Survival Support Vector Machine

In [32]:
from sksurv.svm import FastSurvivalSVM 

In [33]:
from sksurv.svm import FastSurvivalSVM 

grid_params = {
    "alpha": [1,2, 5, 10],
    "rank_ratio": [0],
    "max_iter": [1000],
    "tol": [1e-5],
    "random_state": [random_state],
    "verbose":[0]}

estimator_svm, results = grid_search(grid_params, df, cv, FastSurvivalSVM, cols_x, col_target, verbose = True)

results

4 total scenario to run
1/4: params: {'alpha': 1, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123, 'verbose': 0}
Fold 0: 0.521
Fold 1: 0.562
Fold 2: 0.56
Fold 3: 0.578
Fold 4: 0.57
2/4: params: {'alpha': 2, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123, 'verbose': 0}
Fold 0: 0.521
Fold 1: 0.562
Fold 2: 0.56
Fold 3: 0.589
Fold 4: 0.57
3/4: params: {'alpha': 5, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123, 'verbose': 0}
Fold 0: 0.521
Fold 1: 0.562
Fold 2: 0.56
Fold 3: 0.578
Fold 4: 0.57
4/4: params: {'alpha': 10, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123, 'verbose': 0}
Fold 0: 0.521
Fold 1: 0.562
Fold 2: 0.56
Fold 3: 0.589
Fold 4: 0.57


Unnamed: 0,alpha,rank_ratio,max_iter,tol,random_state,verbose,fold_0,fold_1,fold_2,fold_3,fold_4,mean,std
0,1,0,1000,1e-05,123,0,0.521225,0.562332,0.560435,0.577912,0.570156,0.558412,0.019599
1,2,0,1000,1e-05,123,0,0.521226,0.562351,0.560435,0.589245,0.570176,0.560687,0.022206
2,5,0,1000,1e-05,123,0,0.521225,0.562399,0.560438,0.577922,0.570182,0.558433,0.019606
3,10,0,1000,1e-05,123,0,0.521225,0.562358,0.560435,0.589193,0.570192,0.56068,0.022195


In [34]:
with open(os.path.join(path_dir, "outputs/svm.pkl"), "wb") as f:
    pickle.dump(estimator_svm, f)

In [35]:
from sksurv.svm import FastKernelSurvivalSVM 

grid_params = {
    "kernel": ["linear","poly","rbf","sigmoid","cosine"],
    "alpha": [2],
    "rank_ratio": [0],
    "max_iter": [1000],
    "tol": [1e-5],
    "random_state": [random_state]
}

estimator_ksvm, results = grid_search(grid_params, df, cv, FastKernelSurvivalSVM, cols_x, col_target, verbose = True)

results

5 total scenario to run
1/5: params: {'kernel': 'linear', 'alpha': 2, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123}
Fold 0: 0.481
Fold 1: 0.481
Fold 2: 0.494
Fold 3: 0.511
Fold 4: 0.488
2/5: params: {'kernel': 'poly', 'alpha': 2, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123}
Fold 0: 0.48
Fold 1: 0.481
Fold 2: 0.493
Fold 3: 0.511
Fold 4: 0.488
3/5: params: {'kernel': 'rbf', 'alpha': 2, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123}
Fold 0: 0.508
Fold 1: 0.513
Fold 2: 0.498
Fold 3: 0.49
Fold 4: 0.504
4/5: params: {'kernel': 'sigmoid', 'alpha': 2, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123}
Fold 0: 0.5
Fold 1: 0.5
Fold 2: 0.5
Fold 3: 0.5
Fold 4: 0.5
5/5: params: {'kernel': 'cosine', 'alpha': 2, 'rank_ratio': 0, 'max_iter': 1000, 'tol': 1e-05, 'random_state': 123}
Fold 0: 0.522
Fold 1: 0.512
Fold 2: 0.507
Fold 3: 0.515
Fold 4: 0.518


Unnamed: 0,kernel,alpha,rank_ratio,max_iter,tol,random_state,fold_0,fold_1,fold_2,fold_3,fold_4,mean,std
0,linear,2,0,1000,1e-05,123,0.481132,0.48118,0.494027,0.511236,0.488416,0.491198,0.011126
1,poly,2,0,1000,1e-05,123,0.480275,0.48054,0.493381,0.511083,0.487653,0.490586,0.011347
2,rbf,2,0,1000,1e-05,123,0.507686,0.51314,0.497807,0.490122,0.503848,0.502521,0.007964
3,sigmoid,2,0,1000,1e-05,123,0.5,0.5,0.5,0.5,0.5,0.5,0.0
4,cosine,2,0,1000,1e-05,123,0.522083,0.51222,0.506975,0.514902,0.518447,0.514925,0.00518


In [None]:
with open(os.path.join(path_dir, "outputs/ksvm.pkl"), "wb") as f:
    pickle.dump(estimator_ksvm, f)