In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import lifelines
import pycox
import xgboost as xgb
from sklearn.model_selection import train_test_split

from xgbse import XGBSEKaplanNeighbors, XGBSEKaplanTree, XGBSEDebiasedBCE, XGBSEBootstrapEstimator
from xgbse.converters import convert_data_to_xgb_format, convert_to_structured, convert_y
from xgbse.non_parametric import get_time_bins

from benchmark import BenchmarkLifelines, BenchmarkXGBoost, BenchmarkXGBSE, BenchmarkPysurvival
from pysurvival.models.survival_forest import ConditionalSurvivalForestModel

# setting seed
np.random.seed(42)

In [3]:
def split_train_test_valid(dataf, test_size=.2, valid_size=.01, random_state=1):
    df_train, df_test = train_test_split(dataf, test_size=test_size, random_state=random_state)
    df_train, df_valid = train_test_split(df_train, test_size=valid_size, random_state=random_state)
    return df_train, df_valid,  df_test

## Data

In [4]:
from pycox.datasets import metabric
metabric_df = metabric.read_df()
df_train, df_valid, df_test = split_train_test_valid(metabric_df)

T_train = df_train.duration
E_train = df_train.event
TIME_BINS = time_bins = get_time_bins(T_train, E_train, size=30)

## Model

Let us fit a model and check performance.

In [5]:
results = pd.DataFrame()

## Lifelines

In [6]:

bench_lifelines = BenchmarkLifelines(lifelines.CoxPHFitter(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Cox-PH"
)

In [7]:
bench_lifelines.train()

In [8]:
cox_lifelines_results = bench_lifelines.test()
cox_lifelines_results

{'model': 'Cox-PH',
 'c-index': 0.6216153910637942,
 'ibs': 0.15360476043227422,
 'dcal_pval': 0.5666124375648811,
 'dcal_max_dev': 0.025805768690018743,
 'training_time': 0.21692919731140137,
 'inference_time': 0.004318952560424805}

In [9]:
results = results.append(cox_lifelines_results, ignore_index=True)

In [10]:
bench_lifelines = BenchmarkLifelines(lifelines.WeibullAFTFitter(),
        df_train.assign(duration = np.where(df_train.duration == 0, 1, df_train.duration)),
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Weibull AFT"
)

In [11]:
bench_lifelines.train()
weibull_aft_results = bench_lifelines.test()
weibull_aft_results

{'model': 'Weibull AFT',
 'c-index': 0.6221812389974013,
 'ibs': 0.15412227274195567,
 'dcal_pval': 0.666759786591344,
 'dcal_max_dev': 0.024245639416549558,
 'training_time': 0.3900749683380127,
 'inference_time': 0.00800180435180664}

In [12]:
results = results.append(weibull_aft_results, ignore_index=True)

## XGBoost

In [13]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - AFT",
        "survival:aft",
)

In [14]:
bench_xgboost.train()
xgboost_aft_results = bench_xgboost.test()
xgboost_aft_results

{'model': 'XGB - AFT',
 'c-index': 0.6101726884064046,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.02436208724975586,
 'inference_time': 0.0007061958312988281}

In [15]:
results = results.append(xgboost_aft_results, ignore_index=True)

In [16]:
bench_xgboost = BenchmarkXGBoost(xgb,
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGB - Cox",
        "survival:cox",
)
bench_xgboost.train()
xgboost_cox_results = bench_xgboost.test()
xgboost_cox_results

{'model': 'XGB - Cox',
 'c-index': 0.6192157766786822,
 'ibs': nan,
 'dcal_pval': nan,
 'dcal_max_dev': nan,
 'training_time': 0.022723913192749023,
 'inference_time': 0.0006630420684814453}

In [17]:
results = results.append(xgboost_cox_results, ignore_index=True)

## XGB SE

In [18]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEDebiasedBCE(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Debiased BCE",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bce_results = bench_xgboost_embedding.test()
xgboost_bce_results

{'model': 'XGBSE - Debiased BCE',
 'c-index': 0.6322198004862101,
 'ibs': 0.1566039133167816,
 'dcal_pval': 0.3811202061981239,
 'dcal_max_dev': 0.03232402429331624,
 'training_time': 14.198437929153442,
 'inference_time': 0.36887192726135254}

In [19]:
results = results.append(xgboost_bce_results, ignore_index=True)

In [20]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanNeighbors(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Neighbors",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kn_results = bench_xgboost_embedding.test()
xgboost_kn_results

{'model': 'XGBSE - Kaplan Neighbors',
 'c-index': 0.6269595104367508,
 'ibs': 0.15569091727108705,
 'dcal_pval': 0.52481525977227,
 'dcal_max_dev': 0.02351833078188803,
 'training_time': 25.679295301437378,
 'inference_time': 0.7914042472839355}

In [21]:
results = results.append(xgboost_kn_results, ignore_index=True)

In [22]:
bench_xgboost_embedding = BenchmarkXGBSE(XGBSEKaplanTree(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Kaplan Tree",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_kt_results = bench_xgboost_embedding.test()
xgboost_kt_results

{'model': 'XGBSE - Kaplan Tree',
 'c-index': 0.589791684131109,
 'ibs': 0.16482640216664288,
 'dcal_pval': 0.18030272928223431,
 'dcal_max_dev': 0.03586345646948624,
 'training_time': 0.1144261360168457,
 'inference_time': 0.01436305046081543}

In [23]:
results = results.append(xgboost_kt_results, ignore_index=True)

In [24]:
base_tree = XGBSEKaplanTree()

bench_xgboost_embedding = BenchmarkXGBSE(
    XGBSEBootstrapEstimator(base_tree, n_estimators=100),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "XGBSE - Bootstrap Trees",
        "survival:aft",
)
bench_xgboost_embedding.train()
xgboost_bootstrap_results = bench_xgboost_embedding.test()
xgboost_bootstrap_results

{'model': 'XGBSE - Bootstrap Trees',
 'c-index': 0.6242769720848352,
 'ibs': 0.1549380947038349,
 'dcal_pval': 0.5634196757562135,
 'dcal_max_dev': 0.023621561823879617,
 'training_time': 11.956449031829834,
 'inference_time': 0.4664280414581299}

In [25]:
results = results.append(xgboost_bootstrap_results, ignore_index=True)

## Pysurvival

In [26]:
bench_pysurvival = BenchmarkPysurvival(ConditionalSurvivalForestModel(),
        df_train,
        df_valid,
        df_test,
        "event",
        "duration",
        TIME_BINS,
        "Conditional Survival Forest",

)
bench_pysurvival.train()
pysurvival_results = bench_pysurvival.test()
pysurvival_results

{'model': 'Conditional Survival Forest',
 'c-index': 0.6233129348646157,
 'ibs': 0.15248870955612373,
 'dcal_pval': 0.28919180650906656,
 'dcal_max_dev': 0.03175346126555047,
 'training_time': 0.20072293281555176,
 'inference_time': 31.874172925949097}

In [27]:
results = results.append(pysurvival_results, ignore_index=True)

In [28]:
results.sort_values("c-index", ascending=False).round(3)

Unnamed: 0,c-index,dcal_max_dev,dcal_pval,ibs,inference_time,model,training_time
4,0.632,0.032,0.381,0.157,0.369,XGBSE - Debiased BCE,14.198
5,0.627,0.024,0.525,0.156,0.791,XGBSE - Kaplan Neighbors,25.679
7,0.624,0.024,0.563,0.155,0.466,XGBSE - Bootstrap Trees,11.956
8,0.623,0.032,0.289,0.152,31.874,Conditional Survival Forest,0.201
1,0.622,0.024,0.667,0.154,0.008,Weibull AFT,0.39
0,0.622,0.026,0.567,0.154,0.004,Cox-PH,0.217
3,0.619,,,,0.001,XGB - Cox,0.023
2,0.61,,,,0.001,XGB - AFT,0.024
6,0.59,0.036,0.18,0.165,0.014,XGBSE - Kaplan Tree,0.114


In [29]:
results.sort_values("c-index", ascending=False).round(3).set_index('model').to_markdown()

'| model                       |   c-index |   dcal_max_dev |   dcal_pval |     ibs |   inference_time |   training_time |\n|:----------------------------|----------:|---------------:|------------:|--------:|-----------------:|----------------:|\n| XGBSE - Debiased BCE        |     0.632 |          0.032 |       0.381 |   0.157 |            0.369 |          14.198 |\n| XGBSE - Kaplan Neighbors    |     0.627 |          0.024 |       0.525 |   0.156 |            0.791 |          25.679 |\n| XGBSE - Bootstrap Trees     |     0.624 |          0.024 |       0.563 |   0.155 |            0.466 |          11.956 |\n| Conditional Survival Forest |     0.623 |          0.032 |       0.289 |   0.152 |           31.874 |           0.201 |\n| Weibull AFT                 |     0.622 |          0.024 |       0.667 |   0.154 |            0.008 |           0.39  |\n| Cox-PH                      |     0.622 |          0.026 |       0.567 |   0.154 |            0.004 |           0.217 |\n| XGB - Cox    