In [None]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import lifelines
import pycox
import xgboost as xgb
from sklearn.model_selection import train_test_split

from xgbse import XGBSEKaplanNeighbors, XGBSEKaplanTree, XGBSEDebiasedBCE, XGBSEBootstrapEstimator
from xgbse.converters import convert_data_to_xgb_format, convert_to_structured, convert_y
from xgbse.non_parametric import get_time_bins

from benchmark import BenchmarkLifelines, BenchmarkXGBoost, BenchmarkXGBSE, BenchmarkPysurvival
from pysurvival.models.survival_forest import ConditionalSurvivalForestModel

# setting seed
np.random.seed(42)

In [3]:
def split_train_test_valid(dataf, test_size=.2, valid_size=.01, random_state=1):
    df_train, df_test = train_test_split(dataf, test_size=test_size, random_state=random_state)
    df_train, df_valid = train_test_split(df_train, test_size=valid_size, random_state=random_state)
    return df_train, df_valid,  df_test

## Data

In [4]:
from pycox.datasets import support

In [5]:
support_df = support.read_df()
df_train, df_valid, df_test = split_train_test_valid(support_df)

T_train = df_train.duration
E_train = df_train.event
TIME_BINS = time_bins = get_time_bins(T_train, E_train, size=30)

## Model

Let us fit a model and check performance.

In [6]:
results = pd.DataFrame()

## Lifelines

In [7]:

bench_lifelines = BenchmarkLifelines(lifelines.CoxPHFitter(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Cox-PH"
)

In [8]:
bench_lifelines.train()

In [9]:
cox_lifelines_results = bench_lifelines.test()
cox_lifelines_results

{'model': 'Cox-PH',
 'c-index': 0.5780217892329218,
 'ibs': 0.2013043483549938,
 'dcal_pval': 0.0,
 'dcal_max_dev': 0.16042116729003866,
 'training_time': 0.5187549591064453,
 'inference_time': 0.009529829025268555}

In [10]:
results = results.append(cox_lifelines_results, ignore_index=True)

In [11]:
bench_lifelines = BenchmarkLifelines(lifelines.WeibullAFTFitter(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Weibull AFT"
)

In [12]:
bench_lifelines.train()
weibull_aft_results = bench_lifelines.test()
weibull_aft_results

{'model': 'Weibull AFT',
 'c-index': 0.5764595917903831,
 'ibs': 0.20147880877502886,
 'dcal_pval': 5.912690299687692e-252,
 'dcal_max_dev': 0.13831859544149858,
 'training_time': 0.6235389709472656,
 'inference_time': 0.008682012557983398}

In [13]:
results = results.append(weibull_aft_results, ignore_index=True)

## XGBoost

In [14]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - AFT",
        "survival:aft",
)

In [15]:
bench_xgboost.train()
xgboost_aft_results = bench_xgboost.test()
xgboost_aft_results

{'model': 'XGB - AFT',
 'c-index': 0.6118077667273784,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.05042719841003418,
 'inference_time': 0.0008549690246582031}

In [16]:
results = results.append(xgboost_aft_results, ignore_index=True)

In [17]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - Cox",
        "survival:cox",
)
bench_xgboost.train()
xgboost_cox_results = bench_xgboost.test()
xgboost_cox_results

{'model': 'XGB - Cox',
 'c-index': 0.61197556100068,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.13672184944152832,
 'inference_time': 0.0010209083557128906}

In [18]:
results = results.append(xgboost_cox_results, ignore_index=True)

## XGB SE

In [19]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEDebiasedBCE(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Debiased BCE",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bce_results = bench_xgboost_embedding.test()
xgboost_bce_results

{'model': 'XGBSE - Debiased BCE',
 'c-index': 0.607062937775623,
 'ibs': 0.189858717682179,
 'dcal_pval': 8.767402398701876e-189,
 'dcal_max_dev': 0.11941058998551468,
 'training_time': 62.01675295829773,
 'inference_time': 1.2210862636566162}

In [20]:
results = results.append(xgboost_bce_results, ignore_index=True)

In [21]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanNeighbors(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Neighbors",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kn_results = bench_xgboost_embedding.test()
xgboost_kn_results

{'model': 'XGBSE - Kaplan Neighbors',
 'c-index': 0.5777793792849416,
 'ibs': 0.20156917507339067,
 'dcal_pval': 6.99924919299361e-176,
 'dcal_max_dev': 0.11043982034195823,
 'training_time': 49.236324310302734,
 'inference_time': 8.93327784538269}

In [22]:
results = results.append(xgboost_kn_results, ignore_index=True)

In [23]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanTree(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Tree",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kt_results = bench_xgboost_embedding.test()
xgboost_kt_results

{'model': 'XGBSE - Kaplan Tree',
 'c-index': 0.5976060379731908,
 'ibs': 0.20288464464793987,
 'dcal_pval': 1.360307803727658e-142,
 'dcal_max_dev': 0.0965224071702945,
 'training_time': 0.19405484199523926,
 'inference_time': 0.005036115646362305}

In [24]:
results = results.append(xgboost_kt_results, ignore_index=True)

In [25]:
base_tree = XGBSEKaplanTree()

bench_xgboost_embedding = BenchmarkXGBSE(
    XGBSEBootstrapEstimator(base_tree, n_estimators=100),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Bootstrap Trees",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bootstrap_results = bench_xgboost_embedding.test()
xgboost_bootstrap_results

{'model': 'XGBSE - Bootstrap Trees',
 'c-index': 0.6072864208207579,
 'ibs': 0.18781453318160898,
 'dcal_pval': 2.711414876542169e-197,
 'dcal_max_dev': 0.10253653541171698,
 'training_time': 19.81351399421692,
 'inference_time': 0.5238931179046631}

In [26]:
results = results.append(xgboost_bootstrap_results, ignore_index=True)

## Pysurvival

In [27]:
bench_pysurvival = BenchmarkPysurvival(ConditionalSurvivalForestModel(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Conditional Survival Forest",

)
bench_pysurvival.train()
pysurvival_results = bench_pysurvival.test()
pysurvival_results

{'model': 'Conditional Survival Forest',
 'c-index': 0.5947109497956622,
 'ibs': 0.1946863785642482,
 'dcal_pval': 0.0,
 'dcal_max_dev': 0.16581709491434668,
 'training_time': 2.6257541179656982,
 'inference_time': 115.48593997955322}

In [28]:
results = results.append(pysurvival_results, ignore_index=True)

In [29]:
results.sort_values("c-index", ascending=False).round(3)

Unnamed: 0,c-index,dcal_max_dev,dcal_pval,ibs,inference_time,model,training_time
3,0.612,,,,0.001,XGB - Cox,0.137
2,0.612,,,,0.001,XGB - AFT,0.05
7,0.607,0.103,0.0,0.188,0.524,XGBSE - Bootstrap Trees,19.814
4,0.607,0.119,0.0,0.19,1.221,XGBSE - Debiased BCE,62.017
6,0.598,0.097,0.0,0.203,0.005,XGBSE - Kaplan Tree,0.194
8,0.595,0.166,0.0,0.195,115.486,Conditional Survival Forest,2.626
0,0.578,0.16,0.0,0.201,0.01,Cox-PH,0.519
5,0.578,0.11,0.0,0.202,8.933,XGBSE - Kaplan Neighbors,49.236
1,0.576,0.138,0.0,0.201,0.009,Weibull AFT,0.624


In [30]:
results.sort_values("c-index", ascending=False).round(3).set_index('model').to_markdown()

'| model                       |   c-index |   dcal_max_dev |   dcal_pval |     ibs |   inference_time |   training_time |\n|:----------------------------|----------:|---------------:|------------:|--------:|-----------------:|----------------:|\n| XGB - Cox                   |     0.612 |        nan     |         nan | nan     |            0.001 |           0.137 |\n| XGB - AFT                   |     0.612 |        nan     |         nan | nan     |            0.001 |           0.05  |\n| XGBSE - Bootstrap Trees     |     0.607 |          0.103 |           0 |   0.188 |            0.524 |          19.814 |\n| XGBSE - Debiased BCE        |     0.607 |          0.119 |           0 |   0.19  |            1.221 |          62.017 |\n| XGBSE - Kaplan Tree         |     0.598 |          0.097 |           0 |   0.203 |            0.005 |           0.194 |\n| Conditional Survival Forest |     0.595 |          0.166 |           0 |   0.195 |          115.486 |           2.626 |\n| Cox-PH       