In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import lifelines
import pycox
import xgboost as xgb
from sklearn.model_selection import train_test_split

from xgbse import XGBSEKaplanNeighbors, XGBSEKaplanTree, XGBSEDebiasedBCE, XGBSEBootstrapEstimator
from xgbse.converters import convert_data_to_xgb_format, convert_to_structured, convert_y
from xgbse.non_parametric import get_time_bins

from benchmark import BenchmarkLifelines, BenchmarkXGBoost, BenchmarkXGBSE, BenchmarkPysurvival
from pysurvival.models.survival_forest import ConditionalSurvivalForestModel

# setting seed
np.random.seed(42)

In [3]:
def split_train_test_valid(dataf, test_size=.2, valid_size=.01, random_state=1):
    df_train, df_test = train_test_split(dataf, test_size=test_size, random_state=random_state)
    df_train, df_valid = train_test_split(df_train, test_size=valid_size, random_state=random_state)
    return df_train, df_valid,  df_test

## Data

In [4]:
from pycox.datasets import flchain
df = flchain.read_df()
df = df.rename(columns={'death':'event', 'futime':'duration'})
df = df.astype(float)

df_train, df_valid, df_test = split_train_test_valid(df)

T_train = df_train.duration
E_train = df_train.event
TIME_BINS = time_bins = get_time_bins(T_train, E_train, size=30)

## Model

Let us fit a model and check performance.

In [5]:
results = pd.DataFrame()

## Lifelines

In [6]:
bench_lifelines = BenchmarkLifelines(lifelines.CoxPHFitter(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Cox-PH"
)

In [7]:
bench_lifelines.train()

In [8]:
cox_lifelines_results = bench_lifelines.test()
cox_lifelines_results

{'model': 'Cox-PH',
 'c-index': 0.7883798521392772,
 'ibs': 0.0985487804655452,
 'dcal_pval': 0.9711460792177572,
 'dcal_max_dev': 0.0113195501393272,
 'training_time': 1.192168951034546,
 'inference_time': 0.006709098815917969}

In [9]:
results = results.append(cox_lifelines_results, ignore_index=True)

In [10]:
bench_lifelines = BenchmarkLifelines(lifelines.WeibullAFTFitter(),
        df_train.assign(duration = np.where(df_train.duration == 0, 1, df_train.duration)),
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Weibull AFT"
)

In [11]:
bench_lifelines.train()
weibull_aft_results = bench_lifelines.test()
weibull_aft_results

{'model': 'Weibull AFT',
 'c-index': 0.788764446796106,
 'ibs': 0.09899526138314266,
 'dcal_pval': 0.8485679576992691,
 'dcal_max_dev': 0.0130897512621993,
 'training_time': 0.8402612209320068,
 'inference_time': 0.010332822799682617}

In [12]:
results = results.append(weibull_aft_results, ignore_index=True)

## XGBoost

In [13]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - AFT",
        "survival:aft",
)

In [14]:
bench_xgboost.train()
xgboost_aft_results = bench_xgboost.test()
xgboost_aft_results

{'model': 'XGB - AFT',
 'c-index': 0.7815428859787077,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.06519603729248047,
 'inference_time': 0.0008289813995361328}

In [15]:
results = results.append(xgboost_aft_results, ignore_index=True)

In [16]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - Cox",
        "survival:cox",
)
bench_xgboost.train()
xgboost_cox_results = bench_xgboost.test()
xgboost_cox_results

{'model': 'XGB - Cox',
 'c-index': 0.7786192316925594,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.0846090316772461,
 'inference_time': 0.0008230209350585938}

In [17]:
results = results.append(xgboost_cox_results, ignore_index=True)

## XGB SE

In [18]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEDebiasedBCE(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Debiased BCE",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bce_results = bench_xgboost_embedding.test()
xgboost_bce_results

{'model': 'XGBSE - Debiased BCE',
 'c-index': 0.7841321633816893,
 'ibs': 0.1013538111648455,
 'dcal_pval': 0.03644799466710443,
 'dcal_max_dev': 0.0298175545947673,
 'training_time': 46.155426025390625,
 'inference_time': 0.47022414207458496}

In [19]:
results = results.append(xgboost_bce_results, ignore_index=True)

In [20]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanNeighbors(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Neighbors",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kn_results = bench_xgboost_embedding.test()
xgboost_kn_results

{'model': 'XGBSE - Kaplan Neighbors',
 'c-index': 0.7694232549936064,
 'ibs': 0.10297238924307031,
 'dcal_pval': 0.7320452038855624,
 'dcal_max_dev': 0.019724969746333948,
 'training_time': 31.365705966949463,
 'inference_time': 5.807218074798584}

In [21]:
results = results.append(xgboost_kn_results, ignore_index=True)

In [22]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanTree(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Tree",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kt_results = bench_xgboost_embedding.test()
xgboost_kt_results

{'model': 'XGBSE - Kaplan Tree',
 'c-index': 0.767830984121385,
 'ibs': 0.10346911400864475,
 'dcal_pval': 0.9287914257772109,
 'dcal_max_dev': 0.01106687028007211,
 'training_time': 0.21167683601379395,
 'inference_time': 0.0029299259185791016}

In [23]:
results = results.append(xgboost_kt_results, ignore_index=True)

In [24]:
base_tree = XGBSEKaplanTree()

bench_xgboost_embedding = BenchmarkXGBSE(
    XGBSEBootstrapEstimator(base_tree, n_estimators=100),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Bootstrap Trees",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bootstrap_results = bench_xgboost_embedding.test()
xgboost_bootstrap_results

{'model': 'XGBSE - Bootstrap Trees',
 'c-index': 0.7814485745501223,
 'ibs': 0.10047335035894789,
 'dcal_pval': 0.9845121854944557,
 'dcal_max_dev': 0.009485243521309514,
 'training_time': 17.497756004333496,
 'inference_time': 0.4245340824127197}

In [25]:
results = results.append(xgboost_bootstrap_results, ignore_index=True)

## Pysurvival

In [26]:
bench_pysurvival = BenchmarkPysurvival(ConditionalSurvivalForestModel(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Conditional Survival Forest",

)
bench_pysurvival.train()
pysurvival_results = bench_pysurvival.test()
pysurvival_results

{'model': 'Conditional Survival Forest',
 'c-index': 0.7614202566250716,
 'ibs': 0.10590986997908104,
 'dcal_pval': 0.03091278345062742,
 'dcal_max_dev': 0.029681239567899267,
 'training_time': 1.1027400493621826,
 'inference_time': 109.55284905433655}

In [27]:
results = results.append(pysurvival_results, ignore_index=True)

In [28]:
results.sort_values("c-index", ascending=False).round(3)

Unnamed: 0,c-index,dcal_max_dev,dcal_pval,ibs,inference_time,model,training_time
1,0.789,0.013,0.849,0.099,0.01,Weibull AFT,0.84
0,0.788,0.011,0.971,0.099,0.007,Cox-PH,1.192
4,0.784,0.03,0.036,0.101,0.47,XGBSE - Debiased BCE,46.155
2,0.782,,,,0.001,XGB - AFT,0.065
7,0.781,0.009,0.985,0.1,0.425,XGBSE - Bootstrap Trees,17.498
3,0.779,,,,0.001,XGB - Cox,0.085
5,0.769,0.02,0.732,0.103,5.807,XGBSE - Kaplan Neighbors,31.366
6,0.768,0.011,0.929,0.103,0.003,XGBSE - Kaplan Tree,0.212
8,0.761,0.03,0.031,0.106,109.553,Conditional Survival Forest,1.103


In [29]:
results.sort_values("c-index", ascending=False).round(3).set_index('model').to_markdown()

'| model                       |   c-index |   dcal_max_dev |   dcal_pval |     ibs |   inference_time |   training_time |\n|:----------------------------|----------:|---------------:|------------:|--------:|-----------------:|----------------:|\n| Weibull AFT                 |     0.789 |          0.013 |       0.849 |   0.099 |            0.01  |           0.84  |\n| Cox-PH                      |     0.788 |          0.011 |       0.971 |   0.099 |            0.007 |           1.192 |\n| XGBSE - Debiased BCE        |     0.784 |          0.03  |       0.036 |   0.101 |            0.47  |          46.155 |\n| XGB - AFT                   |     0.782 |        nan     |     nan     | nan     |            0.001 |           0.065 |\n| XGBSE - Bootstrap Trees     |     0.781 |          0.009 |       0.985 |   0.1   |            0.425 |          17.498 |\n| XGB - Cox                   |     0.779 |        nan     |     nan     | nan     |            0.001 |           0.085 |\n| XGBSE - Kapla