In [1]:
import pandas as pd
import numpy as np
from numpy import savetxt
from xgbsurv.datasets import (load_metabric, load_flchain, load_rgbsg, load_support, load_tcga)
from xgbsurv.models.utils import sort_X_y_pandas, transform_back, transform
from xgbsurv.preprocessing.dataset_preprocessing import discretizer_df
import torch
from torch import nn
from sklearn.metrics import make_scorer
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold, train_test_split
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler, LabelEncoder, LabelBinarizer, OneHotEncoder
from sklearn.compose import make_column_transformer, make_column_selector
from sklearn.decomposition import PCA
from loss_functions_pytorch import DeephitLoss, deephit_likelihood_1_torch
from skorch import NeuralNet, NeuralNetRegressor
from skorch.callbacks import EarlyStopping, Callback, LRScheduler
from skorch.dataset import ValidSplit
from pycox.evaluation import EvalSurv
from pycox.models import DeepHitSingle
from scipy.stats import uniform as scuniform
from scipy.stats import randint as scrandint
from scipy.stats import loguniform as scloguniform
import random
from functools import partial
import os
#torch.set_default_dtype(torch.float64)
#torch.set_default_tensor_type(torch.DoubleTensor)

## Set Parameters

In [2]:
# set parameters, put into function
n_outer_splits = 5
n_inner_splits = 5
rand_state = 42
n_iter = 10 # set to 50
#n_iter_cind = 200
early_stopping_rounds=15
base_score = 0.0

param_grid_breslow = {
    'estimator__module__n_layers': [1, 2, 4],
    'estimator__module__n_layers': [1, 2, 4],
    'estimator__module__num_nodes': [64, 128, 256, 512],
    'estimator__module__dropout': scuniform(0.0,0.7),
    'estimator__optimizer__weight_decay': [0.4, 0.2, 0.1, 0.05, 0.02, 0.01, 0],
    'estimator__batch_size': [64, 128, 256, 512, 1024],
    #lr not in paper because of learning rate finder
    # note: setting learning rate higher would make exp(partial_hazard) explode
    #'estimator__lr': scloguniform(0.001,0.01), # scheduler unten einbauen
    # use callback instead
    'estimator__lr':[0.01]
    #'max_epochs':  scrandint(10,20), # corresponds to num_rounds
}

## Set Seed

In [3]:
def seed_torch(seed=42):
    """Sets all seeds within torch and adjacent libraries.

    Args:
        seed: Random seed to be used by the seeding functions.

    Returns:
        None
    """
    random.seed(seed)
    #os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    return None


class FixSeed(Callback):
    def __init__(self, seed):
        self.seed = seed

    def initialize(self):
        seed_torch(self.seed)
        return super().initialize()

## Load Data

In [4]:
# data = load_flchain(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=True)
# print(type(data.data),type(data.target))
# X, y = data.data, data.target
# #X, y = data.data.astype(np.float64), data.target.astype(np.float64)
# #one_hot_dict = {'load_flchain':[''], 'load_support':['cancer'], 'load_rgbsg':['grade']}
# #X = pd.get_dummies(X, columns=['cancer'])
# #X = pd.get_dummies(X, columns=['grade'])
# #X = pd.get_dummies(X, columns=['mgus'])
# X.mgus.value

## Set Loss Function

In [28]:
# Define Scorer
def custom_scoring_function(y_true, y_pred, duration_bins):

        #y_true = torch.from_numpy(y_true)
        if isinstance(y_pred, np.ndarray):
            y_pred = torch.from_numpy(y_pred)
        if isinstance(y_true, np.ndarray):
            y_true = torch.from_numpy(y_true)
        if isinstance(y_pred, pd.Series):
            y_pred = torch.tensor(y_pred.values)
        if isinstance(y_true, pd.Series):
            y_true = torch.tensor(y_true.values)
        print('score functions types',type(y_true), type(y_pred), type(duration_bins) )
        score = deephit_likelihood_1_torch(y_true, y_pred, duration_bins).to(torch.float32)
        return score.numpy()

scoring_function = make_scorer(custom_scoring_function, greater_is_better=False)

In [29]:
## Set up Custom Splitter

## Set Torch Model

In [44]:

class SurvivalModel(nn.Module):
    def __init__(self, n_layers, input_units, num_nodes, dropout, out_features):
        super(SurvivalModel, self).__init__()
        self.n_layers = n_layers
        self.in_features = input_units
        self.num_nodes = num_nodes
        self.dropout = dropout
        self.out_features = out_features
        model = []
        # first layer
        model.append(torch.nn.Linear(input_units, num_nodes))
        model.append(torch.nn.ReLU())
        model.append(torch.nn.Dropout(dropout))
        model.append(torch.nn.BatchNorm1d(num_nodes))

        for i in range(n_layers-1):
            model.append(torch.nn.Linear(num_nodes, num_nodes))
            #init.kaiming_normal_(model[-1].weight, nonlinearity='relu')
            model.append(torch.nn.ReLU())
            model.append(torch.nn.Dropout(dropout))
            model.append(torch.nn.BatchNorm1d(num_nodes))

        # output layer
        model.append(torch.nn.Linear(num_nodes, out_features))
        model.append(torch.nn.Softmax(dim=1))
    
        self.layers = nn.Sequential(*model)

        # for layer in self.layers:
        #     if isinstance(layer, nn.Linear):
        #         #nn.init.uniform_(layer.weight, a=-0.5, b=0.5)
        #         nn.init.kaiming_normal_(layer.weight)


    def forward(self, X):
        X = X.to(torch.float32)
        res = self.layers(X)
        #print(res)
        return res


## Set up Scaler

In [45]:
class CustomStandardScaler(StandardScaler):
    
    def __init__(self, copy=True, with_mean=True, with_std=True):
        super().__init__(copy=copy, with_mean=with_mean, with_std=with_std)
        
    def fit(self, X, y=None):
        return super().fit(X, y)
    
    def transform(self, X, y=None):
        X_transformed = super().transform(X, y)
        return X_transformed.astype(np.float32)
    
    def fit_transform(self, X, y=None):
        X_transformed = super().fit_transform(X, y)
        return X_transformed.astype(np.float32)

## Custom Split

In [46]:
# Define stratified inner k-fold cross-validation
class CustomSplit(StratifiedKFold):
    def __init__(self, n_splits=5, shuffle=True, random_state=None):
        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)

    def split(self, X, y, groups=None):
        print('split', X.dtypes)
        try:
            if y.shape[1]>1:
                y = y[:,0]
        except:
            pass
        bins = np.sign(y)
        return super().split(X, bins, groups=groups)

    def get_n_splits(self, X=None, y=None, groups=None):
        return self.n_splits

outer_custom_cv = CustomSplit(n_splits=n_outer_splits, shuffle=True, random_state=rand_state)
inner_custom_cv = CustomSplit(n_splits=n_outer_splits, shuffle=True, random_state=rand_state)



## Setting Training Procedure

In [47]:
# custom transform allowing for zero time value due to deephit
def transform(time, event):

    #if isinstance(time, pd.Series):
    #    time = time.to_numpy()
    #    event = event.to_numpy()
    event_mod = np.copy(event) 
    event_mod[event_mod==0] = -1
    y = event_mod*time
    return y

In [48]:
import skorch.callbacks
class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_units=X.shape[-1])

In [53]:
data_set_fns = [load_metabric,  load_flchain, load_rgbsg, load_support] #, load_flchain, load_rgbsg, load_support, load_tcga]
data_set_fns_str = ['load_metabric', 'load_flchain', 'load_rgbsg', 'load_support'] 
one_hot_dict = {'load_flchain': ['mgus'], 'load_support':['cancer'], 'load_rgbsg':['grade']}

#n_iter = 10
n_cuts = 10
for idx, dataset in enumerate(data_set_fns):

        model = '_deephit_'
        num_durations = 10
        data = dataset(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=True)
        X  = data.data
        y = data.target
        if data_set_fns_str[idx] in one_hot_dict.keys():
                X = pd.get_dummies(X, columns=one_hot_dict[data_set_fns_str[idx]])
        filename = data.filename
        dataset_name = filename.split('_')[0]
        # add IBS later
        outer_scores = {'cindex_train_'+dataset_name:[], 'cindex_test_'+dataset_name:[],
                        'ibs_train_'+dataset_name:[], 'ibs_test_'+dataset_name:[]}
        best_params = {'best_params_'+dataset_name:[]}
        best_model = {'best_model_'+dataset_name:[]}
        
        ct = make_column_transformer(
                (StandardScaler(), make_column_selector(dtype_include=['float32'])),
                #(OneHotEncoder(sparse_output=False), make_column_selector(dtype_include=['category', 'object']))
                remainder='passthrough')

        # Deephit transformation
        labtrans = DeepHitSingle.label_transform(num_durations)

        for i, (train_index, test_index) in enumerate(outer_custom_cv.split(X, y)):
                # Split data into training and testing sets for outer fold
                X_train, X_test = X.iloc[train_index], X.iloc[test_index]
                y_train, y_test = y.iloc[train_index], y.iloc[test_index]

                # transform training data as in deephit implementation by Kvamme
                time, event = transform_back(y_train.values)
                y_train_inter = labtrans.fit_transform(time, event)
                y_train_final = pd.Series(transform(y_train_inter[0],y_train_inter[1]))
                cuts = torch.from_numpy(labtrans.cuts)
                X_train, y_train = sort_X_y_pandas(X_train, y_train_final)
                X_test, y_test = sort_X_y_pandas(X_test, y_test)

                print('dtype X_train', X_train.dtypes)
                print('dtype y_train', y_train.dtypes)
                print('dtype X_test', X_test.dtypes)
                print('dtype y_test', y_test.dtypes)
                # save splits and data
                savetxt('splits/train_index_'+str(i)+'_'+filename, train_index, delimiter=',')
                savetxt('splits/test_index_'+str(i)+'_'+filename, test_index, delimiter=',')
                
                savetxt('splits/X_train_'+str(i)+'_'+filename, X_train, delimiter=',')
                savetxt('splits/X_test_'+str(i)+'_'+filename, X_test, delimiter=',')

                savetxt('splits/y_train_'+str(i)+'_'+filename, y_train, delimiter=',')
                savetxt('splits/y_test_'+str(i)+'_'+filename, y_test, delimiter=',')
                net = NeuralNet(
                SurvivalModel, 
                module__n_layers = 1,
                module__input_units = X.shape[1],
                #module__idx_durations = 
                #module__num_nodes = 32,
                #module__dropout = 0.1, # these could also be removed
                module__out_features = labtrans.out_features,
                # for split sizes when result size = 1
                iterator_train__drop_last=True,
                #iterator_valid__drop_last=True,
                criterion=DeephitLoss(duration_bins=cuts),
                optimizer=torch.optim.AdamW,
                optimizer__weight_decay = 0.4,
                batch_size=32, # separate train and valid->iterator_train__batch_size=128 and iterator_valid__batch_size=128 ?
                callbacks=[
                (
                        "sched",
                        LRScheduler(
                        torch.optim.lr_scheduler.ReduceLROnPlateau,
                        monitor="valid_loss",
                        patience=5,
                        ),
                ),
                (
                        "es",
                        EarlyStopping(
                        monitor="valid_loss",
                        patience=20, # increase for deephit
                        load_best=True,
                        ),
                ),
                ("seed", FixSeed(seed=42)),
                ("Input Shape Setter",InputShapeSetter())
                ],
                
                #[EarlyStopping(patience=10)],
                # add extensive callback, and random number seed
                #TODO: enable stratification, verify
                train_split=ValidSplit(0.2), # might cause lower performance in metrics, explain in thesis
                #lr=0.001,
                max_epochs=50, #0,#100
                #train_split=None,
                verbose=1
                )
                #strat = np.sign(y_train.values)
                #valid_split = ValidSplit(cv=0.1, stratified=strat, random_state=42)
                
                pipe = Pipeline([('scaler',ct),
                        ('estimator', net)])
                
                custom_scoring = partial(custom_scoring_function,duration_bins=cuts)
                
                rs = RandomizedSearchCV(pipe, param_grid_breslow, scoring = make_scorer(custom_scoring, greater_is_better=False), n_jobs=-1, 
                                    n_iter=2, refit=True)


                rs.fit(X_train, y_train)
                best_preds_train = rs.best_estimator_.predict(X_train)
                best_preds_test = rs.best_estimator_.predict(X_test)
                # save hyperparameter settings
                params = rs.best_estimator_.get_params
                best_params['best_params_'+dataset_name] += [rs.best_params_]
                best_model['best_model_'+dataset_name] += [params]
                surv_func_train = 1-np.cumsum(best_preds_train)

                try:
                    #use survival func formula here von kvamme paper

                    df_survival_train = 1-np.cumsum(best_preds_train)
                    durations_train, events_train = transform_back(y_train.values)
                    time_grid_train = np.linspace(durations_train.min(), durations_train.max(), 100)
                    ev = EvalSurv(df_survival_train, durations_train, events_train, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_train))
                    cindex_score_train = ev.concordance_td('antolini')
                    ibs_score_train = ev.integrated_brier_score(time_grid_train)

                    outer_scores['cindex_train_'+dataset_name] += [cindex_score_train]
                    outer_scores['ibs_train_'+dataset_name] += [ibs_score_train]

                except:
                    outer_scores['cindex_train_'+dataset_name] += [np.nan]
                    outer_scores['ibs_train_'+dataset_name] += [np.nan]
                    
                try:
                    df_survival_test = 1-np.cumsum(best_preds_test)
                    durations_test, events_test = transform_back(y_test.values)
                    print('durations',durations_test.min(), durations_test.max())
                    time_grid_test = np.linspace(durations_test.min(), durations_test.max(), 100)
                    ev = EvalSurv(df_survival_test, durations_test, events_test, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_test))
                    cindex_score_test = ev.concordance_td('antolini')
                    ibs_score_test = ev.integrated_brier_score(time_grid_test)

                    outer_scores['cindex_test_'+dataset_name] += [cindex_score_test]
                    outer_scores['ibs_test_'+dataset_name] += [ibs_score_test]
                except: 
                    outer_scores['cindex_test_'+dataset_name] += [np.nan]
                    outer_scores['ibs_test_'+dataset_name] += [np.nan]
            
        df_best_params = pd.DataFrame(best_params)
        df_best_model = pd.DataFrame(best_model)
        df_outer_scores = pd.DataFrame(outer_scores)
        df_metrics = pd.concat([df_best_params,df_best_model,df_outer_scores], axis=1)
        df_metrics.to_csv('metrics/metric_summary'+model+str(i)+'_'+filename, index=False)


                

split MKI67                float32
EGFR                 float32
PGR                  float32
ERBB2                float32
hormone_treatment    float32
radiotherapy         float32
chemotherapy         float32
ER_positive          float32
age                  float32
dtype: object
dtype X_train MKI67                float32
EGFR                 float32
PGR                  float32
ERBB2                float32
hormone_treatment    float32
radiotherapy         float32
chemotherapy         float32
ER_positive          float32
age                  float32
dtype: object
dtype y_train float64
dtype X_test MKI67                float32
EGFR                 float32
PGR                  float32
ERBB2                float32
hormone_treatment    float32
radiotherapy         float32
chemotherapy         float32
ER_positive          float32
age                  float32
dtype: object
dtype y_test float32
Re-initializing module because the following parameters were re-set: module__input_units.
Re-initia

In [54]:
surv_func_train = 1-np.cumsum(best_preds_train, axis=1)
pd.DataFrame(surv_func_train)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,0.999999,0.000003,0.000003,0.000002,0.000002,0.000001,0.000001,8.940697e-07,4.768372e-07,0.000000e+00
1,0.999989,0.000070,0.000060,0.000051,0.000042,0.000034,0.000027,1.960993e-05,1.001358e-05,-1.192093e-07
2,0.999996,0.000022,0.000020,0.000017,0.000014,0.000011,0.000009,6.377697e-06,3.397465e-06,1.192093e-07
3,0.999984,0.000104,0.000090,0.000077,0.000064,0.000052,0.000042,2.932549e-05,1.525879e-05,0.000000e+00
4,0.999999,0.000007,0.000006,0.000005,0.000004,0.000003,0.000003,2.026558e-06,1.072884e-06,0.000000e+00
...,...,...,...,...,...,...,...,...,...,...
7094,0.999985,0.000101,0.000087,0.000075,0.000062,0.000050,0.000040,2.831221e-05,1.448393e-05,1.192093e-07
7095,0.999977,0.000158,0.000137,0.000118,0.000098,0.000079,0.000063,4.404783e-05,2.235174e-05,-2.384186e-07
7096,0.999977,0.000160,0.000139,0.000119,0.000099,0.000080,0.000063,4.446507e-05,2.259016e-05,5.960464e-08
7097,0.999977,0.000157,0.000136,0.000116,0.000097,0.000078,0.000062,4.392862e-05,2.211332e-05,0.000000e+00


In [61]:
np.set_printoptions(suppress=True)
np.cumsum(best_preds_train[0])

array([0.00000058, 0.999997  , 0.99999744, 0.99999785, 0.9999982 ,
       0.9999985 , 0.9999988 , 0.9999991 , 0.9999995 , 1.        ],
      dtype=float32)

In [62]:
np.cumsum(best_preds_test[0])

array([0.00000134, 0.99999243, 0.9999933 , 0.9999943 , 0.99999523,
       0.9999961 , 0.9999969 , 0.9999978 , 0.99999875, 0.9999999 ],
      dtype=float32)

pass time classification bins in the loss class, add the bins to the loss init and add it to self, then in the forward pass it to the loss function

In [None]:
data_set_fns = [load_metabric,  load_flchain, load_rgbsg, load_support] #, load_flchain, load_rgbsg, load_support, load_tcga]
data_set_fns_str = ['load_metabric', 'load_flchain', 'load_rgbsg', 'load_support'] 
one_hot_dict = {'load_flchain': ['mgus'], 'load_support':['cancer'], 'load_rgbsg':['grade']}

n_cuts = 10
for idx, dataset in enumerate(data_set_fns):
    # get name of current dataset
    data = dataset(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=True)
    X  = data.data #.astype(np.float32)
    time, event = transform_back(data.target.values)
    min_time = time.min()
    X['time'] = time
    X['event'] = event
    y = data.target #.values #.to_numpy()
    df = pd.concat([X,y])
    df = discretizer_df(df=X, n_cuts=n_cuts, type = 'equidistant', min_time=min_time)
    target = transform(df.time.values, df.event.values)
    y = pd.Series(target)
    X = df.drop(['time', 'event'], axis=1)
    print(data_set_fns_str[idx])
    if data_set_fns_str[idx] in one_hot_dict.keys():
        X = pd.get_dummies(X, columns=one_hot_dict[data_set_fns_str[idx]])
    X, y = sort_X_y_pandas(X, y)
    print(X.shape,y.shape)
    
    net = NeuralNet(
        SurvivalModel, 
        module__n_layers = 1,
        module__input_units = X.shape[1],
        #module__idx_durations = 
        #module__num_nodes = 32,
        #module__dropout = 0.1, # these could also be removed
        module__out_features = n_cuts,
        # for split sizes when result size = 1
        iterator_train__drop_last=True,
        #iterator_valid__drop_last=True,
        criterion=DeephitLoss,
        optimizer=torch.optim.AdamW,
        optimizer__weight_decay = 0.4,
        batch_size=32, # separate train and valid->iterator_train__batch_size=128 and iterator_valid__batch_size=128 ?
        callbacks=[
            (
                "sched",
                LRScheduler(
                    torch.optim.lr_scheduler.ReduceLROnPlateau,
                    monitor="valid_loss",
                    patience=5,
                ),
            ),
            (
                "es",
                EarlyStopping(
                    monitor="valid_loss",
                    patience=10,
                    load_best=True,
                ),
            ),
            ("seed", FixSeed(seed=42)),
        ],
        
        #[EarlyStopping(patience=10)],
        # add extensive callback, and random number seed
        #TODO: enable stratification, verify
        train_split=ValidSplit(0.2), # might cause lower performance in metrics, explain in thesis
        #lr=0.001,
        #max_epochs=1, #0,#100
        #train_split=None,
        verbose=1
    )
    best_model,params, outer_scores, best_preds_train, best_preds_test, X_train, X_test, y_train, y_test = train_eval(X, y, net, n_iter, data.filename)


load_metabric
(1903, 9) (1903,)
split MKI67                float32
EGFR                 float32
PGR                  float32
ERBB2                float32
hormone_treatment    float32
radiotherapy         float32
chemotherapy         float32
ER_positive          float32
age                  float32
dtype: object
(1522, 9) <class 'pandas.core.frame.DataFrame'>
(1522,) <class 'pandas.core.series.Series'>
(381, 9) <class 'pandas.core.frame.DataFrame'>
(381,) <class 'pandas.core.series.Series'>
  epoch    valid_loss     dur
-------  ------------  ------
      1      [36m395.7535[0m  0.0086
  epoch    valid_loss     dur
-------  ------------  ------
      1      [36m397.0852[0m  0.0092
  epoch    valid_loss     dur
-------  ------------  ------
      1      [36m351.2593[0m  0.0093
  epoch    valid_loss     dur
-------  ------------  ------
      1      [36m396.6680[0m  0.0130
      2      395.7535  0.0074
  epoch    valid_loss     dur
-------  ------------  ------
      1      [36m3

In [None]:
data = load_metabric(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", as_frame=True)
num_durations = 10
labtrans = DeepHitSingle.label_transform(num_durations)
X = data.data
time, event = transform_back(data.target.values)
labtrans = DeepHitSingle.label_transform(num_durations)
#get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = labtrans.fit_transform(time, event)
y_train

(array([1, 0, 0, ..., 8, 9, 9]),
 array([1., 0., 0., ..., 0., 1., 1.], dtype=float32))

## TCGA

In [None]:
param_grid_breslow_tcga = {
    'estimator__module__n_layers': [1, 2, 4],
    'estimator__module__n_layers': [1, 2, 4],
    'estimator__module__num_nodes': [64, 128, 256, 512],
    'estimator__module__dropout': scuniform(0.0,0.7),
    'estimator__optimizer__weight_decay': [0.4, 0.2, 0.1, 0.05, 0.02, 0.01, 0],
    'estimator__batch_size': [64, 128, 256, 512, 1024],
    #lr not in paper because of learning rate finder
    # note: setting learning rate higher would make exp(partial_hazard) explode
    #'estimator__lr': scloguniform(0.001,0.01), # scheduler unten einbauen
    # use callback instead
    'estimator__lr':[0.01],
    'estimator__max_epochs':  scrandint(10,20), # corresponds to num_rounds
    'pca__n_components': [8, 10, 12, 14, 16]
}

In [None]:

def train_eval(X, y, net, n_iter, filename):

        dataset_name = filename.split('_')[0]
        # add IBS later
        outer_scores = {'cindex_train_'+dataset_name:[], 'cindex_test_'+dataset_name:[],
                        'ibs_train_'+dataset_name:[], 'ibs_test_'+dataset_name:[]}
        best_params = {'best_params_'+dataset_name:[]}
        best_model = {'best_model_'+dataset_name:[]}
        ct = make_column_transformer(
                #(OneHotEncoder(sparse_output=False), make_column_selector(dtype_include=['category', 'object']))
                (StandardScaler(), make_column_selector(dtype_include=['float32']))
                ,remainder='drop')
        pipe = Pipeline([('scaler',ct),
                         ('pca', PCA()),#n_components=10
                        ('estimator', net)])
        rs = RandomizedSearchCV(pipe, param_grid_breslow_tcga, scoring = scoring_function, n_jobs=-1, 
                                    n_iter=2, refit=True)
        for i, (train_index, test_index) in enumerate(outer_custom_cv.split(X, y)):
                # Split data into training and testing sets for outer fold
                X_train, X_test = X.iloc[train_index], X.iloc[test_index]
                y_train, y_test = y.iloc[train_index], y.iloc[test_index]
                X_train, y_train = sort_X_y_pandas(X_train, y_train)
                X_test, y_test = sort_X_y_pandas(X_test, y_test)

                print(X_train.shape, type(X_train))
                print(y_train.shape, type(y_train))
                print(X_test.shape, type(X_test))
                print(y_test.shape, type(y_test))
                # save splits and data
                savetxt('splits/train_index_'+str(i)+'_'+filename, train_index, delimiter=',')
                savetxt('splits/test_index_'+str(i)+'_'+filename, test_index, delimiter=',')
                
                savetxt('splits/X_train_'+str(i)+'_'+filename, X_train, delimiter=',')
                savetxt('splits/X_test_'+str(i)+'_'+filename, X_test, delimiter=',')

                savetxt('splits/y_train_'+str(i)+'_'+filename, y_train, delimiter=',')
                savetxt('splits/y_test_'+str(i)+'_'+filename, y_test, delimiter=',')

                strat = np.sign(y_train)
                valid_split = ValidSplit(cv=0.1, stratified=strat, random_state=42)




                rs.fit(X_train, y_train)
                best_preds_train = rs.best_estimator_.predict(X_train)
                best_preds_test = rs.best_estimator_.predict(X_test)
                # save hyperparameter settings
                params = rs.best_estimator_.get_params
                best_params['best_params_'+dataset_name] += [rs.best_params_]
                best_model['best_model_'+dataset_name] += [params]
                try:
                    cum_hazard_train = get_cumulative_hazard_function_breslow(
                            X_train.values, X_train.values, y_train.values, y_train.values,
                            best_preds_train.reshape(-1), best_preds_train.reshape(-1)
                            )

                    df_survival_train = np.exp(-cum_hazard_train)
                    durations_train, events_train = transform_back(y_train.values)
                    time_grid_train = np.linspace(durations_train.min(), durations_train.max(), 100)
                    ev = EvalSurv(df_survival_train, durations_train, events_train, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_train))
                    cindex_score_train = ev.concordance_td('antolini')
                    ibs_score_train = ev.integrated_brier_score(time_grid_train)

                    outer_scores['cindex_train_'+dataset_name] += [cindex_score_train]
                    outer_scores['ibs_train_'+dataset_name] += [ibs_score_train]

                except:
                    outer_scores['cindex_train_'+dataset_name] += [np.nan]
                    outer_scores['ibs_train_'+dataset_name] += [np.nan]
                    
                try:
                    cum_hazard_test = get_cumulative_hazard_function_breslow(
                            X_train.values, X_test.values, y_train.values, y_test.values,
                            best_preds_train.reshape(-1), best_preds_test.reshape(-1)
                            )
                    df_survival_test = np.exp(-cum_hazard_test)
                    durations_test, events_test = transform_back(y_test.values)
                    print('durations',durations_test.min(), durations_test.max())
                    time_grid_test = np.linspace(durations_test.min(), durations_test.max(), 100)
                    ev = EvalSurv(df_survival_test, durations_test, events_test, censor_surv='km')
                    print('Concordance Index',ev.concordance_td('antolini'))
                    print('Integrated Brier Score:',ev.integrated_brier_score(time_grid_test))
                    cindex_score_test = ev.concordance_td('antolini')
                    ibs_score_test = ev.integrated_brier_score(time_grid_test)

                    outer_scores['cindex_test_'+dataset_name] += [cindex_score_test]
                    outer_scores['ibs_test_'+dataset_name] += [ibs_score_test]
                except: 
                    outer_scores['cindex_test_'+dataset_name] += [np.nan]
                    outer_scores['ibs_test_'+dataset_name] += [np.nan]
            
        df_best_params = pd.DataFrame(best_params)
        df_best_model = pd.DataFrame(best_model)
        df_outer_scores = pd.DataFrame(outer_scores)
        df_metrics = pd.concat([df_best_params,df_best_model,df_outer_scores], axis=1)
        df_metrics.to_csv('metrics/metric_summary_'+str(i)+'_'+filename, index=False)
        return best_model, best_params, outer_scores, best_preds_train, best_preds_test, X_train, X_test, y_train, y_test

                
#cv=inner_custom_cv,pipe

In [None]:
cancer_types = ['BLCA',
    'BRCA',
    'HNSC',
    'KIRC',
    'LGG',
    'LIHC',
    'LUAD',
    'LUSC',
    'OV',
    'STAD']
import skorch.callbacks

class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_units=X.shape[-1])

for idx, cancer_type in enumerate(cancer_types):
    # get name of current dataset
    data = load_tcga(path="/Users/JUSC/Documents/xgbsurv/xgbsurv/datasets/data/", cancer_type=cancer_type, as_frame=True)
    X  = data.data #.astype(np.float32)
    y = data.target #.values #.to_numpy()

    X, y = sort_X_y_pandas(X, y)

    net = NeuralNet(
        SurvivalModel, 
        module__n_layers = 1,
        module__input_units = X.shape[1],
        #module__num_nodes = 32,
        #module__dropout = 0.1, # these could also be removed
        module__out_features = 1,
        # for split sizes when result size = 1
        iterator_train__drop_last=True,
        #iterator_valid__drop_last=True,
        criterion=DeephitLoss,
        optimizer=torch.optim.AdamW,
        optimizer__weight_decay = 0.4,
        batch_size=32, # separate train and valid->iterator_train__batch_size=128 and iterator_valid__batch_size=128 ?
        callbacks=[
            (
                "sched",
                LRScheduler(
                    torch.optim.lr_scheduler.ReduceLROnPlateau,
                    monitor="valid_loss",
                    patience=5,
                ),
            ),
            (
                "es",
                EarlyStopping(
                    monitor="valid_loss",
                    patience=10,
                    load_best=True,
                ),
            ),
            ("seed", FixSeed(seed=42)),
            ("Inpout Shape Setter",InputShapeSetter())
        ],#[EarlyStopping(patience=10)],#,InputShapeSetter()],
        #TODO: enable stratification, verify
        train_split=ValidSplit(0.2), # might cause lower performance in metrics, explain in thesis
        #lr=0.001,
        #max_epochs=1, #0,#100
        #train_split=None,
        verbose=1
    )
    best_model,params, outer_scores, best_preds_train, best_preds_test, X_train, X_test, y_train, y_test = train_eval(X, y, net, n_iter, data.filename)

split gex_?|100130426      float32
gex_?|100133144      float32
gex_?|100134869      float32
gex_?|10357          float32
gex_?|10431          float32
                      ...   
gex_ZYG11A|440590    float32
gex_ZYG11B|79699     float32
gex_ZYX|7791         float32
gex_ZZEF1|23140      float32
gex_ZZZ3|26009       float32
Length: 20531, dtype: object
(324, 20531) <class 'pandas.core.frame.DataFrame'>
(324,) <class 'pandas.core.series.Series'>
(82, 20531) <class 'pandas.core.frame.DataFrame'>
(82,) <class 'pandas.core.series.Series'>
Re-initializing module because the following parameters were re-set: module__input_units.
Re-initializing criterion.
Re-initializing optimizer.
Re-initializing module because the following parameters were re-set: module__input_units.
Re-initializing criterion.
Re-initializing optimizer.
Re-initializing module because the following parameters were re-set: module__input_units.
Re-initializing criterion.
Re-initializing optimizer.
Re-initializing module becau

ValueError: 
All the 10 fits failed.
It is very likely that your model is misconfigured.
You can try to debug the error by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 119`, but got `phi.shape[1] = 2`

--------------------------------------------------------------------------------
2 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 118`, but got `phi.shape[1] = 2`

--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 112`, but got `phi.shape[1] = 2`

--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 116`, but got `phi.shape[1] = 2`

--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 62`, but got `phi.shape[1] = 2`

--------------------------------------------------------------------------------
3 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 63`, but got `phi.shape[1] = 2`

--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/sklearn/pipeline.py", line 405, in fit
    self._final_estimator.fit(Xt, y, **fit_params_last_step)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1230, in fit
    self.partial_fit(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1189, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1101, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1137, in run_single_epoch
    step = step_fn(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1016, in train_step
    self._step_optimizer(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 972, in _step_optimizer
    optimizer.step(step_fn)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/optim/adamw.py", line 119, in step
    loss = closure()
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1006, in step_fn
    step = self.train_step_single(batch, **fit_params)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 906, in train_step_single
    loss = self.get_loss(y_pred, yi, X=Xi, training=True)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/skorch/net.py", line 1571, in get_loss
    return self.criterion_(y_pred, y_true)
  File "/Users/JUSC/miniconda3/envs/xgbsurv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 288, in forward
    loss = deephit_likelihood_1_torch(input, prediction)
  File "/Users/JUSC/Documents/xgbsurv/experiments/deep_learning/loss_functions_pytorch.py", line 259, in deephit_likelihood_1_torch
    raise ValueError(f"Network output `phi` is too small for `idx_durations`."+
ValueError: Network output `phi` is too small for `idx_durations`. Need at least `phi.shape[1] = 61`, but got `phi.shape[1] = 2`
