# Introduction to the xgbsurv package

This notebook introduces `xgbsurv` using a specific dataset. It structured by the following steps:

- Load data
- Load model
- Fit model
- Predict and evaluate model

The syntax conveniently follows that of sklearn.

In [1]:
from xgbsurv.datasets import load_metabric
from xgbsurv.models.eh_final import eh_likelihood #, eh_objective
from xgbsurv import XGBSurv
import numpy as np
%load_ext autoreload
%autoreload 2


## Load Data

In [2]:
data = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=False)
n = data.target.shape[0]
target2 = data.target.reshape(n,1)
target2 = np.tile(target2, (1,2)) #.shape
target2

array([[ 1.0000000e-01,  1.0000000e-01],
       [-7.6666665e-01, -7.6666665e-01],
       [-1.2333333e+00, -1.2333333e+00],
       ...,
       [-3.3703333e+02, -3.3703333e+02],
       [ 3.5100000e+02,  3.5100000e+02],
       [ 3.5520001e+02,  3.5520001e+02]], dtype=float32)

In [3]:
# test run likelihood
#eh_likelihood(data.target, target2)

## Load Model

In [4]:
model = XGBSurv(n_estimators=2000, objective="eh_objective",
                                             eval_metric="eh_loss",
                                             learning_rate=0.01,
                                             random_state=7, disable_default_metric=True, base_score=0.3)
model.get_params()

{'eval_metric': CPUDispatcher(<function eh_likelihood at 0x7fceca8db130>),
 'objective': CPUDispatcher(<function eh_objective at 0x7fceca8dbb50>),
 'base_score': 0.3,
 'booster': None,
 'callbacks': None,
 'colsample_bylevel': None,
 'colsample_bynode': None,
 'colsample_bytree': None,
 'early_stopping_rounds': None,
 'enable_categorical': False,
 'feature_types': None,
 'gamma': None,
 'gpu_id': None,
 'grow_policy': None,
 'importance_type': None,
 'interaction_constraints': None,
 'learning_rate': 0.01,
 'max_bin': None,
 'max_cat_threshold': None,
 'max_cat_to_onehot': None,
 'max_delta_step': None,
 'max_depth': None,
 'max_leaves': None,
 'min_child_weight': None,
 'missing': nan,
 'monotone_constraints': None,
 'n_estimators': 2000,
 'n_jobs': None,
 'num_parallel_tree': None,
 'predictor': None,
 'random_state': 7,
 'reg_alpha': None,
 'reg_lambda': None,
 'sampling_method': None,
 'scale_pos_weight': None,
 'subsample': None,
 'tree_method': None,
 'validate_parameters': None,

The options of loss and objective functions can be obtained like below:

In [5]:
print(model.get_loss_functions().keys())
print(model.get_objective_functions().keys())

dict_keys(['breslow_loss', 'efron_loss', 'cind_loss', 'deephit_loss', 'aft_loss', 'ah_loss', 'eh_loss'])
dict_keys(['breslow_objective', 'efron_objective', 'cind_objective', 'deephit_objective', 'aft_objective', 'ah_objective', 'eh_objective'])


## Fit Model

In [6]:
eval_set = [(data.data, target2)]


In [7]:
model.fit(data.data, target2, eval_set=eval_set)

Parameters: { "disable_default_metric" } are not used.

y shape (3806,)
(1903, 2)
-likelihood 3.1140353835665406
[0]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.114034866679437
[1]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.1140343497937337
[2]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.114033832909432
[3]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.114033316026527
[4]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.1140327991450247
[5]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.114032282264922
[6]	validation_0-rmse:146.53337	validation_0-eh_likelihood:3.11403
y shape (3806,)
(1903, 2)
-likelihood 3.114031767574591
[7]	validation_0-rmse:14

The model can be saved like below. Note that objective and eval_metric are not saved.

In [10]:
model.save_model("eh_model.json")

## Predict

In [13]:
preds = model.predict(np.nan_to_num(data.data), output_margin=True, validate_features=False)
preds

array([[0.2997616 , 0.30381492],
       [0.3010133 , 0.2983907 ],
       [0.2997616 , 0.30381492],
       ...,
       [0.2997616 , 0.2983907 ],
       [0.2997616 , 0.2983907 ],
       [0.2997616 , 0.2983907 ]], dtype=float32)

### Predict Cumulative Hazard

In [None]:
df_cum_hazards = model.predict_cumulative_hazard_function(data, dataframe=True)
df_cum_hazards.head(5)

denominator 1.4150272607803345
(1685,) (1685,)


Unnamed: 0,time,patient_0,patient_1,patient_2,patient_3,patient_4,patient_5,patient_6,patient_7,patient_8,...,patient_1893,patient_1894,patient_1895,patient_1896,patient_1897,patient_1898,patient_1899,patient_1900,patient_1901,patient_1902
0,0.1,0.000589,0.000476,0.000552,0.000584,0.000437,0.000632,0.000476,0.000635,0.000481,...,0.000476,0.000476,0.000491,0.000588,0.000468,0.000468,0.000493,0.000476,0.000476,0.000432
1,0.766667,0.001771,0.001431,0.001657,0.001753,0.001312,0.001897,0.001431,0.001909,0.001445,...,0.001431,0.001431,0.001474,0.001767,0.001407,0.001407,0.001479,0.001431,0.001431,0.001298
2,1.233333,0.003545,0.002864,0.003316,0.003509,0.002626,0.003797,0.002864,0.003821,0.002892,...,0.002864,0.002864,0.002951,0.003537,0.002816,0.002816,0.002962,0.002864,0.002864,0.002598
3,1.266667,0.005912,0.004776,0.005531,0.005852,0.004379,0.006333,0.004776,0.006373,0.004824,...,0.004776,0.004776,0.004923,0.0059,0.004698,0.004698,0.00494,0.004776,0.004776,0.004334
4,1.433333,0.008874,0.007169,0.008302,0.008784,0.006573,0.009506,0.007169,0.009565,0.00724,...,0.007169,0.007169,0.007389,0.008855,0.007051,0.007051,0.007414,0.007169,0.007169,0.006505


In [None]:
df_cum_hazards.to_csv('test.csv')

In [None]:
np.savetxt("preds.csv", preds, delimiter=",")

In [None]:
len(np.unique(np.absolute(target)))

1685

In [None]:
len(target)

1903

### Predict Survival Function

In [None]:
df_survival = model.predict_survival_function(data, dataframe=True)
df_survival

Unnamed: 0,time,patient_0,patient_1,patient_2,patient_3,patient_4,patient_5,patient_6,patient_7,patient_8,...,patient_1893,patient_1894,patient_1895,patient_1896,patient_1897,patient_1898,patient_1899,patient_1900,patient_1901,patient_1902
0,0.100000,9.994107e-01,9.995238e-01,9.994486e-01,9.994167e-01,9.995634e-01,9.993687e-01,9.995238e-01,9.993648e-01,9.995191e-01,...,9.995238e-01,9.995238e-01,9.995093e-01,9.994119e-01,9.995317e-01,9.995317e-01,9.995076e-01,9.995238e-01,9.995238e-01,9.995680e-01
1,0.766667,9.982309e-01,9.985704e-01,9.983447e-01,9.982488e-01,9.986892e-01,9.981049e-01,9.985704e-01,9.980931e-01,9.985563e-01,...,9.985704e-01,9.985704e-01,9.985267e-01,9.982345e-01,9.985940e-01,9.985940e-01,9.985216e-01,9.985704e-01,9.985704e-01,9.987028e-01
2,1.233333,9.964618e-01,9.971403e-01,9.966892e-01,9.964975e-01,9.973777e-01,9.962100e-01,9.971403e-01,9.961864e-01,9.971122e-01,...,9.971403e-01,9.971403e-01,9.970530e-01,9.964690e-01,9.971875e-01,9.971875e-01,9.970428e-01,9.971403e-01,9.971403e-01,9.974050e-01
3,1.266667,9.941055e-01,9.952349e-01,9.944840e-01,9.941651e-01,9.956301e-01,9.936866e-01,9.952349e-01,9.936474e-01,9.951880e-01,...,9.952349e-01,9.952349e-01,9.950896e-01,9.941176e-01,9.953134e-01,9.953134e-01,9.950726e-01,9.952349e-01,9.952349e-01,9.956755e-01
4,1.433333,9.911657e-01,9.928563e-01,9.917321e-01,9.912548e-01,9.934482e-01,9.905388e-01,9.928563e-01,9.904802e-01,9.927861e-01,...,9.928563e-01,9.928563e-01,9.926387e-01,9.911838e-01,9.929738e-01,9.929738e-01,9.926133e-01,9.928563e-01,9.928563e-01,9.935161e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1680,335.600000,4.656672e-217,1.646696e-175,3.954090e-203,7.237864e-215,5.566516e-161,1.755569e-232,1.646696e-175,6.337285e-234,3.116488e-177,...,1.646696e-175,1.646696e-175,7.458802e-181,1.300179e-216,1.265915e-172,1.265915e-172,1.774117e-181,1.646696e-175,1.646696e-175,2.578949e-159
1681,335.733340,4.656672e-217,1.646696e-175,3.954090e-203,7.237864e-215,5.566516e-161,1.755569e-232,1.646696e-175,6.337285e-234,3.116488e-177,...,1.646696e-175,1.646696e-175,7.458802e-181,1.300179e-216,1.265915e-172,1.265915e-172,1.774117e-181,1.646696e-175,1.646696e-175,2.578949e-159
1682,337.033330,4.656672e-217,1.646696e-175,3.954090e-203,7.237864e-215,5.566516e-161,1.755569e-232,1.646696e-175,6.337285e-234,3.116488e-177,...,1.646696e-175,1.646696e-175,7.458802e-181,1.300179e-216,1.265915e-172,1.265915e-172,1.774117e-181,1.646696e-175,1.646696e-175,2.578949e-159
1683,351.000000,4.656672e-217,1.646696e-175,3.954090e-203,7.237864e-215,5.566516e-161,1.755569e-232,1.646696e-175,6.337285e-234,3.116488e-177,...,1.646696e-175,1.646696e-175,7.458802e-181,1.300179e-216,1.265915e-172,1.265915e-172,1.774117e-181,1.646696e-175,1.646696e-175,2.578949e-159


### Visualize Predictions

In [None]:
import sys
!conda install --yes --prefix {sys.prefix} -c plotly plotly_express
!conda install --yes --prefix {sys.prefix} -c ipykernel
import plotly_express as px

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/JUSC/miniconda3/envs/xgbsurv

  added / updated specs:
    - plotly_express


The following NEW packages will be INSTALLED:

  patsy              pkgs/main/osx-64::patsy-0.5.3-py310hecd8cb5_0 
  plotly             plotly/noarch::plotly-5.13.1-py_0 
  plotly_express     plotly/noarch::plotly_express-0.4.1-py_0 
  statsmodels        pkgs/main/osx-64::statsmodels-0.13.5-py310h7b7cdfe_1 
  tenacity           pkgs/main/osx-64::tenacity-8.0.1-py310hecd8cb5_1 



Downloading and Extracting Packages

Preparing transaction: done
Verifying transaction: done
Executing transaction: done

CondaValueError: too few arguments, must supply command line package specs or --file



In [None]:
# cumhazard
px.line(df_cum_hazards, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Cumulative Hazard over Time')


In [None]:
# survival function
px.line(df_survival, x= 'time',y = df_cum_hazards.columns[1:4],range_y=[0,1], title='Survival Function over Time')



In [None]:
!conda remove --yes --prefix {sys.prefix} -c plotly plotly_express
!conda remove --yes --prefix {sys.prefix} -c nbformat

Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/JUSC/miniconda3/envs/xgbsurv

  removed specs:
    - plotly_express


The following packages will be REMOVED:

  patsy-0.5.3-py310hecd8cb5_0
  plotly-5.13.1-py_0
  plotly_express-0.4.1-py_0
  statsmodels-0.13.5-py310h7b7cdfe_1
  tenacity-8.0.1-py310hecd8cb5_1


Preparing transaction: done
Verifying transaction: done
Executing transaction: done

CondaValueError: no package names supplied,
       try "conda remove -h" for more details



## Evaluate

In [14]:
#from sksurv.metrics import concordance_index_censored
from xgbsurv.evaluation import cindex_censored, ibs

In [19]:
cindex_censored(data.target, preds[:,1])

0.659739525116822

In [17]:
np.savetxt('preds_eh.csv', preds)