# Introduction to the xgbsurv package

This notebook introduces `xgbsurv` using a specific dataset. It structured by the following steps:

- Load data
- Load model
- Fit model
- Predict and evaluate model

The syntax conveniently follows that of sklearn.

In [1]:
from xgbsurv.datasets import load_metabric, load_flchain
from xgbsurv.models.eh_final import eh_likelihood #, eh_objective
from sklearn.model_selection import train_test_split
from pycox.evaluation import EvalSurv
from xgbsurv import XGBSurv
from xgbsurv.models.utils import transform_back, sort_X_y_pandas
from xgbsurv.models.eh_final import get_cumulative_hazard_function_eh
import numpy as np
import pandas as pd
%load_ext autoreload
%autoreload 2


## Load Data

In [2]:
data, target = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False, return_X_y=True)
target_sign = np.sign(target)
n = target.shape[0]
target2 = target.reshape(n,1)
target2 = np.tile(target2, (1,2)) #.shape
target2
X_train, X_test, y_train, y_test = train_test_split(data, target2, stratify=target_sign)



In [3]:
#X_train, y_train = sort_X_y_pandas(pd.DataFrame(X_train), pd.DataFrame(y_train))
#X_test, y_test = sort_X_y_pandas(pd.DataFrame(X_test), pd.DataFrame(y_test))
X_train = X_train.astype(np.float32)
y_train = y_train.astype(np.float32)
X_test = X_test.astype(np.float32)
y_test = y_test.astype(np.float32)

## Load Model

In [4]:
model = XGBSurv(n_estimators=1, objective="eh_objective",
                                             eval_metric="eh_loss",
                                             learning_rate=0.1,
                                             random_state=7, 
                                             disable_default_metric=True, 
                                             base_score=0.0,
                                             verbosity=3)

The options of loss and objective functions can be obtained like below:

In [5]:
print(model.get_loss_functions().keys())
print(model.get_objective_functions().keys())

dict_keys(['breslow_loss', 'efron_loss', 'cind_loss', 'deephit_loss', 'aft_loss', 'ah_loss', 'eh_loss'])
dict_keys(['breslow_objective', 'efron_objective', 'cind_objective', 'deephit_objective', 'aft_objective', 'ah_objective', 'eh_objective'])


## Fit Model

In [6]:
eval_set = [(X_train, y_train)]


In [7]:
model.fit(X_train, y_train, eval_set=eval_set)

Parameters: { "disable_default_metric" } are not used.

[18:41:09] DEBUG: /Users/runner/work/xgboost/xgboost/python-package/build/temp.macosx-11.0-arm64-cpython-38/xgboost/src/gbm/gbtree.cc:157: Using tree method: 2
[18:41:09] INFO: /Users/runner/work/xgboost/xgboost/python-package/build/temp.macosx-11.0-arm64-cpython-38/xgboost/src/tree/updater_prune.cc:98: tree pruning end, 88 extra nodes, 0 pruned nodes, max_depth=6
[18:41:09] INFO: /Users/runner/work/xgboost/xgboost/python-package/build/temp.macosx-11.0-arm64-cpython-38/xgboost/src/tree/updater_prune.cc:98: tree pruning end, 88 extra nodes, 0 pruned nodes, max_depth=6
[0]	validation_0-rmse:147.87695	validation_0-eh_likelihood:2573.52013
[18:41:09] BoostOneIter: 0.003615s, 1 calls @ 3615us

[18:41:09] Configure: 0.000212s, 1 calls @ 212us

[18:41:09] EvalOneIter: 0.000216s, 1 calls @ 216us

[18:41:09] BoostNewTrees: 0.003535s, 1 calls @ 3535us

[18:41:09] CommitModel: 0s, 1 calls @ 0us

[18:41:09] PrunerUpdate: 1.5e-05s, 2 calls @ 1

The model can be saved like below. Note that objective and eval_metric are not saved.

In [8]:
model.save_model("eh_model.json")

## Predict

In [9]:
preds_train = model.predict(X_train)
preds_test = model.predict(X_test)
preds_test

array([[-3.5511148e-03,  3.4251341e-03],
       [-3.5511148e-03,  2.9360494e-02],
       [ 5.2676457e-03, -3.4116633e-02],
       [-1.0072008e-02,  3.4251341e-03],
       [ 1.0844442e-02,  2.9360494e-02],
       [-3.5511148e-03, -3.4579483e-03],
       [-3.5511148e-03, -3.4579483e-03],
       [-3.5511148e-03,  3.4251341e-03],
       [-1.0072008e-02,  1.3999209e-02],
       [-3.5511148e-03,  1.2568697e-02],
       [-3.5511148e-03,  2.9360494e-02],
       [ 2.2387272e-02, -8.9475168e-03],
       [ 3.0304831e-03, -3.4579483e-03],
       [-3.5511148e-03, -3.4579483e-03],
       [-1.5624416e-03,  3.1563792e-02],
       [ 1.0439034e-02,  3.1563792e-02],
       [ 5.2676457e-03, -8.9475168e-03],
       [ 1.8303421e-03,  4.1445833e-02],
       [-1.2164873e-03,  1.2067969e-02],
       [ 1.3051051e-02, -1.4600220e-02],
       [-1.2164873e-03,  1.2067969e-02],
       [-3.5511148e-03,  1.6326148e-02],
       [-3.5511148e-03,  2.9360494e-02],
       [ 7.0587532e-03,  3.1563792e-02],
       [-3.55111

### Predict Cumulative Hazard

In [10]:
df_cum_hazards = model.predict_cumulative_hazard_function(X_train, X_test, y_train, y_test)
df_cum_hazards.tail(5)

predictor_train [[ 0.00705875  0.04144583]
 [-0.00355111 -0.00345795]
 [-0.00355111 -0.00345795]
 ...
 [ 0.01248097  0.00342513]
 [-0.00355111  0.02936049]
 [-0.00121649 -0.03421191]]
predictor_test [[-3.5511148e-03  3.4251341e-03]
 [-3.5511148e-03  2.9360494e-02]
 [ 5.2676457e-03 -3.4116633e-02]
 [-1.0072008e-02  3.4251341e-03]
 [ 1.0844442e-02  2.9360494e-02]
 [-3.5511148e-03 -3.4579483e-03]
 [-3.5511148e-03 -3.4579483e-03]
 [-3.5511148e-03  3.4251341e-03]
 [-1.0072008e-02  1.3999209e-02]
 [-3.5511148e-03  1.2568697e-02]
 [-3.5511148e-03  2.9360494e-02]
 [ 2.2387272e-02 -8.9475168e-03]
 [ 3.0304831e-03 -3.4579483e-03]
 [-3.5511148e-03 -3.4579483e-03]
 [-1.5624416e-03  3.1563792e-02]
 [ 1.0439034e-02  3.1563792e-02]
 [ 5.2676457e-03 -8.9475168e-03]
 [ 1.8303421e-03  4.1445833e-02]
 [-1.2164873e-03  1.2067969e-02]
 [ 1.3051051e-02 -1.4600220e-02]
 [-1.2164873e-03  1.2067969e-02]
 [-3.5511148e-03  1.6326148e-02]
 [-3.5511148e-03  2.9360494e-02]
 [ 7.0587532e-03  3.1563792e-02]
 [-3.5511

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,466,467,468,469,470,471,472,473,474,475
287.933319,1.290257,1.348709,1.347918,1.313181,1.267127,1.364636,1.36431,1.304173,1.347684,1.359063,...,1.265052,1.358684,1.297088,1.304173,1.264728,1.313408,1.265052,1.351144,1.313181,1.264675
291.166656,1.305071,1.364099,1.363495,1.328331,1.281728,1.380303,1.379939,1.319219,1.363232,1.374783,...,1.279508,1.37421,1.312015,1.319219,1.279283,1.328586,1.279508,1.366651,1.328331,1.279265
292.666656,1.31187,1.371246,1.370644,1.335356,1.288509,1.387494,1.387347,1.326196,1.370442,1.382011,...,1.286212,1.38141,1.318937,1.326196,1.286032,1.335553,1.286212,1.373926,1.335356,1.286031
295.333344,1.324103,1.383982,1.38335,1.347841,1.300445,1.400433,1.400226,1.338596,1.383255,1.39489,...,1.298273,1.394364,1.331238,1.338596,1.298027,1.347933,1.298273,1.386732,1.347841,1.298055
335.600006,1.507663,1.575646,1.575286,1.534915,1.480696,1.594575,1.59447,1.524387,1.575245,1.587936,...,1.47823,1.58764,1.516011,1.524387,1.47819,1.534956,1.47823,1.578881,1.534915,1.478219


In [11]:
df_survival_function = np.exp(-df_cum_hazards)
durations_test, events_test = transform_back(y_test[:,0])
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
ev = EvalSurv(df_survival_function, durations_test, events_test, censor_surv='km')
print('Concordance Index',ev.concordance_td('antolini'))
print('Brier Score',ev.integrated_brier_score(time_grid))

Concordance Index 0.4550129195644628
Brier Score 0.17900772647057084


In [12]:
df_survival_function

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,466,467,468,469,470,471,472,473,474,475
0.766667,0.999317,0.999297,0.999282,0.999294,0.999333,0.999277,0.999272,0.999299,0.999276,0.999289,...,0.999326,0.999277,0.999305,0.999299,0.999322,0.999300,0.999326,0.999285,0.999294,0.999321
3.766667,0.997105,0.996986,0.996992,0.997080,0.997175,0.996939,0.996951,0.997100,0.997003,0.996953,...,0.997180,0.996971,0.997123,0.997100,0.997194,0.997069,0.997180,0.996970,0.997080,0.997188
4.433333,0.996426,0.996250,0.996280,0.996381,0.996478,0.996220,0.996229,0.996406,0.996286,0.996208,...,0.996511,0.996254,0.996434,0.996406,0.996523,0.996375,0.996511,0.996259,0.996381,0.996515
5.433333,0.995333,0.995077,0.995135,0.995261,0.995407,0.995064,0.995069,0.995293,0.995136,0.995063,...,0.995437,0.995101,0.995330,0.995293,0.995446,0.995259,0.995437,0.995115,0.995261,0.995435
5.833333,0.994859,0.994582,0.994639,0.994775,0.994942,0.994563,0.994566,0.994811,0.994638,0.994523,...,0.994972,0.994601,0.994851,0.994811,0.994980,0.994776,0.994972,0.994619,0.994775,0.994968
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
287.933319,0.275200,0.259575,0.259781,0.268963,0.281640,0.255474,0.255557,0.271397,0.259841,0.256901,...,0.282225,0.256999,0.273327,0.271397,0.282316,0.268902,0.282225,0.258944,0.268963,0.282331
291.166656,0.271153,0.255611,0.255765,0.264919,0.277557,0.251502,0.251594,0.267344,0.255833,0.252895,...,0.278174,0.253039,0.269277,0.267344,0.278237,0.264851,0.278174,0.254959,0.264919,0.278242
292.666656,0.269316,0.253790,0.253943,0.263064,0.275682,0.249700,0.249737,0.265485,0.253995,0.251073,...,0.276316,0.251224,0.267419,0.265485,0.276365,0.263013,0.276316,0.253111,0.263064,0.276365
295.333344,0.266041,0.250579,0.250737,0.259800,0.272411,0.246490,0.246541,0.262214,0.250761,0.247860,...,0.273003,0.247991,0.264150,0.262214,0.273070,0.259777,0.273003,0.249891,0.259800,0.273062


### Predict Survival Function

In [13]:
df_survival = model.predict_survival_function(data, dataframe=True)
df_survival

TypeError: XGBSurv.predict_survival_function() got an unexpected keyword argument 'dataframe'

### Visualize Predictions

In [None]:
import sys
!conda install --yes --prefix {sys.prefix} -c plotly plotly_express
!conda install --yes --prefix {sys.prefix} -c ipykernel
import plotly_express as px

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/JUSC/miniconda3/envs/xgbsurv

  added / updated specs:
    - plotly_express


The following NEW packages will be INSTALLED:

  patsy              pkgs/main/osx-64::patsy-0.5.3-py310hecd8cb5_0 
  plotly             plotly/noarch::plotly-5.13.1-py_0 
  plotly_express     plotly/noarch::plotly_express-0.4.1-py_0 
  statsmodels        pkgs/main/osx-64::statsmodels-0.13.5-py310h7b7cdfe_1 
  tenacity           pkgs/main/osx-64::tenacity-8.0.1-py310hecd8cb5_1 



Downloading and Extracting Packages

Preparing transaction: done
Verifying transaction: done
Executing transaction: done

CondaValueError: too few arguments, must supply command line package specs or --file



In [None]:
# cumhazard
px.line(df_cum_hazards, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Cumulative Hazard over Time')


In [None]:
# survival function
px.line(df_survival, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Survival Function over Time')



In [None]:
!conda remove --yes --prefix {sys.prefix} -c plotly plotly_express
!conda remove --yes --prefix {sys.prefix} -c nbformat

Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/JUSC/miniconda3/envs/xgbsurv

  removed specs:
    - plotly_express


The following packages will be REMOVED:

  patsy-0.5.3-py310hecd8cb5_0
  plotly-5.13.1-py_0
  plotly_express-0.4.1-py_0
  statsmodels-0.13.5-py310h7b7cdfe_1
  tenacity-8.0.1-py310hecd8cb5_1


Preparing transaction: done
Verifying transaction: done
Executing transaction: done

CondaValueError: no package names supplied,
       try "conda remove -h" for more details



## Evaluate

In [None]:
#from sksurv.metrics import concordance_index_censored
from xgbsurv.evaluation import cindex_censored, ibs

In [None]:
cindex_censored(data.target, preds[:,1])

0.659739525116822

In [None]:
np.savetxt('preds_eh.csv', preds)