In [1]:
# Import libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
import re

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('./src')

In [None]:
cd ../src/

In [5]:
from train import *

ModuleNotFoundError: No module named 'train'

In [None]:
df = pd.read_csv("./datasets/hdhi_clean.csv")

In [None]:
# Parameters

scaler_name = "StandardScaler" #MinMaxScaler
random_state=123

# 1. Train / test splits

In [None]:
# covariate columns (used when possible)

cols_x = [
    'age', 'gender', 'rural',
    'duration_of_stay', 'duration_of_intensive_unit_stay', 
    'smoking_','alcohol', 'dm', 'htn', 'cad', 'prior_cmp', 'ckd', 'hb', 'tlc',
    'platelets', 'glucose', 'urea', 'creatinine', 'raised_cardiac_enzymes',
    'severe_anaemia', 'anaemia', 'stable_angina', 'acs', 'stemi',
    'atypical_chest_pain', 'heart_failure', 'hfref', 'hfnef', 'valvular',
    'chb', 'sss', 'aki', 'cva_infract', 'cva_bleed', 'af', 'vt', 'psvt',
    'congenital', 'uti', 'neuro_cardiogenic_syncope', 'orthostatic',
    'infective_endocarditis', 'dvt', 'cardiogenic_shock', 'shock',
    'pulmonary_embolism', 'chest_infection',
    'type_adm'
]

In [None]:
from sklearn.model_selection import train_test_split

# split train / test

df.sort_values(by="doa", inplace=True)
Xy = df[cols_x+["censored", "time_before_readm"]].dropna()

Xy_train, Xy_test = train_test_split(Xy, test_size=0.3, random_state=random_state)#, stratify=Xy.censored)


y_train = np.array(list(zip(Xy_train.censored, Xy_train.time_before_readm)), 
                   dtype=[('censored', '?'), ('time_before_readm', '<f8')])
y_test = np.array(list(zip(Xy_test.censored, Xy_test.time_before_readm)), 
                   dtype=[('censored', '?'), ('time_before_readm', '<f8')])

In [None]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# rescale
scaler = eval(scaler_name)()

Xy_train[cols_x] = scaler.fit_transform(Xy_train[cols_x])
Xy_test[cols_x] = scaler.transform(Xy_test[cols_x])

# 2. Kaplan-Meier estimator

In [None]:
from sksurv.nonparametric import kaplan_meier_estimator

time, probas = kaplan_meier_estimator(Xy_train["censored"], Xy_train["time_before_readm"])

fig = px.line(x=time, y=probas, width=800, height = 400)
fig.update_layout(dict(xaxis={'title' : 'Time (# days)'}, yaxis={'title' : 'Survival probability'}))

In [None]:
from sksurv.nonparametric import kaplan_meier_estimator

for i, age_bin in enumerate(df.age_bin.unique()):
    
    Xy_train_filter = df[df.age_bin == age_bin]

    time, probas = kaplan_meier_estimator(Xy_train_filter["censored"], Xy_train_filter["time_before_readm"])
    probas = pd.DataFrame({'time': time, 'age_bin' : age_bin, 'proba_readm': [1-p for p in probas]})
    
    preds = probas if i ==0 else pd.concat([probas, preds], axis=0)

preds.head()

In [None]:
preds_graph = preds[preds.age_bin!="[110,120["].sort_values(by=["age_bin", "time"])
fig = px.line(preds_graph, x="time", y="proba_readm", color="age_bin", width=800, height = 400)
fig.update_layout(dict(xaxis={'title' : 'nb days'}, yaxis={'title' : 'proba'}))

# 3. Cox PH estimator

## 3.1 Model training & analysis

### Training

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

# train an estimator
estimator = CoxPHSurvivalAnalysis(alpha=0.5)
estimator = estimator.fit(Xy_train[cols_x], y_train)

### Cumulative hazard functions

In [None]:
# predict cumulative hazard function
chf_funcs = estimator.predict_cumulative_hazard_function(Xy_test[cols_x].iloc[:3])

data = [go.Scatter(x=fn.x,y= fn(fn.x), name=i) for i, fn in enumerate(chf_funcs)]
fig = go.Figure(data, layout=dict(width=800, height=400))
fig.update_layout({"yaxis":{"range": [0,1]}})

### Survival functions

In [None]:
# predict survival function
surv_funcs = estimator.predict_survival_function(Xy_test[cols_x].iloc[:3])

# plot results
data = [go.Scatter(x=fn.x,y= fn(fn.x), name=i) for i, fn in enumerate(surv_funcs)]
go.Figure(data, layout=dict(width=800, height=400))

### Feature importance

In [None]:
feat_importance, fig = plot_feat_imp(cols_x, estimator.coef_)
fig

## 3.2. Model evaluation

### C-index

In [None]:
from sksurv.metrics import concordance_index_censored

prediction = estimator.predict(Xy_test[cols_x])
result = concordance_index_censored(list(Xy_test.censored), Xy_test.time_before_readm, prediction)
result
# c-index, concordant,  discordant, tied_risk, tied_time

### Time-dependant AUC

In [None]:
from sksurv.metrics import cumulative_dynamic_auc

times = np.percentile(df.time_before_readm, np.linspace(5, 81, 15))
risk_score = estimator.predict(Xy_test[cols_x]) 

# Possible because the Cox PH is not time-dependant
auc, mean_auc = cumulative_dynamic_auc(y_train, y_test, risk_score, times)
mean_auc

In [None]:
fig = px.line(x=times, y= auc, width=800, height=400)
fig.update_layout({
    "xaxis": dict(title = "Time before readmission (#days)"),
    "yaxis": dict(title = "Time-dependent AUC")
})

### Bier score

In [None]:
from sksurv.metrics import brier_score, integrated_brier_score

In [None]:
survs = estimator.predict_survival_function(Xy_test[cols_x])

In [None]:
T = 364/2
preds = [fn(T) for fn in survs]
times, score = brier_score(y_train, y_test, preds, T)
score

In [None]:
times = np.arange(364/2, 365)

preds = np.asarray([[fn(t) for t in times] for fn in survs])
score = integrated_brier_score(y_train, y_test, preds, times)
print(score)

## 3.3. Model fine-tuning

In [None]:
col_target = "time_before_readm"

In [None]:
from sklearn.model_selection import KFold

cv = KFold(n_splits=5, random_state=random_state, shuffle=True)

grid_params = {"alpha": [0.3, 0.6, 1, 1.2]}

estimator_cox, results = grid_search(grid_params, df, cv, CoxPHSurvivalAnalysis, cols_x,  col_target, verbose = True)

In [None]:
results

# 4. Gradient Boosting Survival Analysis

In [None]:
from sksurv.ensemble import GradientBoostingSurvivalAnalysis

grid_params = {
    "n_estimators": [100,200, 300],
    "min_samples_leaf": [2, 3],
    "random_state": [random_state],
    "verbose":[1]}

estimator_gb, results = grid_search(grid_params, df, cv, GradientBoostingSurvivalAnalysis, cols_x, col_target, verbose = True)

results

In [None]:
feat_importance_gb, fig = plot_feat_imp(cols_x, estimator_gb.feature_importances_)
fig

## 5. Survival Support Vector Machine

In [None]:
from sksurv.svm import FastSurvivalSVM 

In [None]:
from sksurv.svm import FastSurvivalSVM 

grid_params = {
    "alpha": [1,2, 5, 10],
    "rank_ratio": [0],
    "max_iter": [1000],
    "tol": [1e-5],
    "random_state": [random_state],
    "verbose":[0]}

estimator_svm, results = grid_search(grid_params, df, cv, FastSurvivalSVM, cols_x, col_target, verbose = True)

results

In [None]:
from sksurv.svm import FastKernelSurvivalSVM 

grid_params = {
    "kernel": ["linear","poly","rbf","sigmoid","cosine"],
    "alpha": [2],
    "rank_ratio": [0],
    "max_iter": [1000],
    "tol": [1e-5],
    "random_state": [random_state]
}

estimator_ksvm, results = grid_search(grid_params, df, cv, FastKernelSurvivalSVM, cols_x, col_target, verbose = True)

results