In [1]:
import pandas as pd
import numpy as np
from numpy import savetxt
from xgbsurv.datasets import (load_metabric, load_flchain, load_rgbsg, load_support, load_tcga)
from xgbsurv.models.utils import sort_X_y_pandas, transform_back, transform
from xgbsurv.models.breslow_final import get_cumulative_hazard_function_breslow, breslow_estimator_loop
import torch
from torch import nn
from sklearn.metrics import make_scorer
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold, train_test_split
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler, LabelEncoder, LabelBinarizer, OneHotEncoder
from sklearn.compose import make_column_transformer, make_column_selector
from sklearn.decomposition import PCA
from loss_functions_pytorch import BreslowLoss, breslow_likelihood_torch
from skorch import NeuralNet
from skorch.callbacks import EarlyStopping, Callback, LRScheduler
import skorch.callbacks
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import check_cv
from numbers import Number
import torch.utils.data
from skorch.utils import flatten
from skorch.utils import is_pandas_ndframe
from skorch.utils import check_indexing
from skorch.utils import multi_indexing
from skorch.utils import to_numpy
from skorch.dataset import get_len
from skorch.dataset import ValidSplit
from pycox.evaluation import EvalSurv
from scipy.stats import uniform as scuniform
from scipy.stats import randint as scrandint
from scipy.stats import loguniform as scloguniform
import random
import os
#torch.set_default_dtype(torch.float64)
#torch.set_default_tensor_type(torch.DoubleTensor)

## Set Parameters

In [2]:
# set parameters, put into function
n_outer_splits = 5
n_inner_splits = 5
rand_state = 42
n_iter = 50
early_stopping_rounds=10
base_score = 0.0

# set seed for scipy
np.random.seed(rand_state)

param_grid_breslow = {
    'estimator__module__n_layers': [1, 2, 4],
    'estimator__module__num_nodes': [64, 128, 256, 512],
    'estimator__module__dropout': scuniform(0.0,0.7),
    'estimator__optimizer__weight_decay': [0.4, 0.2, 0.1, 0.05, 0.02, 0.01, 0],
    'estimator__batch_size': [64, 128, 256, 512, 1024],
    #lr not in paper because of learning rate finder
    # note: setting learning rate higher would make exp(partial_hazard) explode
    #'estimator__lr': scloguniform(0.001,0.01), # scheduler unten einbauen
    # use callback instead
    'estimator__lr':[0.01],
    'estimator__max_epochs':  scrandint(150,250) # corresponds to num_rounds
}

## Set Seed

In [3]:
def seed_torch(seed=rand_state):
    """Sets all seeds within torch and adjacent libraries.

    Args:
        seed: Random seed to be used by the seeding functions.

    Returns:
        None
    """
    random.seed(seed)
    #os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    return None


class FixSeed(Callback):
    def __init__(self, seed):
        self.seed = seed

    def initialize(self):
        seed_torch(self.seed)
        return super().initialize()

## Set Loss Function

In [4]:
# Define Scorer
def custom_scoring_function(y_true, y_pred):

        #y_true = torch.from_numpy(y_true)
        if isinstance(y_pred, np.ndarray):
            y_pred = torch.from_numpy(y_pred)
        if isinstance(y_true, np.ndarray):
            y_true = torch.from_numpy(y_true)
        if isinstance(y_pred, pd.Series):
            y_pred = torch.tensor(y_pred.values)
        if isinstance(y_true, pd.Series):
            y_true = torch.tensor(y_true.values)
        score = breslow_likelihood_torch(y_true, y_pred).to(torch.float32)
        return score.numpy()

scoring_function = make_scorer(custom_scoring_function, greater_is_better=False)

## Set Torch Model

In [5]:

class SurvivalModel(nn.Module):
    def __init__(self, n_layers, input_units, num_nodes, dropout, out_features):
        super(SurvivalModel, self).__init__()
        self.n_layers = n_layers
        self.in_features = input_units
        self.num_nodes = num_nodes
        self.dropout = dropout
        self.out_features = out_features
        model = []
        # first layer
        model.append(torch.nn.Linear(input_units, num_nodes))
        model.append(torch.nn.ReLU())
        model.append(torch.nn.Dropout(dropout))
        model.append(torch.nn.BatchNorm1d(num_nodes))

        for i in range(n_layers-1):
            model.append(torch.nn.Linear(num_nodes, num_nodes))
            #init.kaiming_normal_(model[-1].weight, nonlinearity='relu')
            model.append(torch.nn.ReLU())
            model.append(torch.nn.Dropout(dropout))
            model.append(torch.nn.BatchNorm1d(num_nodes))

        # output layer
        model.append(torch.nn.Linear(num_nodes, out_features))
    
        self.layers = nn.Sequential(*model)

        # for layer in self.layers:
        #     if isinstance(layer, nn.Linear):
        #         #nn.init.uniform_(layer.weight, a=-0.5, b=0.5)
        #         nn.init.kaiming_normal_(layer.weight)


    def forward(self, X):
        X = X.to(torch.float32)
        res = self.layers(X)
        #print(res)
        return res


## Set up Scaler

In [6]:
class CustomStandardScaler(StandardScaler):
    
    def __init__(self, copy=True, with_mean=True, with_std=True):
        super().__init__(copy=copy, with_mean=with_mean, with_std=with_std)
        
    def fit(self, X, y=None):
        return super().fit(X, y)
    
    def transform(self, X, y=None):
        X_transformed = super().transform(X, y)
        return X_transformed.astype(np.float32)
    
    def fit_transform(self, X, y=None):
        X_transformed = super().fit_transform(X, y)
        return X_transformed.astype(np.float32)

## Custom Split

In [7]:


# Define stratified inner k-fold cross-validation
class CustomSplit(StratifiedKFold):
    def __init__(self, n_splits=2, shuffle=True, random_state=rand_state):
        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)

    def split(self, X, y, groups=None):
        print('split', X.dtypes)
        try:
            if y.shape[1]>1:
                y = y[:,0]
        except:
            pass
        bins = np.sign(y)
        return super().split(X, bins, groups=groups)

    def get_n_splits(self, X=None, y=None, groups=None):
        return self.n_splits

outer_custom_cv = CustomSplit(n_splits=n_outer_splits, shuffle=True, random_state=rand_state)
inner_custom_cv = CustomSplit(n_splits=n_outer_splits, shuffle=True, random_state=rand_state)



## Custom Valid Split

In [8]:



class CustomStandardScaler(StandardScaler):
    
    def __init__(self, copy=True, with_mean=True, with_std=True):
        super().__init__(copy=copy, with_mean=with_mean, with_std=with_std)
        
    def fit(self, X, y=None):
        return super().fit(X, y)
    
    def transform(self, X, y=None):
        X_transformed = super().transform(X, y)
        return X_transformed.astype(np.float32)
    
    def fit_transform(self, X, y=None):
        X_transformed = super().fit_transform(X, y)
        return X_transformed.astype(np.float32)
    
class CustomValidSplit():

    def __init__(
            self,
            cv=5,
            stratified=False,
            random_state=None,
    ):
        self.stratified = stratified
        self.random_state = random_state

        if isinstance(cv, Number) and (cv <= 0):
            raise ValueError("Numbers less than 0 are not allowed for cv "
                             "but ValidSplit got {}".format(cv))

        if not self._is_float(cv) and random_state is not None:
            raise ValueError(
                "Setting a random_state has no effect since cv is not a float. "
                "You should leave random_state to its default (None), or set cv "
                "to a float value.",
            )

        self.cv = cv

    def _is_stratified(self, cv):
        return isinstance(cv, (StratifiedKFold, StratifiedShuffleSplit))

    def _is_float(self, x):
        if not isinstance(x, Number):
            return False
        return not float(x).is_integer()

    def _check_cv_float(self):
        cv_cls = StratifiedShuffleSplit if self.stratified else ShuffleSplit
        return cv_cls(test_size=self.cv, random_state=self.random_state)

    def _check_cv_non_float(self, y):
        return check_cv(
            self.cv,
            y=y,
            classifier=self.stratified,
        )

    def check_cv(self, y):
        """Resolve which cross validation strategy is used."""
        y_arr = None
        if self.stratified:
            # Try to convert y to numpy for sklearn's check_cv; if conversion
            # doesn't work, still try.
            try:
                y_arr = to_numpy(y)
            except (AttributeError, TypeError):
                y_arr = y

        if self._is_float(self.cv):
            return self._check_cv_float()
        return self._check_cv_non_float(y_arr)

    def _is_regular(self, x):
        return (x is None) or isinstance(x, np.ndarray) or is_pandas_ndframe(x)

    def __call__(self, dataset, y=None, groups=None):
        # key change here
        y = np.sign(y)
        bad_y_error = ValueError(
            "Stratified CV requires explicitly passing a suitable y.")
        if (y is None) and self.stratified:
            raise bad_y_error

        cv = self.check_cv(y)
        if self.stratified and not self._is_stratified(cv):
            raise bad_y_error

        # pylint: disable=invalid-name
        len_dataset = get_len(dataset)
        if y is not None:
            len_y = get_len(y)
            if len_dataset != len_y:
                raise ValueError("Cannot perform a CV split if dataset and y "
                                 "have different lengths.")

        args = (np.arange(len_dataset),)
        if self._is_stratified(cv):
            args = args + (to_numpy(y),)

        idx_train, idx_valid = next(iter(cv.split(*args, groups=groups)))
        dataset_train = torch.utils.data.Subset(dataset, idx_train)
        dataset_valid = torch.utils.data.Subset(dataset, idx_valid)
        return dataset_train, dataset_valid


## Input Shape Setter

In [9]:
class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_units=X.shape[-1])

## Setting Training Procedure

In [18]:

def train_eval(X, y, net, n_iter, filename):
        model = 'skorch_breslow_'
        dataset_name = filename.split('_')[0]
        # add IBS later
        outer_scores = {'cindex_train_'+dataset_name:[], 'cindex_test_'+dataset_name:[],
                        'ibs_train_'+dataset_name:[], 'ibs_test_'+dataset_name:[]}
        best_params = {'best_params_'+dataset_name:[]}
        best_model = {'best_model_'+dataset_name:[]}
        ct = make_column_transformer(
                (StandardScaler(), make_column_selector(dtype_include=['float32'])),
                #(OneHotEncoder( handle_unknown='infrequent_if_exist'), make_column_selector(dtype_include=['category', 'object'])),#sparse_output=False,
                remainder='passthrough')

        pipe = Pipeline([('scaler',ct),
                        ('estimator', net)])
        rs = RandomizedSearchCV(pipe, param_grid_breslow, scoring = scoring_function, n_jobs=-1, 
                                    n_iter=n_iter, refit=True, random_state=rand_state)
        for i, (train_index, test_index) in enumerate(outer_custom_cv.split(X, y)):
                # Split data into training and testing sets for outer fold
                X_train, X_test = X.iloc[train_index], X.iloc[test_index]
                y_train, y_test = y.iloc[train_index], y.iloc[test_index]
                X_train, y_train = sort_X_y_pandas(X_train, y_train)
                X_test, y_test = sort_X_y_pandas(X_test, y_test)

                #print(X_train.shape, type(X_train))
                #print(y_train.shape, type(y_train))
                #print(X_test.shape, type(X_test))
                #print(y_test.shape, type(y_test))
                
                # save splits and data
                savetxt('splits/'+model+'train_index_'+str(i)+'_'+filename, train_index, delimiter=',')
                savetxt('splits/'+model+'test_index_'+str(i)+'_'+filename, test_index, delimiter=',')
                
                #savetxt('splits/X_train_'+str(i)+'_'+filename, X_train, delimiter=',')
                #savetxt('splits/X_test_'+str(i)+'_'+filename, X_test, delimiter=',')

                #savetxt('splits/y_train_'+str(i)+'_'+filename, y_train, delimiter=',')
                #savetxt('splits/y_test_'+str(i)+'_'+filename, y_test, delimiter=',')


                rs.fit(X_train, y_train)
                best_preds_train = rs.best_estimator_.predict(X_train)
                best_preds_test = rs.best_estimator_.predict(X_test)

                savetxt('predictions/'+model+'best_preds_train_'+str(i)+'_'+filename, best_preds_train, delimiter=',')
                savetxt('predictions/'+model+'best_preds_test_'+str(i)+'_'+filename, best_preds_test, delimiter=',')

                # save hyperparameter settings
                params = rs.best_estimator_.get_params
                best_params['best_params_'+dataset_name] += [rs.best_params_]
                best_model['best_model_'+dataset_name] += [params]
                try:
                    cum_hazard_train = get_cumulative_hazard_function_breslow(
                            X_train.values, X_train.values, y_train.values, y_train.values,
                            best_preds_train.reshape(-1), best_preds_train.reshape(-1)
                            )

                    df_survival_train = np.exp(-cum_hazard_train)
                    durations_train, events_train = transform_back(y_train.values)
                    time_grid_train = np.linspace(durations_train.min(), durations_train.max(), 100)
                    ev = EvalSurv(df_survival_train, durations_train, events_train, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_train))
                    cindex_score_train = ev.concordance_td('antolini')
                    ibs_score_train = ev.integrated_brier_score(time_grid_train)

                    outer_scores['cindex_train_'+dataset_name] += [cindex_score_train]
                    outer_scores['ibs_train_'+dataset_name] += [ibs_score_train]

                except:
                    outer_scores['cindex_train_'+dataset_name] += [np.nan]
                    outer_scores['ibs_train_'+dataset_name] += [np.nan]
                    
                try:
                    cum_hazard_test = get_cumulative_hazard_function_breslow(
                            X_train.values, X_test.values, y_train.values, y_test.values,
                            best_preds_train.reshape(-1), best_preds_test.reshape(-1)
                            )
                    df_survival_test = np.exp(-cum_hazard_test)
                    durations_test, events_test = transform_back(y_test.values)
                    time_grid_test = np.linspace(durations_test.min(), durations_test.max(), 100)
                    ev = EvalSurv(df_survival_test, durations_test, events_test, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_test))
                    cindex_score_test = ev.concordance_td('antolini')
                    ibs_score_test = ev.integrated_brier_score(time_grid_test)

                    outer_scores['cindex_test_'+dataset_name] += [cindex_score_test]
                    outer_scores['ibs_test_'+dataset_name] += [ibs_score_test]
                except: 
                    outer_scores['cindex_test_'+dataset_name] += [np.nan]
                    outer_scores['ibs_test_'+dataset_name] += [np.nan]
            
        df_best_params = pd.DataFrame(best_params)
        df_best_model = pd.DataFrame(best_model)
        df_outer_scores = pd.DataFrame(outer_scores)
        df_metrics = pd.concat([df_best_params,df_best_model,df_outer_scores], axis=1)
        df_metrics.to_csv('metrics/'+model+'metric_summary_'+'_'+filename, index=False)
        
        # cindex
        df_agg_metrics_cindex = pd.DataFrame({'dataset':[dataset_name],
                                              'cindex_train_mean':df_outer_scores['cindex_train_'+dataset_name].mean(),
                                              'cindex_train_std':df_outer_scores['cindex_train_'+dataset_name].std(),
                                              'cindex_test_mean':df_outer_scores['cindex_test_'+dataset_name].mean(),
                                              'cindex_test_std':df_outer_scores['cindex_test_'+dataset_name].std() })
        # IBS
        df_agg_metrics_ibs = pd.DataFrame({'dataset':[dataset_name],
                                              'ibs_train_mean':df_outer_scores['ibs_train_'+dataset_name].mean(),
                                              'ibs_train_std':df_outer_scores['ibs_train_'+dataset_name].std(),
                                              'ibs_test_mean':df_outer_scores['ibs_test_'+dataset_name].mean(),
                                              'ibs_test_std':df_outer_scores['ibs_test_'+dataset_name].std() })

        return df_agg_metrics_cindex, df_agg_metrics_ibs, best_model, best_params, outer_scores, best_preds_train, best_preds_test #, X_train, X_test, y_train, y_test

                

In [22]:
data_set_fns = [  load_support] #load_metabric,  load_flchain,#, load_flchain, load_rgbsg, load_support, load_tcga]load_rgbsg,
data_set_fns_str = [ 'load_support'] #'load_metabric', 'load_flchain','load_rgbsg', 
one_hot_dict = { 'load_support':['cancer','race'],}#'load_flchain': ['mgus'], 'load_rgbsg':['grade'],
agg_metrics_cindex = []
agg_metrics_ibs = []

for idx, dataset in enumerate(data_set_fns):
    # get name of current dataset
    data = dataset(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=True)
    X  = data.data #.astype(np.float32)
    y = data.target #.values #.to_numpy()

    #print(data_set_fns_str[idx])
    if data_set_fns_str[idx] in one_hot_dict.keys():
        X = pd.get_dummies(X, columns=one_hot_dict[data_set_fns_str[idx]])
    X, y = sort_X_y_pandas(X, y)
    
    net = NeuralNet(
        SurvivalModel, 
        #module__n_layers = 1,
        module__input_units = X.shape[1],
        #module__num_nodes = 32,
        #module__dropout = 0.1, # these could also be removed
        module__out_features = 1,
        # for split sizes when result size = 1
        iterator_train__drop_last=True,
        #iterator_valid__drop_last=True,
        criterion=BreslowLoss,
        optimizer=torch.optim.AdamW,
        optimizer__weight_decay = 0.4,
        batch_size=32, # separate train and valid->iterator_train__batch_size=128 and iterator_valid__batch_size=128 ?
        callbacks=[
            (
                "sched",
                LRScheduler(
                    torch.optim.lr_scheduler.ReduceLROnPlateau,
                    monitor="valid_loss",
                    patience=5,
                ),
            ),
            (
                "es",
                EarlyStopping(
                    monitor="valid_loss",
                    patience=early_stopping_rounds,
                    load_best=True,
                ),
            ),
            ("seed", FixSeed(seed=42)),
            #("Input Shape Setter",InputShapeSetter())
        ],
        train_split = CustomValidSplit(0.2, stratified=True, random_state=rand_state), 
        verbose=0
    )
    df_agg_metrics_cindex, df_agg_metrics_ibs, best_model,params, outer_scores, best_preds_train, best_preds_test = train_eval(X, y, net, n_iter, data.filename)
    agg_metrics_cindex.append(df_agg_metrics_cindex)
    agg_metrics_ibs.append(df_agg_metrics_ibs)


split age                 float32
sex                   uint8
n_comorbidities     float32
diabetes              uint8
dementia              uint8
blood_pressure      float32
heart_rate          float32
respiration_rate    float32
temperature         float32
white_blood_cell    float32
serum_sodium        float32
serum_creatinine    float32
cancer_0.0            uint8
cancer_1.0            uint8
cancer_2.0            uint8
race_0.0              uint8
race_1.0              uint8
race_2.0              uint8
race_3.0              uint8
race_4.0              uint8
race_5.0              uint8
race_6.0              uint8
race_7.0              uint8
race_8.0              uint8
race_9.0              uint8
dtype: object
Concordance Index 0.6092034454057412
Integrated Brier Score: 0.19817574921421613
Concordance Index 0.5982227630068273
Integrated Brier Score: 0.20341964201203153
Concordance Index 0.5979460445204812
Integrated Brier Score: 0.19715680167865088
Concordance Index 0.5970365943119679


In [23]:
df_final_breslow_1_cindex = pd.concat([df for df in agg_metrics_cindex]).round(4)
df_final_breslow_1_cindex.to_csv('metrics/final_deep_learning_breslow_1_cindex.csv', index=False)
df_final_breslow_1_cindex.to_csv('/Users/JUSC/Documents/644928e0fb7e147893e8ec15/05_thesis/tables/final_deep_learning_breslow_1_cindex.csv', index=False)  #
df_final_breslow_1_cindex

Unnamed: 0,dataset,cindex_train_mean,cindex_train_std,cindex_test_mean,cindex_test_std
0,SUPPORT,0.6069,0.0051,0.5941,0.009


In [24]:
df_final_breslow_1_ibs = pd.concat([df for df in agg_metrics_ibs]).round(4)
df_final_breslow_1_ibs.to_csv('metrics/final_deep_learning_breslow_1_ibs.csv', index=False)
df_final_breslow_1_ibs.to_csv('/Users/JUSC/Documents/644928e0fb7e147893e8ec15/05_thesis/tables/final_deep_learning_breslow_1_ibs.csv', index=False) 
df_final_breslow_1_ibs

Unnamed: 0,dataset,ibs_train_mean,ibs_train_std,ibs_test_mean,ibs_test_std
0,SUPPORT,0.1951,0.0027,0.1989,0.0035


## TCGA

In [26]:
param_grid_breslow_tcga = {
    'estimator__module__n_layers': [1, 2, 4],
    'estimator__module__num_nodes': [64, 128, 256, 512],
    'estimator__module__dropout': scuniform(0.0,0.7),
    'estimator__optimizer__weight_decay': [0.4, 0.2, 0.1, 0.05, 0.02, 0.01, 0],
    'estimator__batch_size': [64, 128, 256, 512, 1024],
    #lr not in paper because of learning rate finder
    # note: setting learning rate higher would make exp(partial_hazard) explode
    #'estimator__lr': scloguniform(0.001,0.01), # scheduler unten einbauen
    # use callback instead
    'estimator__lr':[0.01],
    'estimator__max_epochs':  scrandint(150,250), # corresponds to num_rounds
    #'pca__n_components': [8, 16, 32, 64]
}

In [35]:

def train_eval(X, y, net, n_iter, filename):
        model = 'skorch_breslow_tcga_'
        dataset_name = filename.split('_')[0]

        # add IBS later
        outer_scores = {'cindex_train_'+dataset_name:[], 'cindex_test_'+dataset_name:[],
                        'ibs_train_'+dataset_name:[], 'ibs_test_'+dataset_name:[]}
        best_params = {'best_params_'+dataset_name:[]}
        best_model = {'best_model_'+dataset_name:[]}
        ct = make_column_transformer(
                (StandardScaler(), make_column_selector(dtype_include=['float32'])),
                #(OneHotEncoder(sparse_output=False), make_column_selector(dtype_include=['category', 'object'])),
                remainder='drop')
        pipe = Pipeline([('scaler',ct),
                         #('pca', PCA()),
                        ('estimator', net)])
        rs = RandomizedSearchCV(pipe, param_grid_breslow_tcga, scoring = scoring_function, n_jobs=-1, 
                                    n_iter=n_iter, refit=True, random_state=rand_state)
        for i, (train_index, test_index) in enumerate(outer_custom_cv.split(X, y)):
                # Split data into training and testing sets for outer fold
                X_train, X_test = X.iloc[train_index], X.iloc[test_index]
                y_train, y_test = y.iloc[train_index], y.iloc[test_index]
                X_train, y_train = sort_X_y_pandas(X_train, y_train)
                X_test, y_test = sort_X_y_pandas(X_test, y_test)

                # print(X_train.shape, type(X_train))
                # print(y_train.shape, type(y_train))
                # print(X_test.shape, type(X_test))
                # print(y_test.shape, type(y_test))
                # save splits and data
                savetxt('splits/'+model+'train_index_'+str(i)+'_'+filename, train_index, delimiter=',')
                savetxt('splits/'+model+'test_index_'+str(i)+'_'+filename, test_index, delimiter=',')
                
                #savetxt('splits/X_train_'+str(i)+'_'+filename, X_train, delimiter=',')
                #savetxt('splits/X_test_'+str(i)+'_'+filename, X_test, delimiter=',')

                #savetxt('splits/y_train_'+str(i)+'_'+filename, y_train, delimiter=',')
                #savetxt('splits/y_test_'+str(i)+'_'+filename, y_test, delimiter=',')




                rs.fit(X_train, y_train)
                best_preds_train = rs.best_estimator_.predict(X_train)
                best_preds_test = rs.best_estimator_.predict(X_test)
                savetxt('predictions/'+model+'best_preds_train_'+str(i)+'_'+filename, best_preds_train, delimiter=',')
                savetxt('predictions/'+model+'best_preds_test_'+str(i)+'_'+filename, best_preds_test, delimiter=',')

                # save hyperparameter settings
                params = rs.best_estimator_.get_params
                best_params['best_params_'+dataset_name] += [rs.best_params_]
                best_model['best_model_'+dataset_name] += [params]
                try:
                    cum_hazard_train = get_cumulative_hazard_function_breslow(
                            X_train.values, X_train.values, y_train.values, y_train.values,
                            best_preds_train.reshape(-1), best_preds_train.reshape(-1)
                            )

                    df_survival_train = np.exp(-cum_hazard_train)
                    durations_train, events_train = transform_back(y_train.values)
                    time_grid_train = np.linspace(durations_train.min(), durations_train.max(), 100)
                    ev = EvalSurv(df_survival_train, durations_train, events_train, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_train))
                    cindex_score_train = ev.concordance_td('antolini')
                    ibs_score_train = ev.integrated_brier_score(time_grid_train)

                    outer_scores['cindex_train_'+dataset_name] += [cindex_score_train]
                    outer_scores['ibs_train_'+dataset_name] += [ibs_score_train]

                except:
                    outer_scores['cindex_train_'+dataset_name] += [np.nan]
                    outer_scores['ibs_train_'+dataset_name] += [np.nan]
                    
                try:
                    cum_hazard_test = get_cumulative_hazard_function_breslow(
                            X_train.values, X_test.values, y_train.values, y_test.values,
                            best_preds_train.reshape(-1), best_preds_test.reshape(-1)
                            )
                    df_survival_test = np.exp(-cum_hazard_test)
                    durations_test, events_test = transform_back(y_test.values)
                    print('durations',durations_test.min(), durations_test.max())
                    time_grid_test = np.linspace(durations_test.min(), durations_test.max(), 100)
                    ev = EvalSurv(df_survival_test, durations_test, events_test, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_test))
                    cindex_score_test = ev.concordance_td('antolini')
                    ibs_score_test = ev.integrated_brier_score(time_grid_test)

                    outer_scores['cindex_test_'+dataset_name] += [cindex_score_test]
                    outer_scores['ibs_test_'+dataset_name] += [ibs_score_test]
                except: 
                    outer_scores['cindex_test_'+dataset_name] += [np.nan]
                    outer_scores['ibs_test_'+dataset_name] += [np.nan]
            
        df_best_params = pd.DataFrame(best_params)
        df_best_model = pd.DataFrame(best_model)
        df_outer_scores = pd.DataFrame(outer_scores)
        df_metrics = pd.concat([df_best_params,df_best_model,df_outer_scores], axis=1)
        df_metrics.to_csv('metrics/'+model+'metric_summary_'+'_'+filename, index=False)
        # cindex
        df_agg_metrics_cindex = pd.DataFrame({'dataset':[dataset_name],
                                              'cindex_train_mean':df_outer_scores['cindex_train_'+dataset_name].mean(),
                                              'cindex_train_std':df_outer_scores['cindex_train_'+dataset_name].std(),
                                              'cindex_test_mean':df_outer_scores['cindex_test_'+dataset_name].mean(),
                                              'cindex_test_std':df_outer_scores['cindex_test_'+dataset_name].std() })
        # IBS
        df_agg_metrics_ibs = pd.DataFrame({'dataset':[dataset_name],
                                              'ibs_train_mean':df_outer_scores['ibs_train_'+dataset_name].mean(),
                                              'ibs_train_std':df_outer_scores['ibs_train_'+dataset_name].std(),
                                              'ibs_test_mean':df_outer_scores['ibs_test_'+dataset_name].mean(),
                                              'ibs_test_std':df_outer_scores['ibs_test_'+dataset_name].std() })
        return df_agg_metrics_cindex, df_agg_metrics_ibs,best_model, best_params, outer_scores, best_preds_train, best_preds_test

In [36]:
cancer_types = ['BLCA',
    'BRCA',
    'HNSC',
    'KIRC',
    'LGG',
    'LIHC',
    'LUAD',
    'LUSC',
    'OV',
    'STAD']


class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_units=X.shape[-1])

agg_metrics_cindex = []
agg_metrics_ibs = []

for idx, cancer_type in enumerate(cancer_types):
    # get name of current dataset
    data = load_tcga(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", cancer_type=cancer_type, as_frame=True)
    X  = data.data #.astype(np.float32)
    y = data.target #.values #.to_numpy()

    X, y = sort_X_y_pandas(X, y)

    net = NeuralNet(
        SurvivalModel, 
        module__n_layers = 1,
        module__input_units = X.shape[1],
        #module__num_nodes = 32,
        #module__dropout = 0.1, # these could also be removed
        module__out_features = 1,
        # for split sizes when result size = 1
        iterator_train__drop_last=True,
        #iterator_valid__drop_last=True,
        criterion=BreslowLoss,
        optimizer=torch.optim.AdamW,
        optimizer__weight_decay = 0.4,
        batch_size=32, # separate train and valid->iterator_train__batch_size=128 and iterator_valid__batch_size=128 ?
        callbacks=[
            (
                "sched",
                LRScheduler(
                    torch.optim.lr_scheduler.ReduceLROnPlateau,
                    monitor="valid_loss",
                    patience=5,
                ),
            ),
            (
                "es",
                EarlyStopping(
                    monitor="valid_loss",
                    patience=10,
                    load_best=True,
                ),
            ),
            ("seed", FixSeed(seed=42)),
            #("Inpout Shape Setter",InputShapeSetter())
        ],
        train_split = CustomValidSplit(0.2, stratified=True, random_state=rand_state),  
        verbose=0
    )
    df_agg_metrics_cindex, df_agg_metrics_ibs, best_model,params, outer_scores, best_preds_train, best_preds_test = train_eval(X, y, net, n_iter, data.filename)
    agg_metrics_cindex.append(df_agg_metrics_cindex)
    agg_metrics_ibs.append(df_agg_metrics_ibs)

split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
gex_ZZEF1|23140      float32
gex_ZZZ3|26009       float32
Length: 20531, dtype: object
Concordance Index 0.6865297392971718
Integrated Brier Score: 0.24073819193104293
durations 17.0 4343.0
Concordance Index 0.6027624309392265
Integrated Brier Score: 0.22048646125462354


  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.7150806900389538
Integrated Brier Score: 0.17974245817811713
durations 15.0 3817.0
Concordance Index 0.5548880393227744
Integrated Brier Score: 0.28346270300502047
Concordance Index 0.6818708369533739
Integrated Brier Score: 0.20533589638426658
durations 59.0 5050.0
Concordance Index 0.5878136200716846
Integrated Brier Score: 0.2534493172889495
Concordance Index 0.6872096850883289
Integrated Brier Score: 0.21303965768824187
durations 55.0 5041.0
Concordance Index 0.6025200458190149
Integrated Brier Score: 0.3033755215942722
Concordance Index 0.7270485450574227
Integrated Brier Score: 0.18766475149697115
durations 13.0 3432.0
Concordance Index 0.5656108597285068
Integrated Brier Score: 0.25119327956884474
split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32

  new_threshold = score - abs_threshold_change


Concordance Index 0.5238727734161903
Integrated Brier Score: 0.1880374423318559
durations 5.0 8008.0
Concordance Index 0.5103954341622503
Integrated Brier Score: 0.18859614237535777


  new_threshold = score - abs_threshold_change


Concordance Index 0.5251194721393397
Integrated Brier Score: 0.18360544454616307
durations 5.0 8556.0
Concordance Index 0.5085820895522388
Integrated Brier Score: 0.20030729613019166
Concordance Index 0.5422024030719683
Integrated Brier Score: 0.20947900948657192
durations 1.0 7106.0
Concordance Index 0.43394911118856744
Integrated Brier Score: 0.19080975683857737


  new_threshold = score - abs_threshold_change


Concordance Index 0.5116346223259619
Integrated Brier Score: 0.20385366190870347
durations 1.0 8391.0
Concordance Index 0.6072684642438453
Integrated Brier Score: 0.21383338502535842
split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
gex_ZZEF1|23140      float32
gex_ZZZ3|26009       float32
Length: 20531, dtype: object
Concordance Index 0.7116390732226576
Integrated Brier Score: 0.3912594593229989
durations 23.0 4760.0
Concordance Index 0.564308121296619
Integrated Brier Score: 0.5599550714892444


  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.6232340844062947
Integrated Brier Score: 0.19154693999937855
durations 14.0 4282.0
Concordance Index 0.5905044510385756
Integrated Brier Score: 0.210604136656739
Concordance Index 0.5639498711225669
Integrated Brier Score: 0.18033452932493985
durations 11.0 6417.0
Concordance Index 0.5789473684210527
Integrated Brier Score: 0.2050914092668158
Concordance Index 0.5611388384754991
Integrated Brier Score: 0.16962283056859345
durations 2.0 5480.0
Concordance Index 0.5764541971438523
Integrated Brier Score: 0.21708490442678716
Concordance Index 0.5670484344532353
Integrated Brier Score: 0.18352485075508992
durations 14.0 5152.0
Concordance Index 0.5777051561365287
Integrated Brier Score: 0.18657243659927855
split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
g

  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.7366567624078354
Integrated Brier Score: 0.1586454433472276
durations 3.0 3431.0
Concordance Index 0.6706638115631691
Integrated Brier Score: 0.17021119442680066
Concordance Index 0.5824812357228326
Integrated Brier Score: 0.1937834688834565
durations 3.0 4537.0
Concordance Index 0.5403189066059225
Integrated Brier Score: 0.20724791922148658
Concordance Index 0.5767652774778665
Integrated Brier Score: 0.1996279729321882
durations 16.0 3987.0
Concordance Index 0.5854922279792746
Integrated Brier Score: 0.18757009115553533
Concordance Index 0.7250973082616085
Integrated Brier Score: 0.15387122433222522
durations 13.0 4067.0
Concordance Index 0.7052585064074238
Integrated Brier Score: 0.18532130128360208


  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.6919998880711867
Integrated Brier Score: 0.1795882945077531
durations 2.0 3944.0
Concordance Index 0.6236341562120599
Integrated Brier Score: 0.18447520375801926
split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
gex_ZZEF1|23140      float32
gex_ZZZ3|26009       float32
Length: 20531, dtype: object
Concordance Index 0.8132323678974133
Integrated Brier Score: 0.1637504525652532
durations 3.0 4695.0
Concordance Index 0.8232971372161896
Integrated Brier Score: 0.1529292281402801
Concordance Index 0.8011350737797956
Integrated Brier Score: 0.14520852727836547
durations 3.0 6423.0
Concordance Index 0.7541557305336833
Integrated Brier Score: 0.2263621622850644
Concordance Index 0.8234325244317557
Integrated Brier Score: 0.16662543752336234
durations 4.0 5166.0

  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.5593409826419535
Integrated Brier Score: 0.1995470853656357
durations 6.0 2728.0
Concordance Index 0.4513491414554375
Integrated Brier Score: 0.23073728962897
split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
gex_ZZEF1|23140      float32
gex_ZZZ3|26009       float32
Length: 20531, dtype: object
Concordance Index 0.5226191616384127
Integrated Brier Score: 0.18450846905340393
durations 4.0 3635.0
Concordance Index 0.4332993890020367
Integrated Brier Score: 0.20524213871088098


  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.517420778811288
Integrated Brier Score: 0.1903121968414787
durations 18.0 4765.0
Concordance Index 0.532870122928915
Integrated Brier Score: 0.19647577916986905
Concordance Index 0.48144427001569856
Integrated Brier Score: 0.19582237039805123
durations 11.0 7248.0
Concordance Index 0.49725274725274726
Integrated Brier Score: 0.202862123775123
Concordance Index 0.7173777792331117
Integrated Brier Score: 0.1745931945875225
durations 28.0 3940.0
Concordance Index 0.5324074074074074
Integrated Brier Score: 0.3120608922064661
Concordance Index 0.5071540603693864
Integrated Brier Score: 0.18072169254179427
durations 19.0 6732.0
Concordance Index 0.37321792260692466
Integrated Brier Score: 0.20951826183401617
split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
g

  new_threshold = score - abs_threshold_change


Concordance Index 0.5500802975234553
Integrated Brier Score: 0.1891094179723483
durations 1.0 3838.0
Concordance Index 0.42496793501496366
Integrated Brier Score: 0.2107241211709256
Concordance Index 0.5259777528765476
Integrated Brier Score: 0.19278964525333633
durations 4.0 4026.0
Concordance Index 0.5361179361179361
Integrated Brier Score: 0.1907237564563564
Concordance Index 0.524290238660785
Integrated Brier Score: 0.19128313965834026
durations 12.0 5287.0
Concordance Index 0.5350236355822948
Integrated Brier Score: 0.20777328631062786
Concordance Index 0.5345884413309983
Integrated Brier Score: 0.18689111849991893
durations 3.0 4765.0
Concordance Index 0.4909887968826108
Integrated Brier Score: 0.19276061017691287
Concordance Index 0.5122448401800629
Integrated Brier Score: 0.18319901004024172
durations 2.0 4694.0
Concordance Index 0.5686357841053974
Integrated Brier Score: 0.22088940859535142
split gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          fl

  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.50398636723267
Integrated Brier Score: 0.1255951876104938
durations 8.0 4424.0
Concordance Index 0.4167468719923003
Integrated Brier Score: 0.1372662974038205
Concordance Index 0.5110780226325193
Integrated Brier Score: 0.14115753444426504
durations 9.0 5481.0
Concordance Index 0.45454545454545453
Integrated Brier Score: 0.11973710993513564


  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.5568898592916061
Integrated Brier Score: 0.12560496381114752
durations 53.0 4624.0
Concordance Index 0.5131195335276968
Integrated Brier Score: 0.14867737841925624


  new_threshold = score - abs_threshold_change


Concordance Index 0.536114795610756
Integrated Brier Score: 0.12393962095199665
durations 24.0 3871.0
Concordance Index 0.5099403578528827
Integrated Brier Score: 0.14738343861386174


  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change
  new_threshold = score - abs_threshold_change


Concordance Index 0.531405721316518
Integrated Brier Score: 0.12053763421344445
durations 11.0 3525.0
Concordance Index 0.5112359550561798
Integrated Brier Score: 0.16691852019968492
split gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
gex_?|155060         float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
gex_ZZEF1|23140      float32
gex_ZZZ3|26009       float32
Length: 19076, dtype: object
Concordance Index 0.592316173352557
Integrated Brier Score: 0.20382519501371732
durations 7.0 3540.0
Concordance Index 0.5735797399041752
Integrated Brier Score: 0.2236037344148463
Concordance Index 0.583249791144528
Integrated Brier Score: 0.21030430825108473
durations 14.0 2267.0
Concordance Index 0.6364306784660767
Integrated Brier Score: 0.18671620490057064
Concordance Index 0.5777292014915777
Integrated Brier Score: 0.2155463001582481
durations 8.0 3720.

In [37]:
df_final_breslow_1_cindex = pd.concat([df for df in agg_metrics_cindex]).round(4)
df_final_breslow_1_cindex.to_csv('metrics/final_deep_learning_tcga_breslow_1_cindex.csv', index=False)
df_final_breslow_1_cindex.to_csv('/Users/JUSC/Documents/644928e0fb7e147893e8ec15/05_thesis/tables/final_deep_learning_tcga_breslow_1_cindex.csv', index=False)  #
df_final_breslow_1_cindex

Unnamed: 0,dataset,cindex_train_mean,cindex_train_std,cindex_test_mean,cindex_test_std
0,BLCA,0.6995,0.0202,0.5827,0.0217
0,BRCA,0.5298,0.0142,0.4997,0.0704
0,HNSC,0.6054,0.0647,0.5776,0.0093
0,KIRC,0.6626,0.0775,0.6251,0.0657
0,LGG,0.8197,0.0207,0.7498,0.0558
0,LIHC,0.5485,0.0251,0.5073,0.1002
0,LUAD,0.5492,0.0953,0.4738,0.0693
0,LUSC,0.5294,0.014,0.5111,0.0555
0,OV,0.5279,0.0211,0.4811,0.0436
0,STAD,0.59,0.01,0.5845,0.068


In [38]:
df_final_breslow_1_ibs = pd.concat([df for df in agg_metrics_ibs]).round(4)
df_final_breslow_1_ibs.to_csv('metrics/final_deep_learning_tcga_breslow_1_ibs.csv', index=False)
df_final_breslow_1_ibs.to_csv('/Users/JUSC/Documents/644928e0fb7e147893e8ec15/05_thesis/tables/final_deep_learning_tcga_breslow_1_ibs.csv', index=False) 
df_final_breslow_1_ibs

Unnamed: 0,dataset,ibs_train_mean,ibs_train_std,ibs_test_mean,ibs_test_std
0,BLCA,0.2053,0.0239,0.2624,0.032
0,BRCA,0.1937,0.0121,0.2003,0.0108
0,HNSC,0.2233,0.0942,0.2759,0.1592
0,KIRC,0.1771,0.0204,0.187,0.0132
0,LGG,0.1646,0.0116,0.2177,0.0982
0,LIHC,0.2065,0.0062,0.2181,0.012
0,LUAD,0.1852,0.0082,0.2252,0.0488
0,LUSC,0.1887,0.0038,0.2046,0.0127
0,OV,0.1274,0.008,0.144,0.0173
0,STAD,0.2103,0.0065,0.2104,0.0193
