# Introduction to the xgbsurv package

This notebook introduces `xgbsurv` using a specific dataset. It structured by the following steps:

- Load data
- Load model
- Fit model
- Predict and evaluate model

The syntax conveniently follows that of sklearn.

In [1]:
from xgbsurv.datasets import load_metabric
from xgbsurv.models.utils import sort_X_y
from xgbsurv import XGBSurv
from sklearn.model_selection import train_test_split
import numpy as np
%load_ext autoreload
%autoreload 2


## Load Data

In [2]:
data, target = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False)
target_sign = np.sign(target)
X_train, X_test, y_train, y_test = train_test_split(data, target, stratify=target_sign)
X_train, y_train = sort_X_y(X_train, y_train) 
X_test,  y_test = sort_X_y(X_test,  y_test)

## Load Model

In [3]:
model = XGBSurv(n_estimators=500, objective="efron_objective",
                                             eval_metric="efron_loss",
                                             learning_rate=0.01,
                                             random_state=7, 
                                             disable_default_eval_metric=True)

The options of loss and objective functions can be obtained like below:

In [4]:
print(model.get_loss_functions().keys())
print(model.get_objective_functions().keys())

dict_keys(['breslow_loss', 'efron_loss', 'cind_loss', 'deephit_loss', 'aft_loss'])
dict_keys(['breslow_objective', 'efron_objective', 'cind_objective', 'deephit_objective', 'aft_objective'])


## Fit Model

In [5]:
eval_set = [(X_train, y_train)]

In [6]:
model.fit(X_train, y_train, eval_set=eval_set)

[0]	validation_0-efron_likelihood:5415.62994
[1]	validation_0-efron_likelihood:5410.06634
[2]	validation_0-efron_likelihood:5404.62664
[3]	validation_0-efron_likelihood:5399.36299
[4]	validation_0-efron_likelihood:5394.22407
[5]	validation_0-efron_likelihood:5389.28687
[6]	validation_0-efron_likelihood:5384.44580
[7]	validation_0-efron_likelihood:5379.99260
[8]	validation_0-efron_likelihood:5375.39569
[9]	validation_0-efron_likelihood:5371.08487
[10]	validation_0-efron_likelihood:5367.03189
[11]	validation_0-efron_likelihood:5362.93716
[12]	validation_0-efron_likelihood:5359.09183
[13]	validation_0-efron_likelihood:5355.27417
[14]	validation_0-efron_likelihood:5351.47841
[15]	validation_0-efron_likelihood:5347.86162
[16]	validation_0-efron_likelihood:5343.88087
[17]	validation_0-efron_likelihood:5340.37225
[18]	validation_0-efron_likelihood:5336.45684
[19]	validation_0-efron_likelihood:5333.24360
[20]	validation_0-efron_likelihood:5329.59863
[21]	validation_0-efron_likelihood:5326.4930

The model can be saved like below. Note that objective and eval_metric are not saved.

In [7]:
model.save_model("efron_model.json")



## Predict

In [8]:
preds_train = model.predict(X_train, output_margin=True)
preds_test = model.predict(X_test, output_margin=True)

### Predict Cumulative Hazard

In [None]:
df_cum_hazards = model.predict_cumulative_hazard_function(X_test, dataframe=True)
df_cum_hazards.tail(5)

q value at the end 770
shape cum_hazard_baseline_final (771,)
thres 0.76666665
ind 0
shape cum_hazard_baseline_final (771,)
thres 1.2666667
ind 0
shape cum_hazard_baseline_final (772,)
thres 1.7666667
ind 0
shape cum_hazard_baseline_final (773,)
thres 2.5
ind 0
shape cum_hazard_baseline_final (774,)
thres 3.3666666
ind 1
shape cum_hazard_baseline_final (775,)
thres 3.7666667
ind 2
shape cum_hazard_baseline_final (776,)
thres 5.4333334
ind 5
shape cum_hazard_baseline_final (777,)
thres 12.4
ind 24
shape cum_hazard_baseline_final (778,)
thres 13.4
ind 24
shape cum_hazard_baseline_final (779,)
thres 17.766666
ind 47
shape cum_hazard_baseline_final (780,)
thres 19.6
ind 55
shape cum_hazard_baseline_final (781,)
thres 20.266666
ind 60
shape cum_hazard_baseline_final (782,)
thres 21.6
ind 67
shape cum_hazard_baseline_final (783,)
thres 23.766666
ind 78
shape cum_hazard_baseline_final (784,)
thres 23.8
ind 78
shape cum_hazard_baseline_final (785,)
thres 23.9
ind 79
shape cum_hazard_baseline_f

Unnamed: 0,time,patient_0,patient_1,patient_2,patient_3,patient_4,patient_5,patient_6,patient_7,patient_8,...,patient_466,patient_467,patient_468,patient_469,patient_470,patient_471,patient_472,patient_473,patient_474,patient_475
1291,330.36667,4.332714,2.983133,2.863574,9.217394,2.212409,4.804026,1.490031,2.564052,0.834569,...,0.870425,2.397685,2.1624,0.715836,2.928681,2.828265,2.964945,2.455471,0.74119,2.738542
1292,335.6,4.707767,3.241362,3.111453,10.015278,2.403922,5.219876,1.619012,2.786004,0.906812,...,0.945772,2.605235,2.349584,0.7778,3.182196,3.073088,3.221599,2.668024,0.80535,2.975598
1293,335.73334,5.647859,3.888629,3.732778,12.015226,2.88396,6.262232,1.942312,3.34234,1.087893,...,1.134633,3.125474,2.818771,0.933119,3.817648,3.686752,3.86492,3.2008,0.96617,3.569794
1294,337.03333,7.001727,4.820786,4.627575,14.895438,3.575285,7.763374,2.407911,4.143544,1.348676,...,1.406619,3.874693,3.494469,1.1568,4.73279,4.570516,4.791393,3.968076,1.197774,4.425522
1295,351.0,824.751707,567.852889,545.094152,1754.572507,421.14208,914.468001,283.634084,488.078883,158.864059,...,165.689371,456.410138,411.622609,136.262514,557.487657,538.373005,564.390696,467.409956,141.088942,521.293799


In [None]:
df_cum_hazards.to_csv('cumhazards_efron.csv')
np.savetxt("preds_train_efron.csv", preds_train, delimiter=",")

### Predict Survival Function

In [None]:
df_survival = model.predict_survival_function(X_train, dataframe=True)
df_survival

Unnamed: 0,time,patient_0,patient_1,patient_2,patient_3,patient_4,patient_5,patient_6,patient_7,patient_8,...,patient_1417,patient_1418,patient_1419,patient_1420,patient_1421,patient_1422,patient_1423,patient_1424,patient_1425,patient_1426
0,0.100000,0.999390,9.987946e-01,0.999359,9.997210e-01,9.997570e-01,9.987590e-01,9.997687e-01,0.999260,0.999421,...,9.998759e-01,9.999023e-01,9.999083e-01,9.998486e-01,9.999157e-01,9.999009e-01,9.998459e-01,9.998257e-01,9.999275e-01,9.998966e-01
1,0.766667,0.999390,9.987946e-01,0.999359,9.997210e-01,9.997570e-01,9.987590e-01,9.997687e-01,0.999260,0.999421,...,9.998759e-01,9.999023e-01,9.999083e-01,9.998486e-01,9.999157e-01,9.999009e-01,9.998459e-01,9.998257e-01,9.999275e-01,9.998966e-01
2,1.266667,0.999390,9.987946e-01,0.999359,9.997210e-01,9.997570e-01,9.987590e-01,9.997687e-01,0.999260,0.999421,...,9.998759e-01,9.999023e-01,9.999083e-01,9.998486e-01,9.999157e-01,9.999009e-01,9.998459e-01,9.998257e-01,9.999275e-01,9.998966e-01
3,1.766667,0.999390,9.987946e-01,0.999359,9.997210e-01,9.997570e-01,9.987590e-01,9.997687e-01,0.999260,0.999421,...,9.998759e-01,9.999023e-01,9.999083e-01,9.998486e-01,9.999157e-01,9.999009e-01,9.998459e-01,9.998257e-01,9.999275e-01,9.998966e-01
4,2.500000,0.999390,9.987946e-01,0.999359,9.997210e-01,9.997570e-01,9.987590e-01,9.997687e-01,0.999260,0.999421,...,9.998759e-01,9.999023e-01,9.999083e-01,9.998486e-01,9.999157e-01,9.999009e-01,9.998459e-01,9.998257e-01,9.999275e-01,9.998966e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1291,330.366670,0.003798,1.632115e-05,0.002850,7.804657e-02,1.084669e-01,1.178312e-05,1.206866e-01,0.001148,0.005021,...,3.216814e-01,4.092718e-01,4.326405e-01,2.505817e-01,4.626796e-01,4.040580e-01,2.444933e-01,2.033761e-01,5.154311e-01,3.887578e-01
1292,335.600000,0.002344,6.285685e-06,0.001716,6.258537e-02,8.949315e-02,4.411788e-06,1.004997e-01,0.000639,0.003175,...,2.915999e-01,3.788144e-01,4.023734e-01,2.222904e-01,4.328187e-01,3.735738e-01,2.164280e-01,1.771840e-01,4.866932e-01,3.582288e-01
1293,335.733340,0.000700,5.749591e-07,0.000481,3.598660e-02,5.526790e-02,3.760097e-07,6.351952e-02,0.000147,0.001007,...,2.279874e-01,3.120632e-01,3.354886e-01,1.646295e-01,3.661678e-01,3.068911e-01,1.594346e-01,1.254130e-01,4.215057e-01,2.918307e-01
1294,337.033330,0.000123,1.835430e-08,0.000077,1.621915e-02,2.760748e-02,1.084146e-08,3.280560e-02,0.000018,0.000193,...,1.599534e-01,2.360509e-01,2.582120e-01,1.068301e-01,2.877984e-01,2.312105e-01,1.026670e-01,7.624353e-02,3.426602e-01,2.172279e-01


### Visualize Predictions

In [None]:
import sys
!conda install --yes --prefix {sys.prefix} -c plotly plotly_express
!conda install --yes --prefix {sys.prefix} -c ipykernel
import plotly_express as px

Collecting package metadata (current_repodata.json): done
Solving environment: done

# All requested packages already installed.


CondaValueError: too few arguments, must supply command line package specs or --file



In [None]:
# cumhazard
px.line(df_cum_hazards, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Cumulative Hazard over Time')


In [None]:
# survival function
px.line(df_survival, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Survival Function over Time')



In [None]:
#!conda remove --yes --prefix {sys.prefix} -c plotly plotly_express
#!conda remove --yes --prefix {sys.prefix} -c nbformat

## Evaluate

In [None]:
#from sksurv.metrics import concordance_index_censored
from xgbsurv.evaluation import cindex_censored, ibs

In [None]:
# train
cindex_censored(y_train, preds_train)

0.8680497584219261

In [None]:
# test
cindex_censored(y_test, preds_test)

0.629672807399626