## XGBSEDebiasedBCE and XGBSEStackedWeibull with pre-trained XGB models for the first step

In [1]:
import numpy as np
import pandas as pd
import xgboost as xgb

# lib utils
from xgbse.converters import convert_data_to_xgb_format

from xgbse import (
    XGBSEDebiasedBCE,
    XGBSEStackedWeibull
)

from xgbse.metrics import concordance_index


from tests.data import get_data
(
    X_train,
    X_test,
    X_valid,
    T_train,
    T_test,
    T_valid,
    E_train,
    E_test,
    E_valid,
    y_train,
    y_test,
    y_valid,
    features,
) = get_data()

In [2]:
y = np.concatenate((y_train, y_test, y_valid))
X = np.concatenate((X_train, X_test, X_valid))

In [3]:
DEFAULT_PARAMS = {
    "objective": "survival:aft",
    "eval_metric": "aft-nloglik",
    "aft_loss_distribution": "normal",
    "aft_loss_distribution_scale": 1,
    "tree_method": "hist",
    "learning_rate": 5e-2,
    "max_depth": 8,
    "booster": "dart",
    "subsample": 0.5,
    "min_child_weight": 50,
    "colsample_bynode": 0.5,
}


In [4]:
num_boost_round=1000
validation_data=None
early_stopping_rounds=None
verbose_eval=0
time_bins = None

### Pre train XGB Model

In [5]:
dtrain = convert_data_to_xgb_format(X, y, DEFAULT_PARAMS["objective"])

# converting validation data to xgb format
evals = ()
# training XGB
bst = xgb.train(
    DEFAULT_PARAMS,
    dtrain,
    num_boost_round=num_boost_round,
    early_stopping_rounds=early_stopping_rounds,
    evals=evals,
    verbose_eval=verbose_eval,
)

### Fit XGBSE Debiased BCE with pre-trained XGB as 1st step model

In [6]:
xgbse_bce_pre = XGBSEDebiasedBCE()

xgbse_bce_pre.fit(
    X_train,
    y_train,
    num_boost_round=1000,
    validation_data=(X_valid, y_valid),
    early_stopping_rounds=10,
    verbose_eval=100,
    pre_fitted_xgb_model=[bst, DEFAULT_PARAMS]
)

preds_bce_pre = xgbse_bce_pre.predict(X_test)
cindex_bce_pre = concordance_index(y_test, preds_bce_pre)

### Fit XGBSE Stacked Weibull with pre-trained XGB as 1st step model

In [7]:
xgbse_stacked_weibull_pre = XGBSEStackedWeibull()

xgbse_stacked_weibull_pre.fit(
    X_train,
    y_train,
    num_boost_round=1000,
    validation_data=(X_valid, y_valid),
    early_stopping_rounds=10,
    verbose_eval=100,
    pre_fitted_xgb_model=[bst, DEFAULT_PARAMS]
)

preds_stacked_weibull_pre = xgbse_stacked_weibull_pre.predict(X_test)
cindex_stacked_weibull_pre = concordance_index(y_test, preds_stacked_weibull_pre)

### Display c-indexes and predictions for both

In [8]:
print(f"C-index for BCE - Pretrained XGB: {cindex_bce_pre}\n",
      f"\nC-index for Stacked Weibull - Pretrained XGB: {cindex_stacked_weibull_pre}")

C-index for BCE - Pretrained XGB: 0.6631950573698147
 
C-index for Stacked Weibull - Pretrained XGB: 0.7410414827890556


In [9]:
preds_bce_pre.head()

Unnamed: 0,73,289,506,722,939,1155,1372,1588,1805,2021,2238,2455
0,0.997565,0.97284,0.88672,0.779075,0.693134,0.661225,0.592907,0.563684,0.530866,0.450125,0.369603,0.280199
1,0.997751,0.968089,0.907702,0.818664,0.789311,0.739369,0.68538,0.655571,0.617229,0.536732,0.429817,0.34588
2,0.99766,0.982277,0.946258,0.899952,0.868477,0.813466,0.778464,0.748397,0.704141,0.602144,0.528592,0.428014
3,0.997707,0.983928,0.951476,0.886026,0.814765,0.775998,0.730448,0.702448,0.661934,0.521963,0.467084,0.347061
4,0.997877,0.952951,0.879745,0.706452,0.543965,0.504007,0.447338,0.428935,0.411937,0.299799,0.258493,0.198051


In [10]:
preds_stacked_weibull_pre.head()

Unnamed: 0,73.0,289.0,506.0,722.0,939.0,1155.0,1372.0,1588.0,1805.0,2021.0,2238.0,2455.0
0,0.987672,0.908366,0.80159,0.687057,0.574111,0.469942,0.376953,0.297382,0.230533,0.176204,0.132573,0.098379
1,0.996325,0.969895,0.930209,0.882503,0.829234,0.773002,0.714954,0.656947,0.599509,0.544004,0.490563,0.439925
2,0.999998,0.999975,0.999932,0.999869,0.999789,0.999692,0.999578,0.999449,0.999304,0.999145,0.998971,0.998782
3,0.99845,0.986585,0.967929,0.944575,0.917401,0.887454,0.8551,0.821178,0.785848,0.749847,0.713207,0.676416
4,0.983862,0.883301,0.75297,0.619043,0.493156,0.383003,0.290108,0.21531,0.156392,0.11165,0.078183,0.053832


## Fit models without pre-trained XGB as 1st step

* Assert common lib behavior is maintained

In [11]:
xgbse_bce = XGBSEDebiasedBCE()
xgbse_stacked_weibull = XGBSEStackedWeibull()

In [12]:
xgbse_bce.fit(
    X_train,
    y_train,
    num_boost_round=1000,
    validation_data=(X_valid, y_valid),
    early_stopping_rounds=10,
    verbose_eval=100,
)

preds_bce = xgbse_bce.predict(X_test)
cindex_bce = concordance_index(y_test, preds_bce)


[0]	validation-aft-nloglik:27.02227
[88]	validation-aft-nloglik:4.08169


In [13]:
xgbse_stacked_weibull.fit(
    X_train,
    y_train,
    num_boost_round=1000,
    validation_data=(X_valid, y_valid),
    early_stopping_rounds=10,
    verbose_eval=100,
)

preds_stacked_weibull = xgbse_stacked_weibull.predict(X_test)
cindex_stacked_weibull = concordance_index(y_test, preds_stacked_weibull)

[0]	validation-aft-nloglik:27.02227
[88]	validation-aft-nloglik:4.08169


### Display c-indexes and predictions for both

In [14]:
print(f"C-index for BCE - Pretrained XGB: {cindex_bce}\n",
      f"\nC-index for Stacked Weibull - Pretrained XGB: {cindex_stacked_weibull}")

C-index for BCE - Pretrained XGB: 0.6613415710503089
 
C-index for Stacked Weibull - Pretrained XGB: 0.6629302736098852


In [15]:
preds_bce.head()

Unnamed: 0,73,289,506,722,939,1155,1372,1588,1805,2021,2238,2455
0,0.997637,0.946927,0.857662,0.750377,0.684929,0.643229,0.580754,0.554754,0.524208,0.430345,0.363963,0.282642
1,0.997646,0.948405,0.865267,0.758653,0.695343,0.652587,0.590367,0.563944,0.532798,0.436037,0.368939,0.286862
2,0.997681,0.955166,0.889802,0.796541,0.734185,0.688963,0.628619,0.601607,0.569069,0.463767,0.394987,0.307775
3,0.997681,0.955403,0.890025,0.795736,0.732476,0.68745,0.62719,0.600195,0.567793,0.462052,0.393711,0.306483
4,0.99764,0.933002,0.847715,0.730012,0.648441,0.602726,0.538457,0.514436,0.488051,0.389711,0.332585,0.257718


In [16]:
preds_stacked_weibull.head()

Unnamed: 0,73.0,289.0,506.0,722.0,939.0,1155.0,1372.0,1588.0,1805.0,2021.0,2238.0,2455.0
0,0.990213,0.934623,0.862261,0.783593,0.702988,0.624368,0.54908,0.479121,0.414601,0.356466,0.30423,0.258021
1,0.993499,0.955147,0.903443,0.845271,0.783486,0.720896,0.658525,0.598103,0.539907,0.48505,0.433405,0.385459
2,0.998634,0.989566,0.976247,0.960123,0.941716,0.921645,0.900064,0.877447,0.853814,0.82958,0.804693,0.779418
3,0.997189,0.979457,0.954265,0.924556,0.891469,0.856265,0.819347,0.781632,0.743246,0.704932,0.666667,0.628903
4,0.986045,0.909683,0.814024,0.71404,0.615791,0.524121,0.440353,0.366244,0.301308,0.245829,0.19865,0.159224
