In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import xgboost as xgb



from sklearn.pipeline import Pipeline
from sklearn.model_selection import KFold, train_test_split



from derma.general.preprocessing.transformers import (TransformToNumeric, 
                                                      TransformToDatetime, 
                                                      ComputeAge,
                                                      TransformToObject,
                                                      KeepColumns,
                                                      ComputeAJCC,
                                                      LinkTumourPartToParent,
                                                      TransformCbRegression,
                                                      ConvertCategoriesToNaN,
                                                      ExponentialTransformer,
                                                      RenameLabValues,
                                                      CustomScaler,
                                                      CustomImputer,
                                                      TransformNodMets)
from derma.general.preprocessing.encoders import (OrdinalEncoder,
                                                  GenderEncoder,
                                                  AbsentPresentEncoder,
                                                  LABEncoder,
                                                  CategoricalEncoder)
from derma.general.ingestion.data_loader_csv import SurvivalLoader
from derma.sol.survival.notebooks import config_os as settings_file
path = '/home/carlos.hernandez/datasets/csvs/data-surv_20220302.csv'
X, time_xgboost, time, event = SurvivalLoader(target='os').load_data(path)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  target['time'] = (target['date'] - target['cb_examined_at']


In [2]:

pipe = Pipeline(steps=[
    ('TransformToNumeric', TransformToNumeric(**settings_file.transform_to_numeric)), 
    ('TransformToDatetime', TransformToDatetime(**settings_file.transform_to_datetime)),
    ('TransformToObject', TransformToObject(**settings_file.transform_to_object)),
    ('ComputeAge', ComputeAge(**settings_file.compute_age)),
    ('tr_tm', LinkTumourPartToParent(**settings_file.link_tumour_part_to_parent)),
    ('tr_cb', TransformCbRegression(**settings_file.transform_cb_regression)),
    ('tr0', ConvertCategoriesToNaN(**settings_file.convert_categories_to_nan)),
    ('tr2', GenderEncoder(**settings_file.gender_encoder)),
    ('tr3', AbsentPresentEncoder(**settings_file.absent_present_encoder)),
    ("tr4", CategoricalEncoder(**settings_file.categorical_encoder)),
    ('tr7', LABEncoder(**settings_file.lab_encoder)),
    ('OrdinalEncoder', OrdinalEncoder(**settings_file.ordinal_encoder)),
    ('ComputeAJCC', ComputeAJCC(**settings_file.compute_ajcc)),
    ('tr5', ExponentialTransformer(**settings_file.exponential_transformer)),
#    ('TransformNodMets', TransformNodMets(**settings_file.transform_nod_mets)),
    ('KeepColumns', KeepColumns(**settings_file.keep_cols)),
    ('CustomImputer', CustomImputer(strategy='mean')),
    ('CustomScaler', CustomScaler()),
    ('RenameLabValues', RenameLabValues(**settings_file.rename_lab_values)),
    ])

target = pd.concat([time, event], axis=1)
y = target
def create_target(row):
    if row['event'] == 0:
        return -row['time']
    return row['time'] 

splits = []
labels = []
for i in range(5):

    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                            test_size=0.2, random_state=i)
    X_test, X_val, y_test, y_val = train_test_split(X_test, y_test,
                                            test_size=0.5, random_state=i)
    
    X_train_pre = pipe.fit_transform(X_train.copy(), y_train)
    
    X_val_pre   = pipe.transform(X_val.copy())
    X_test_pre  = pipe.transform(X_test.copy())
    
    
    splits.append([X_train, X_val_pre, X_test_pre, X_train_pre])
    labels.append([y_train, y_val, y_test])
    break

### From here onwards we leave XXMM behind and dive into the specifics of XGBoost

In [3]:
import xgboost as xgb
from xgbse.metrics import concordance_index
from xgbse.converters import (
    convert_data_to_xgb_format, # <- it requires specifc format
    convert_to_structured
)


# conver labels
y_val = convert_to_structured(y_val['time'], y_val['event'])
y_train = convert_to_structured(y_train['time'], y_train['event'])
y_test = convert_to_structured(y_test['time'], y_test['event'])

# and data
dtrain = convert_data_to_xgb_format(X_train_pre, y_train, 'survival:aft')
dval = convert_data_to_xgb_format(X_val_pre, y_val, 'survival:aft')

# Instantiate some hyperparams
PARAMS_XGB_AFT = {
    'objective': 'survival:aft', # <- we could also use survival:cox
    'eval_metric': 'aft-nloglik',
    'aft_loss_distribution': 'normal',
    'aft_loss_distribution_scale': 1.0,
    'tree_method': 'hist', 
    'learning_rate': 5e-2, 
    'max_depth': 8, 
    'booster':'dart',
    'subsample':0.5,
    'min_child_weight': 50,
    'colsample_bynode':0.5
}


bst = xgb.train(
        PARAMS_XGB_AFT,
        dtrain,
        num_boost_round=200,
        early_stopping_rounds=10,
        evals=[(dval, 'val')],
        verbose_eval=0
    )

In [6]:
dval = convert_data_to_xgb_format(X_val_pre, y_val, 'survival:aft')
dtest = convert_data_to_xgb_format(X_test_pre, y_test, 'survival:aft')

preds = bst.predict(dtest)
c_index = concordance_index(y_test, -preds, risk_strategy='precomputed')
print(f'Hooray we got {round(c_index,3)} of concordance index')

Hooray we got 0.826 of concordance index


### Now its SurvTIME! (hehe)