In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import lifelines
import pycox
import xgboost as xgb
from sklearn.model_selection import train_test_split

from xgbse import XGBSEKaplanNeighbors, XGBSEKaplanTree, XGBSEDebiasedBCE, XGBSEBootstrapEstimator
from xgbse.converters import convert_data_to_xgb_format, convert_to_structured, convert_y
from xgbse.non_parametric import get_time_bins

from benchmark import BenchmarkLifelines, BenchmarkXGBoost, BenchmarkXGBSE, BenchmarkPysurvival
from pysurvival.models.survival_forest import ConditionalSurvivalForestModel

# setting seed
np.random.seed(42)

In [3]:
def split_train_test_valid(dataf, test_size=.2, valid_size=.01, random_state=1):
    df_train, df_test = train_test_split(dataf, test_size=test_size, random_state=random_state)
    df_train, df_valid = train_test_split(df_train, test_size=valid_size, random_state=random_state)
    return df_train, df_valid,  df_test

## Data

In [4]:
from pycox.datasets import sac3
sac3_df = sac3.read_df()
sac3_df.drop(["duration_true", "event_true", "censoring_true"], axis=1, inplace=True)

df_train, df_valid, df_test = split_train_test_valid(sac3_df)

T_train = df_train.duration
E_train = df_train.event
TIME_BINS = time_bins = get_time_bins(T_train, E_train, size=30)

## Model

Let us fit a model and check performance.

In [5]:
results = pd.DataFrame()

## Lifelines

In [6]:

bench_lifelines = BenchmarkLifelines(lifelines.CoxPHFitter(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Cox-PH"
)

In [7]:
bench_lifelines.train()

In [8]:
cox_lifelines_results = bench_lifelines.test()
cox_lifelines_results

{'model': 'Cox-PH',
 'c-index': 0.6821608577751815,
 'ibs': 0.16538314664366988,
 'dcal_pval': 2.960248829034621e-63,
 'dcal_max_dev': 0.03529930427043007,
 'training_time': 2.137861967086792,
 'inference_time': 0.036772727966308594}

In [9]:
results = results.append(cox_lifelines_results, ignore_index=True)

In [10]:
bench_lifelines = BenchmarkLifelines(lifelines.WeibullAFTFitter(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Weibull AFT"
)

In [11]:
bench_lifelines.train()
weibull_aft_results = bench_lifelines.test()
weibull_aft_results

{'model': 'Weibull AFT',
 'c-index': 0.682028716612183,
 'ibs': 0.16549935329896318,
 'dcal_pval': 8.433756927102771e-82,
 'dcal_max_dev': 0.03897261022924145,
 'training_time': 2.0566418170928955,
 'inference_time': 0.03670310974121094}

In [12]:
results = results.append(weibull_aft_results, ignore_index=True)

## XGBoost

In [13]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - AFT",
        "survival:aft",
)

In [14]:
bench_xgboost.train()
xgboost_aft_results = bench_xgboost.test()
xgboost_aft_results

{'model': 'XGB - AFT',
 'c-index': 0.6714661818988993,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.7346310615539551,
 'inference_time': 0.003381967544555664}

In [15]:
results = results.append(xgboost_aft_results, ignore_index=True)

In [16]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - Cox",
        "survival:cox",
)
bench_xgboost.train()
xgboost_cox_results = bench_xgboost.test()
xgboost_cox_results

{'model': 'XGB - Cox',
 'c-index': 0.6700909312669946,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.9130971431732178,
 'inference_time': 0.002251148223876953}

In [17]:
results = results.append(xgboost_cox_results, ignore_index=True)

## XGB SE

In [18]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEDebiasedBCE(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Debiased BCE",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bce_results = bench_xgboost_embedding.test()
xgboost_bce_results

{'model': 'XGBSE - Debiased BCE',
 'c-index': 0.6992551992108598,
 'ibs': 0.1622480403343549,
 'dcal_pval': 1.174853878363636e-100,
 'dcal_max_dev': 0.03820920790971053,
 'training_time': 689.0448880195618,
 'inference_time': 4.066883087158203}

In [19]:
results = results.append(xgboost_bce_results, ignore_index=True)

In [20]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanNeighbors(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Neighbors",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kn_results = bench_xgboost_embedding.test()
xgboost_kn_results

{'model': 'XGBSE - Kaplan Neighbors',
 'c-index': 0.6312882215825405,
 'ibs': 0.18583668971044806,
 'dcal_pval': 1.168283462662509e-69,
 'dcal_max_dev': 0.03795180577534555,
 'training_time': 531.032201051712,
 'inference_time': 1318.8425133228302}

In [21]:
results = results.append(xgboost_kn_results, ignore_index=True)

In [22]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanTree(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Tree",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kt_results = bench_xgboost_embedding.test()
xgboost_kt_results

{'model': 'XGBSE - Kaplan Tree',
 'c-index': 0.6309520608231479,
 'ibs': 0.1914181890695273,
 'dcal_pval': 6.061067356778973e-56,
 'dcal_max_dev': 0.033688866155989705,
 'training_time': 2.7165441513061523,
 'inference_time': 0.06923484802246094}

In [23]:
results = results.append(xgboost_kt_results, ignore_index=True)

In [24]:
base_tree = XGBSEKaplanTree()

bench_xgboost_embedding = BenchmarkXGBSE(
    XGBSEBootstrapEstimator(base_tree, n_estimators=100),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Bootstrap Trees",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bootstrap_results = bench_xgboost_embedding.test()
xgboost_bootstrap_results

{'model': 'XGBSE - Bootstrap Trees',
 'c-index': 0.6769903585433326,
 'ibs': 0.16786412434551737,
 'dcal_pval': 1.6511792922848247e-98,
 'dcal_max_dev': 0.043015745266438726,
 'training_time': 220.7722568511963,
 'inference_time': 3.1344659328460693}

In [25]:
results = results.append(xgboost_bootstrap_results, ignore_index=True)

## Pysurvival

In [26]:
bench_pysurvival = BenchmarkPysurvival(ConditionalSurvivalForestModel(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Conditional Survival Forest",

)
bench_pysurvival.train()
pysurvival_results = bench_pysurvival.test()
pysurvival_results

{'model': 'Conditional Survival Forest',
 'c-index': 0.6220749239105159,
 'ibs': 0.18715605994364778,
 'dcal_pval': 7.800830276383053e-98,
 'dcal_max_dev': 0.04440325562340154,
 'training_time': 837.0685200691223,
 'inference_time': 691.9388620853424}

In [27]:
results = results.append(pysurvival_results, ignore_index=True)

In [28]:
results.sort_values("c-index", ascending=False).round(3)

Unnamed: 0,c-index,dcal_max_dev,dcal_pval,ibs,inference_time,model,training_time
4,0.699,0.038,0.0,0.162,4.067,XGBSE - Debiased BCE,689.045
0,0.682,0.035,0.0,0.165,0.037,Cox-PH,2.138
1,0.682,0.039,0.0,0.165,0.037,Weibull AFT,2.057
7,0.677,0.043,0.0,0.168,3.134,XGBSE - Bootstrap Trees,220.772
2,0.671,,,,0.003,XGB - AFT,0.735
3,0.67,,,,0.002,XGB - Cox,0.913
5,0.631,0.038,0.0,0.186,1318.843,XGBSE - Kaplan Neighbors,531.032
6,0.631,0.034,0.0,0.191,0.069,XGBSE - Kaplan Tree,2.717
8,0.622,0.044,0.0,0.187,691.939,Conditional Survival Forest,837.069


In [29]:
results.sort_values("c-index", ascending=False).round(3).set_index('model').to_markdown()

'| model                       |   c-index |   dcal_max_dev |   dcal_pval |     ibs |   inference_time |   training_time |\n|:----------------------------|----------:|---------------:|------------:|--------:|-----------------:|----------------:|\n| XGBSE - Debiased BCE        |     0.699 |          0.038 |           0 |   0.162 |            4.067 |         689.045 |\n| Cox-PH                      |     0.682 |          0.035 |           0 |   0.165 |            0.037 |           2.138 |\n| Weibull AFT                 |     0.682 |          0.039 |           0 |   0.165 |            0.037 |           2.057 |\n| XGBSE - Bootstrap Trees     |     0.677 |          0.043 |           0 |   0.168 |            3.134 |         220.772 |\n| XGB - AFT                   |     0.671 |        nan     |         nan | nan     |            0.003 |           0.735 |\n| XGB - Cox                   |     0.67  |        nan     |         nan | nan     |            0.002 |           0.913 |\n| XGBSE - Kapla