# Introduction to the xgbsurv package

This notebook introduces `xgbsurv` using a specific dataset. It structured by the following steps:

- Load data
- Load model
- Fit model
- Predict and evaluate model

The syntax conveniently follows that of sklearn.

In [1]:
from xgbsurv.datasets import load_metabric
from xgbsurv.models.eh_final import eh_likelihood #, eh_objective
from sklearn.model_selection import train_test_split
from pycox.evaluation import EvalSurv
from xgbsurv import XGBSurv
from xgbsurv.models.utils import transform_back, sort_X_y_pandas
import numpy as np
import pandas as pd
%load_ext autoreload
%autoreload 2


## Load Data

In [2]:
data, target = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False, return_X_y=True)
target_sign = np.sign(target)
n = target.shape[0]
target2 = target.reshape(n,1)
target2 = np.tile(target2, (1,2)) #.shape
target2
X_train, X_test, y_train, y_test = train_test_split(data, target2, stratify=target_sign)


In [3]:
sort_X_y_pandas(pd.DataFrame(X_train), pd.DataFrame(y_test))

[3.69666672e+01 4.25666656e+01 1.19300003e+02 1.69833328e+02
 2.32766663e+02 9.02666702e+01 7.98666687e+01 3.55200012e+02
 2.19666672e+02 5.72333336e+01 1.93966660e+02 1.03133331e+02
 1.40233337e+02 1.21966667e+02 7.01666641e+01 2.29999995e+00
 1.96466660e+02 5.56333351e+01 2.62633331e+02 1.46366669e+02
 1.98999996e+01 1.25599998e+02 7.50999985e+01 6.42333298e+01
 2.40433334e+02 8.99000015e+01 4.30333328e+01 1.36066666e+02
 1.11699997e+02 2.98999996e+01 2.28800003e+02 5.58333321e+01
 2.58166656e+02 1.82600006e+02 4.71333351e+01 2.92033325e+02
 4.23333321e+01 2.40033340e+02 1.00833336e+02 1.31300003e+02
 4.66666679e+01 9.12333298e+01 5.86666679e+01 2.60200012e+02
 1.33233337e+02 8.53000031e+01 2.34233337e+02 3.14333324e+01
 6.86999969e+01 2.39166672e+02 1.10933334e+02 2.12999992e+01
 3.52333336e+01 1.94199997e+02 1.60399994e+02 1.28699997e+02
 4.00000000e+01 3.84333344e+01 3.47000008e+01 3.14666672e+01
 5.52000008e+01 5.87666664e+01 7.93666687e+01 4.84333344e+01
 2.35399994e+02 2.216000

(            0         1          2         3    4    5    6    7      8
 82   5.539938  6.160773   9.297399  5.586134  1.0  1.0  0.0  1.0  43.60
 287  5.595089  5.730703   9.639565  6.055458  1.0  0.0  0.0  1.0  70.28
 209  6.560032  5.487788   9.283164  5.823293  1.0  0.0  0.0  0.0  60.04
 15   7.480446  5.912105  14.333509  6.300035  1.0  1.0  1.0  0.0  53.69
 237  5.814371  7.310476  10.655068  5.841495  0.0  0.0  0.0  1.0  47.14
 ..        ...       ...        ...       ...  ...  ...  ...  ...    ...
 316  5.735562  7.169067   9.311744  5.782729  1.0  0.0  0.0  1.0  62.96
 202  5.915191  6.080590  10.453157  5.981409  0.0  1.0  0.0  1.0  70.86
 109  6.426667  6.306776  10.828656  5.721386  0.0  0.0  0.0  1.0  63.58
 102  5.517889  6.498012   8.273461  5.577678  1.0  0.0  0.0  1.0  73.16
 7    5.761712  5.277199  13.308283  5.947731  1.0  1.0  0.0  1.0  67.60
 
 [476 rows x 9 columns],
               0           1
 82     0.100000    0.100000
 287   -1.433333   -1.433333
 209   -2.

## Load Model

In [4]:
model = XGBSurv(n_estimators=500, objective="eh_objective",
                                             eval_metric="eh_loss",
                                             learning_rate=0.01,
                                             random_state=7, 
                                             disable_default_metric=True, 
                                             base_score=0.3)

The options of loss and objective functions can be obtained like below:

In [5]:
print(model.get_loss_functions().keys())
print(model.get_objective_functions().keys())

dict_keys(['breslow_loss', 'efron_loss', 'cind_loss', 'deephit_loss', 'aft_loss', 'ah_loss', 'eh_loss'])
dict_keys(['breslow_objective', 'efron_objective', 'cind_objective', 'deephit_objective', 'aft_objective', 'ah_objective', 'eh_objective'])


## Fit Model

In [6]:
eval_set = [(X_train, y_train)]


In [7]:
model.fit(X_train, y_train, eval_set=eval_set)

Parameters: { "disable_default_metric" } are not used.

[0]	validation_0-rmse:146.78805	validation_0-eh_likelihood:3.11825
[1]	validation_0-rmse:146.78803	validation_0-eh_likelihood:3.11821
[2]	validation_0-rmse:146.78802	validation_0-eh_likelihood:3.11816
[3]	validation_0-rmse:146.78800	validation_0-eh_likelihood:3.11812
[4]	validation_0-rmse:146.78799	validation_0-eh_likelihood:3.11808
[5]	validation_0-rmse:146.78797	validation_0-eh_likelihood:3.11804
[6]	validation_0-rmse:146.78796	validation_0-eh_likelihood:3.11800
[7]	validation_0-rmse:146.78794	validation_0-eh_likelihood:3.11796
[8]	validation_0-rmse:146.78793	validation_0-eh_likelihood:3.11792
[9]	validation_0-rmse:146.78791	validation_0-eh_likelihood:3.11788
[10]	validation_0-rmse:146.78790	validation_0-eh_likelihood:3.11783
[11]	validation_0-rmse:146.78788	validation_0-eh_likelihood:3.11779
[12]	validation_0-rmse:146.78787	validation_0-eh_likelihood:3.11775
[13]	validation_0-rmse:146.78785	validation_0-eh_likelihood:3.11771
[1

The model can be saved like below. Note that objective and eval_metric are not saved.

In [8]:
model.save_model("eh_model.json")

## Predict

In [9]:
preds_train = model.predict(X_train, output_margin=True)
preds_test = model.predict(X_test, output_margin=True)

### Predict Cumulative Hazard

In [10]:
df_cum_hazards = model.predict_cumulative_hazard_function(X_train, X_test, y_train, y_test)
df_cum_hazards.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,466,467,468,469,470,471,472,473,474,475
0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1.433333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4.166667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [11]:
df_survival_function = np.exp(-df_cum_hazards)
durations_test, events_test = transform_back(y_test[:,0])
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
ev = EvalSurv(df_survival_function, durations_test, events_test, censor_surv='km')
print('Concordance Index',ev.concordance_td('antolini'))
print('Brier Score',ev.integrated_brier_score(time_grid))

Concordance Index 0.5832908146578227
Brier Score 0.1744622560282251


In [12]:
df_survival_function

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,466,467,468,469,470,471,472,473,474,475
0.100000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
1.433333,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
2.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
2.300000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
4.166667,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
297.233337,0.272807,0.274745,0.264122,0.225491,0.228685,0.230219,0.243248,0.271453,0.293185,0.292065,...,0.283916,0.250448,0.224168,0.253838,0.269520,0.293185,0.224168,0.267855,0.274151,0.279536
318.200012,0.247385,0.248539,0.246508,0.200879,0.203947,0.205422,0.226547,0.245331,0.268027,0.274012,...,0.258233,0.225652,0.200295,0.236422,0.244181,0.268027,0.200295,0.242559,0.247960,0.261662
330.366669,0.231941,0.232644,0.230205,0.193353,0.189120,0.190553,0.211128,0.229500,0.260204,0.257216,...,0.242593,0.218014,0.192982,0.220336,0.228800,0.260204,0.192982,0.234740,0.232076,0.245070
351.000000,0.210816,0.210940,0.208002,0.172612,0.175499,0.170476,0.190188,0.207899,0.238289,0.234195,...,0.221146,0.196815,0.172796,0.198484,0.207776,0.238289,0.172796,0.212970,0.210391,0.222392


### Predict Survival Function

In [13]:
df_survival = model.predict_survival_function(data, dataframe=True)
df_survival

TypeError: XGBSurv.predict_survival_function() got an unexpected keyword argument 'dataframe'

### Visualize Predictions

In [None]:
import sys
!conda install --yes --prefix {sys.prefix} -c plotly plotly_express
!conda install --yes --prefix {sys.prefix} -c ipykernel
import plotly_express as px

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/JUSC/miniconda3/envs/xgbsurv

  added / updated specs:
    - plotly_express


The following NEW packages will be INSTALLED:

  patsy              pkgs/main/osx-64::patsy-0.5.3-py310hecd8cb5_0 
  plotly             plotly/noarch::plotly-5.13.1-py_0 
  plotly_express     plotly/noarch::plotly_express-0.4.1-py_0 
  statsmodels        pkgs/main/osx-64::statsmodels-0.13.5-py310h7b7cdfe_1 
  tenacity           pkgs/main/osx-64::tenacity-8.0.1-py310hecd8cb5_1 



Downloading and Extracting Packages

Preparing transaction: done
Verifying transaction: done
Executing transaction: done

CondaValueError: too few arguments, must supply command line package specs or --file



In [None]:
# cumhazard
px.line(df_cum_hazards, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Cumulative Hazard over Time')


In [None]:
# survival function
px.line(df_survival, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Survival Function over Time')



In [None]:
!conda remove --yes --prefix {sys.prefix} -c plotly plotly_express
!conda remove --yes --prefix {sys.prefix} -c nbformat

Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/JUSC/miniconda3/envs/xgbsurv

  removed specs:
    - plotly_express


The following packages will be REMOVED:

  patsy-0.5.3-py310hecd8cb5_0
  plotly-5.13.1-py_0
  plotly_express-0.4.1-py_0
  statsmodels-0.13.5-py310h7b7cdfe_1
  tenacity-8.0.1-py310hecd8cb5_1


Preparing transaction: done
Verifying transaction: done
Executing transaction: done

CondaValueError: no package names supplied,
       try "conda remove -h" for more details



## Evaluate

In [None]:
#from sksurv.metrics import concordance_index_censored
from xgbsurv.evaluation import cindex_censored, ibs

In [None]:
cindex_censored(data.target, preds[:,1])

0.659739525116822

In [None]:
np.savetxt('preds_eh.csv', preds)