# Introduction to the xgbsurv package - Accelerated Hazards

This notebook introduces `xgbsurv` using a specific dataset. It structured by the following steps:

- Load data
- Load model
- Fit model
- Predict and evaluate model

The syntax conveniently follows that of sklearn.

In [23]:
from xgbsurv.datasets import load_metabric, load_flchain, load_support
from xgbsurv.models.utils import sort_X_y, transform_back
from pycox.evaluation import EvalSurv
from xgbsurv import XGBSurv
from sklearn.model_selection import train_test_split
import numpy as np
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load Data

In [24]:
data, target = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False, return_X_y=True)
target_sign = np.sign(target)
X_train, X_test, y_train, y_test = train_test_split(data, target, stratify=target_sign)

## Load Model

In [25]:
model = XGBSurv(n_estimators=1, objective="ah_objective",
                                             eval_metric="ah_loss",
                                             learning_rate=0.6,
                                             random_state=42, 
                                             disable_default_eval_metric=True,
                                             base_score=0.0)

The options of loss and objective functions can be obtained like below:

In [26]:
print(model.get_loss_functions().keys())
print(model.get_objective_functions().keys())

dict_keys(['breslow_loss', 'efron_loss', 'cind_loss', 'deephit_loss', 'aft_loss', 'ah_loss', 'eh_loss'])
dict_keys(['breslow_objective', 'efron_objective', 'cind_objective', 'deephit_objective', 'aft_objective', 'ah_objective', 'eh_objective'])


## Fit Model

In [27]:
eval_set = [(X_train, y_train)]

In [28]:
model.fit(X_train, y_train, eval_set=eval_set)

[0]	validation_0-ah_likelihood:2572.24724


The model can be saved like below. Note that objective and eval_metric are not saved.

## Predict

In [29]:
preds_train = model.predict(X_train, output_margin=True)
preds_test = model.predict(X_test, output_margin=True)

In [30]:
preds_test

array([-0.02512996,  0.01891272, -0.02512996,  0.01225335,  0.01225335,
        0.01891272, -0.02512996, -0.00105614, -0.02512996, -0.0087131 ,
       -0.0087131 , -0.02512996,  0.02299402,  0.00521444, -0.02512996,
        0.03388387, -0.0087131 , -0.00105614, -0.0087131 , -0.02512996,
       -0.02512996, -0.02512996, -0.02512996, -0.02512996, -0.02512996,
       -0.0087131 , -0.00105614,  0.05524958,  0.02299402,  0.01891272,
        0.05524958, -0.02512996,  0.02299402,  0.02299402,  0.02299402,
       -0.02512996, -0.02512996, -0.0087131 , -0.02512996, -0.00105614,
       -0.02512996,  0.01225335,  0.01891272, -0.02512996, -0.02512996,
       -0.0087131 , -0.02512996, -0.0087131 ,  0.00521444, -0.02512996,
        0.02299402,  0.17136551, -0.02512996,  0.01891272,  0.02299402,
       -0.02512996, -0.00105614, -0.00105614,  0.02299402,  0.01891272,
       -0.02512996,  0.02299402,  0.02299402,  0.05524958, -0.02512996,
        0.10535192, -0.02512996, -0.02512996, -0.02512996,  0.05

## Evaluate

In [31]:
df_cum_hazards = model.predict_cumulative_hazard_function(X_train, X_test, y_train, y_test)
df_cum_hazards # = df_cum_hazards.T.sort_index(axis=0)

0.0333241953125
(13655,)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,466,467,468,469,470,471,472,473,474,475
1.233333,0.000718,0.000687,0.000718,0.000691,0.000691,0.000687,0.000718,0.000701,0.000718,0.000706,...,0.000696,0.000662,0.000718,0.000718,0.000706,0.000718,0.000718,0.000706,0.000684,0.000706
1.766667,0.000718,0.000687,0.000718,0.000691,0.000691,0.000687,0.000718,0.000701,0.000718,0.000706,...,0.000696,0.000663,0.000718,0.000718,0.000706,0.000718,0.000718,0.000706,0.000684,0.000706
2.300000,0.000721,0.000691,0.000721,0.000695,0.000695,0.000691,0.000721,0.000704,0.000721,0.000709,...,0.000700,0.000668,0.000721,0.000721,0.000709,0.000721,0.000721,0.000709,0.000688,0.000709
2.533333,0.000726,0.000698,0.000726,0.000702,0.000702,0.000698,0.000726,0.000710,0.000726,0.000715,...,0.000707,0.000678,0.000726,0.000726,0.000715,0.000726,0.000726,0.000715,0.000696,0.000715
3.500000,0.000840,0.000849,0.000840,0.000844,0.000844,0.000849,0.000840,0.000837,0.000840,0.000843,...,0.000841,0.000849,0.000840,0.000840,0.000843,0.000840,0.000840,0.000843,0.000845,0.000843
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
298.033325,1.375397,1.375066,1.375397,1.375086,1.375086,1.375066,1.375397,1.375342,1.375397,1.375272,...,1.375248,1.374253,1.375397,1.375397,1.375272,1.375397,1.375397,1.375272,1.374960,1.375272
300.866669,1.388456,1.387880,1.388456,1.388007,1.388007,1.387880,1.388456,1.388173,1.388456,1.388379,...,1.388131,1.387065,1.388456,1.388456,1.388379,1.388456,1.388456,1.388379,1.387857,1.388379
318.200012,1.467610,1.466184,1.467610,1.466524,1.466524,1.466184,1.467610,1.466954,1.467610,1.467153,...,1.466745,1.464656,1.467610,1.467610,1.467153,1.467610,1.467610,1.467153,1.466051,1.467153
335.733337,1.546589,1.544340,1.546589,1.544616,1.544616,1.544340,1.546589,1.545491,1.546589,1.545858,...,1.545100,1.541870,1.546589,1.546589,1.545858,1.546589,1.546589,1.545858,1.544087,1.545858


In [20]:
df_survival_function = np.exp(-df_cum_hazards)
durations_test, events_test = transform_back(y_test)
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
ev = EvalSurv(df_survival_function, durations_test, events_test, censor_surv='km')
print('Concordance Index',ev.concordance_td('antolini'))
print('Brier Score',ev.integrated_brier_score(time_grid))

Concordance Index 0.3051432168712622
Brier Score 0.18527379131814165


In [21]:
from xgbsurv.evaluation import cindex_censored, ibs
print(cindex_censored(y_train, preds_train))
print(cindex_censored(y_test, preds_test))

0.5530854094251766
0.5207994963802329


In [22]:
df_survival_function

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,466,467,468,469,470,471,472,473,474,475
0.100000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
0.766667,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
1.266667,0.999985,0.999981,0.999985,0.999985,0.999985,0.999985,0.999985,0.999993,0.999985,0.999985,...,0.999981,0.999985,0.999985,0.999970,0.999985,0.999984,0.999970,0.999985,0.999985,0.999985
2.533333,0.999425,0.999410,0.999425,0.999425,0.999425,0.999425,0.999425,0.999572,0.999425,0.999425,...,0.999410,0.999425,0.999425,0.999324,0.999425,0.999442,0.999324,0.999425,0.999424,0.999425
5.500000,0.996868,0.996828,0.996868,0.996868,0.996868,0.996868,0.996868,0.997076,0.996868,0.996868,...,0.996853,0.996868,0.996868,0.996727,0.996868,0.996866,0.996727,0.996868,0.996865,0.996868
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
297.233337,0.253824,0.253724,0.253824,0.253824,0.253824,0.253824,0.253824,0.254992,0.253824,0.253824,...,0.253726,0.253824,0.253824,0.253492,0.253824,0.253896,0.253492,0.253824,0.253832,0.253824
297.799988,0.253143,0.253056,0.253143,0.253143,0.253143,0.253143,0.253143,0.254302,0.253143,0.253143,...,0.253057,0.253143,0.253143,0.252811,0.253143,0.253208,0.252811,0.253143,0.253151,0.253143
300.700012,0.249694,0.249627,0.249694,0.249694,0.249694,0.249694,0.249694,0.250837,0.249694,0.249694,...,0.249625,0.249694,0.249694,0.249472,0.249694,0.249758,0.249472,0.249694,0.249738,0.249694
307.633331,0.241734,0.241692,0.241734,0.241734,0.241734,0.241734,0.241734,0.242659,0.241734,0.241734,...,0.241680,0.241734,0.241734,0.241613,0.241734,0.241781,0.241613,0.241734,0.241730,0.241734
