# Introduction to the xgbsurv package

This notebook introduces `xgbsurv` using a specific dataset. It structured by the following steps:

- Load data
- Load model
- Fit model
- Predict and evaluate model

The syntax conveniently follows that of sklearn.

In [1]:
from xgbsurv.datasets import load_metabric
from xgbsurv.models.utils import sort_X_y
from xgbsurv import XGBSurv
from sklearn.model_selection import train_test_split
import numpy as np
#%load_ext autoreloadconda 
#%autoreload 2


## Load Data

In [2]:
data, target = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False, return_X_y=True)
target_sign = np.sign(target)
X_train, X_test, y_train, y_test = train_test_split(data, target, stratify=target_sign)
X_train, y_train = sort_X_y(X_train, y_train) 
X_test,  y_test = sort_X_y(X_test,  y_test)

## Load Model

In [3]:
model = XGBSurv(n_estimators=10, objective="cind_objective",
                                             eval_metric="cind_loss",
                                             learning_rate=0.1,
                                             random_state=7, 
                                             disable_default_eval_metric=True)

The options of loss and objective functions can be obtained like below:

In [4]:
print(model.get_loss_functions().keys())
print(model.get_objective_functions().keys())

dict_keys(['breslow_loss', 'efron_loss', 'cind_loss', 'deephit_loss', 'aft_loss', 'ah_loss', 'eh_loss'])
dict_keys(['breslow_objective', 'efron_objective', 'cind_objective', 'deephit_objective', 'aft_objective', 'ah_objective', 'eh_objective'])


## Fit Model

In [5]:
eval_set = [(X_train, y_train)]

In [6]:
model.fit(X_train, y_train, eval_set=eval_set)

[0]	validation_0-cind_loss:-416.31332
[1]	validation_0-cind_loss:-419.14242
[2]	validation_0-cind_loss:-421.98425
[3]	validation_0-cind_loss:-424.89049
[4]	validation_0-cind_loss:-427.80346
[5]	validation_0-cind_loss:-430.69999
[6]	validation_0-cind_loss:-433.59587
[7]	validation_0-cind_loss:-436.48752
[8]	validation_0-cind_loss:-439.37142
[9]	validation_0-cind_loss:-442.26470


In [7]:
#model.fit(X_train, y_train, eval_set=eval_set)

The model can be saved like below. Note that objective and eval_metric are not saved.

In [8]:
model.save_model("cind_model.json")

## Predict

In [9]:
preds_train = model.predict(X_train)
preds_test = model.predict(X_test)

### Predict Cumulative Hazard

## Evaluate

In [10]:
#from sksurv.metrics import concordance_index_censored
from xgbsurv.evaluation import cindex_censored, ibs
from sksurv.metrics import concordance_index_censored
from xgbsurv.models.utils import transform_back
np.set_printoptions(suppress=True)

In [11]:
event_time, event_indicator = transform_back(y_test)
concordance_index_censored(event_indicator.astype('bool'), event_time, preds_test)[0]

0.36459058296531566

In [12]:
# train
cindex_censored(y_train, preds_train)

0.2836651503851431

In [13]:
# test
cindex_censored(y_test,preds_test)

0.36459058296531566

In [14]:
y_train

array([   0.1       ,   -0.76666665,   -1.2333333 , ..., -337.03333   ,
        351.        ,  355.2       ], dtype=float32)

In [15]:
y_test.shape

(476,)

In [16]:
preds_train

array([0.48590532, 0.54625905, 0.4864021 , ..., 0.5182941 , 0.50174665,
       0.5532081 ], dtype=float32)

In [17]:
preds_test

array([0.4605856 , 0.49737754, 0.48590532, 0.5315069 , 0.47892365,
       0.5182941 , 0.5315069 , 0.49737754, 0.51430863, 0.46143258,
       0.49352917, 0.50174665, 0.4605856 , 0.47892365, 0.51317614,
       0.5182941 , 0.47434598, 0.49737754, 0.5209666 , 0.4884806 ,
       0.4745113 , 0.47434598, 0.5300912 , 0.50174665, 0.4605856 ,
       0.49737754, 0.48590532, 0.55315506, 0.51430863, 0.4993268 ,
       0.4884806 , 0.4888425 , 0.5206782 , 0.46143258, 0.50174665,
       0.47434598, 0.48590532, 0.50174665, 0.48590532, 0.49737754,
       0.4888425 , 0.4884806 , 0.48118225, 0.47434598, 0.50174665,
       0.49737754, 0.50174665, 0.5209666 , 0.47892365, 0.54625905,
       0.47434598, 0.48590532, 0.48590532, 0.46143258, 0.4888425 ,
       0.54625905, 0.46143258, 0.5532081 , 0.5300912 , 0.4888425 ,
       0.49737754, 0.47892365, 0.49737754, 0.47892365, 0.4884806 ,
       0.48590532, 0.4884806 , 0.5300912 , 0.50174665, 0.4993268 ,
       0.49737754, 0.48590532, 0.5182941 , 0.49737754, 0.48590