# Test on FLCHAIN

In [29]:
import numpy

In [30]:
import sklearn

In [31]:
import survwrap

In [32]:
survwrap.list_available_datasets()

('flchain', 'gbsg2', 'metabric', 'support')

In [33]:
mb_df = survwrap.get_data('flchain')
mb_df.dataframe.info()

<class 'pandas.core.frame.DataFrame'>
Index: 6524 entries, 0 to 6523
Data columns (total 10 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   age         6524 non-null   float64
 1   sex         6524 non-null   float64
 2   sample.yr   6524 non-null   int64  
 3   kappa       6524 non-null   float64
 4   lambda      6524 non-null   float64
 5   flc.grp     6524 non-null   int64  
 6   creatinine  6524 non-null   float64
 7   mgus        6524 non-null   float64
 8   time        6524 non-null   float64
 9   event       6524 non-null   float64
dtypes: float64(8), int64(2)
memory usage: 560.7 KB


In [34]:
X, y = mb_df.get_X_y()
X.shape, y.shape

((6524, 8), (6524,))

In [35]:
y[:10]

array([( True,   85.), ( True, 1281.), ( True,   69.), ( True,  115.),
       ( True, 1039.), ( True, 1355.), ( True, 2851.), ( True,  372.),
       ( True, 3309.), ( True, 1326.)],
      dtype=[('event', '?'), ('time', '<f8')])

### Generate a (stratified) train-test split and Scale the features (only) 

First do the stratified splitting THEN do scaling, parameterized on X_train set ONLY 

In [36]:
from sklearn.preprocessing import StandardScaler, RobustScaler

In [37]:
X_train, X_test, y_train, y_test = survwrap.survival_train_test_split(X, y, rng_seed=2309)

In [38]:
scaler = StandardScaler().fit(X_train)
[X_train, X_test] = [ scaler.transform(_) for _ in  [X_train, X_test] ]
X_train.shape, X_test.shape

((4893, 8), (1631, 8))

In [39]:
survwrap.get_indicator(y).sum(), survwrap.get_indicator(y_train).sum(), survwrap.get_indicator(y_test).sum(),


(1962, 1472, 490)

## check possible dimensionality reduction

In [40]:
from sklearn.decomposition import PCA

In [41]:
pca= PCA(n_components=0.995, random_state=2308).fit(X_train)
print('PCA components:',pca.n_components_)

PCA components: 8


Massive (50%) reduction using PCA

In [42]:
## Stratified CV spliter for survival analysis

In [43]:
#from sklearn.model_selection import RepeatedStratifiedKFold, StratifiedKFold

In [44]:
#testkf= RepeatedStratifiedKFold(n_splits=5,n_repeats=2,random_state=2307)
#for trn,tst in testkf.split(X_train, survwrap.get_indicator(y_train)):
#    print(trn,tst) 

# test coxnet

In [45]:
rng_seed=2401

In [46]:
coxnet = survwrap.CoxNet(rng_seed)
coxnet.fit(X_train, y_train)

In [47]:
coxnet.score(X_test, y_test)

0.7927079761717836

In [48]:
#from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

In [49]:
cox_grid = survwrap.CoxNet().get_parameter_grid()
cox_grid['l1_ratio'] = [0.5]
cox_grid

{'alpha': [0.001,
  0.003,
  0.005,
  0.008,
  0.01,
  0.02,
  0.03,
  0.04,
  0.05,
  0.06,
  0.07,
  0.08,
  0.09,
  0.1,
  0.15,
  0.2,
  0.3],
 'l1_ratio': [0.5]}

In [50]:
opt_coxnet, opt_coxnet_params, opt_coxnet_search = survwrap.optimize(survwrap.CoxNet(rng_seed), X_train, y_train, 
                                                                    user_grid=cox_grid, n_jobs=4)
opt_coxnet.score(X_test, y_test), opt_coxnet_params

(0.793126155601201, {'alpha': 0.02, 'l1_ratio': 0.5})

In [51]:
if opt_coxnet_search.scoring: print('multi score') 
else: print('Single score')

Single score


In [52]:
#opt_coxnet_search.cv_results_
survwrap.get_model_scores_df(opt_coxnet_search)

['rank_test_score', 'mean_test_score', 'std_test_score', 'params', 'mean_fit_time', 'std_fit_time']


Unnamed: 0,rank_test_score,mean_test_score,std_test_score,params,mean_fit_time,std_fit_time
5,1,0.791243,0.011211,"{'alpha': 0.02, 'l1_ratio': 0.5}",0.476548,0.050124
6,2,0.791213,0.011123,"{'alpha': 0.03, 'l1_ratio': 0.5}",0.48222,0.030772
4,3,0.791023,0.011584,"{'alpha': 0.01, 'l1_ratio': 0.5}",0.452358,0.015244
3,4,0.791018,0.011652,"{'alpha': 0.008, 'l1_ratio': 0.5}",0.449075,0.010527
0,5,0.790998,0.011742,"{'alpha': 0.001, 'l1_ratio': 0.5}",0.446674,0.011657
1,5,0.790998,0.011742,"{'alpha': 0.003, 'l1_ratio': 0.5}",0.448055,0.010208
2,5,0.790998,0.011742,"{'alpha': 0.005, 'l1_ratio': 0.5}",0.443161,0.009915
7,8,0.790919,0.011149,"{'alpha': 0.04, 'l1_ratio': 0.5}",0.47715,0.043137
8,9,0.790656,0.011157,"{'alpha': 0.05, 'l1_ratio': 0.5}",0.459422,0.022498
9,10,0.790189,0.011183,"{'alpha': 0.06, 'l1_ratio': 0.5}",0.445735,0.006774


# Test FastCPH

In [53]:
fl_lasso=survwrap.FastCPH(rng_seed)
fl_lasso

In [54]:
#fl_dsm.layer_sizes=[3]
#fl_dsm.learning_rate=0.01
fl_lasso_xplore=fl_lasso.fit(X_train,y_train)
fl_lasso_xplore.score(X_train,y_train)

  return torch.empty(output_size, device=input.device).scatter_reduce(


epoch: 0
loss: 8.593567848205566
epoch: 1
loss: 8.584389686584473
epoch: 2
loss: 8.575233459472656
epoch: 3
loss: 8.5661039352417
epoch: 4
loss: 8.556997299194336
epoch: 5
loss: 8.547910690307617
epoch: 6
loss: 8.538849830627441
epoch: 7
loss: 8.529813766479492
epoch: 8
loss: 8.520800590515137
epoch: 9
loss: 8.51181411743164
epoch: 10
loss: 8.502851486206055
epoch: 11
loss: 8.493910789489746
epoch: 12
loss: 8.484992027282715
epoch: 13
loss: 8.476089477539062
epoch: 14
loss: 8.46721076965332
epoch: 15
loss: 8.458356857299805
epoch: 16
loss: 8.449514389038086
epoch: 17
loss: 8.440692901611328
epoch: 18
loss: 8.431884765625
epoch: 19
loss: 8.42309284210205
epoch: 20
loss: 8.414307594299316
epoch: 21
loss: 8.40553092956543
epoch: 22
loss: 8.396761894226074
epoch: 23
loss: 8.387993812561035
epoch: 24
loss: 8.379231452941895
epoch: 25
loss: 8.370476722717285
epoch: 26
loss: 8.361723899841309
epoch: 27
loss: 8.35296630859375
epoch: 28
loss: 8.344212532043457
epoch: 29
loss: 8.33544921875
epoc

0.7942605159620818

In [63]:
fl_grid =fl_lasso.get_parameter_grid(max_width=X_train.shape[1])
fl_grid

{'layer_sizes': [[8], [8, 8], [8, 8, 8], [8, 8, 8, 8]]}

In [56]:
# Stratified CV
#opt_dsm, opt_dsm_params, opt_dsm_search = optimize(survwrap.DeepSurvivalMachines(rng_seed=2308),  X_train, y_train, n_jobs=8,
                                                  # user_grid=grid,cv=RepeatedStratifiedKFold(n_splits=5, n_repeats=2, random_state=2308).split(X_train,survwrap.get_indicator(y_train)))
#opt_dsm.score(X_test, y_test), opt_dsm_params

In [64]:
# Stratified CV
fl_lasso_cv=survwrap.FastCPH(rng_seed)
#fl_grid=fl_lasso_cv.get_parameter_grid(max_width=X_train.shape[1])
fl_grid['layer_sizes']=[[8]*_ for _ in range(2,5)]
fl_grid

{'layer_sizes': [[8, 8], [8, 8, 8], [8, 8, 8, 8]]}

In [65]:
opt_lasso, opt_lasso_params, opt_lasso_search = survwrap.optimize(fl_lasso_cv,  X_train, y_train, 
                                                                  user_grid=fl_grid,
                                                                 cv=survwrap.survival_crossval_splitter(X_train,y_train,
                                                                                                        n_repeats=1,n_splits=3),
                                                                 )

epoch: 0
loss: 7.954215049743652
epoch: 1
loss: 7.949530124664307
epoch: 2
loss: 7.944864749908447
epoch: 3
loss: 7.940215587615967
epoch: 4
loss: 7.935583114624023
epoch: 5
loss: 7.93096923828125
epoch: 6
loss: 7.926385402679443
epoch: 7
loss: 7.921855449676514
epoch: 8
loss: 7.917341709136963
epoch: 9
loss: 7.912844181060791
epoch: 10
loss: 7.9083662033081055
epoch: 11
loss: 7.903908729553223
epoch: 12
loss: 7.899469375610352
epoch: 13
loss: 7.895051002502441
epoch: 14
loss: 7.890653133392334
epoch: 15
loss: 7.886277675628662
epoch: 16
loss: 7.8819260597229
epoch: 17
loss: 7.877594947814941
epoch: 18
loss: 7.87328577041626
epoch: 19
loss: 7.869001865386963
epoch: 20
loss: 7.864734649658203
epoch: 21
loss: 7.860491752624512
epoch: 22
loss: 7.856273651123047
epoch: 23
loss: 7.852082252502441
epoch: 24
loss: 7.8479204177856445
epoch: 25
loss: 7.843785285949707
epoch: 26
loss: 7.839672088623047
epoch: 27
loss: 7.835582256317139
epoch: 28
loss: 7.831511497497559
epoch: 29
loss: 7.82745695

In [66]:
survwrap.get_model_scores_df(opt_lasso_search)

['rank_test_score', 'mean_test_score', 'std_test_score', 'params', 'mean_fit_time', 'std_fit_time']


Unnamed: 0,rank_test_score,mean_test_score,std_test_score,params,mean_fit_time,std_fit_time
2,1,0.792478,0.013033,"{'layer_sizes': [8, 8, 8, 8]}",19.106784,0.668043
1,2,0.791197,0.014476,"{'layer_sizes': [8, 8, 8]}",17.164177,0.206789
0,3,0.790994,0.011804,"{'layer_sizes': [8, 8]}",16.022369,2.096882


In [72]:
survwrap.get_model_scores_df(opt_lasso_search).shape

['rank_test_score', 'mean_test_score', 'std_test_score', 'params', 'mean_fit_time', 'std_fit_time']


(3, 6)

In [67]:
opt_lasso.score(X_test, y_test), opt_lasso_params

(0.7922897967423662, {'layer_sizes': [8, 8, 8, 8]})

In [77]:
from survwrap.metrics import *
scorers = {
    #'c-index-td': concordance_index_td_scorer,
    'neg-brier': make_time_dependent_scorer(neg_brier_score, time_mode='quantiles', time_values=[0.25, 0.5, 0.75]),
    'auc': make_time_dependent_scorer(roc_auc_td_score, time_mode='quantiles', time_values=[0.25, 0.5, 0.75]),
    'c-index-median': make_time_dependent_scorer(concordance_index_score, time_mode='quantiles', time_values=[0.5]),
}

#best_model, best_params, search_results = survwrap.optimize(deephit, X, y, mode='sklearn-random', user_grid=dict(dropout=[0.0, 0.2, 0.9]), scoring=scoring, tries=3, refit='c-index-td')

In [78]:
opt_lasso, opt_lasso_params, opt_lasso_search = survwrap.optimize(fl_lasso_cv,  X_train, y_train, 
                                                                  user_grid=fl_grid,
                                                                  scoring=scorers,
                                                                  refit='c-index-median',
                                                                  cv=survwrap.survival_crossval_splitter(X_train,y_train,
                                                                                                        n_repeats=1,n_splits=3),
                                                                 )

epoch: 0
loss: 7.9480061531066895
epoch: 1
loss: 7.943586826324463
epoch: 2
loss: 7.939194202423096
epoch: 3
loss: 7.934821605682373
epoch: 4
loss: 7.930466651916504
epoch: 5
loss: 7.92613410949707
epoch: 6
loss: 7.921833515167236
epoch: 7
loss: 7.9175848960876465
epoch: 8
loss: 7.913358211517334
epoch: 9
loss: 7.909154415130615
epoch: 10
loss: 7.9049811363220215
epoch: 11
loss: 7.900837421417236
epoch: 12
loss: 7.896726608276367
epoch: 13
loss: 7.892647743225098
epoch: 14
loss: 7.888603210449219
epoch: 15
loss: 7.8845977783203125
epoch: 16
loss: 7.880639553070068
epoch: 17
loss: 7.876720428466797
epoch: 18
loss: 7.872848987579346
epoch: 19
loss: 7.869026184082031
epoch: 20
loss: 7.865254878997803
epoch: 21
loss: 7.861538887023926
epoch: 22
loss: 7.857884407043457
epoch: 23
loss: 7.854283809661865
epoch: 24
loss: 7.850744724273682
epoch: 25
loss: 7.847270965576172
epoch: 26
loss: 7.843865871429443
epoch: 27
loss: 7.8405232429504395
epoch: 28
loss: 7.837236404418945
epoch: 29
loss: 7.83

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 7.980625629425049
epoch: 1
loss: 7.976261138916016
epoch: 2
loss: 7.971918106079102
epoch: 3
loss: 7.967597007751465
epoch: 4
loss: 7.963296890258789
epoch: 5
loss: 7.959023952484131
epoch: 6
loss: 7.9547834396362305
epoch: 7
loss: 7.9506096839904785
epoch: 8
loss: 7.946456432342529
epoch: 9
loss: 7.942324161529541
epoch: 10
loss: 7.9382171630859375
epoch: 11
loss: 7.9341325759887695
epoch: 12
loss: 7.9300689697265625
epoch: 13
loss: 7.926029682159424
epoch: 14
loss: 7.922010898590088
epoch: 15
loss: 7.918015480041504
epoch: 16
loss: 7.91403865814209
epoch: 17
loss: 7.910078525543213
epoch: 18
loss: 7.906129837036133
epoch: 19
loss: 7.902195930480957
epoch: 20
loss: 7.898271560668945
epoch: 21
loss: 7.894362449645996
epoch: 22
loss: 7.89046049118042
epoch: 23
loss: 7.886563777923584
epoch: 24
loss: 7.882673263549805
epoch: 25
loss: 7.878788471221924
epoch: 26
loss: 7.874904632568359
epoch: 27
loss: 7.871020317077637
epoch: 28
loss: 7.867135047912598
epoch: 29
loss: 7.863

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 7.965539932250977
epoch: 1
loss: 7.9612345695495605
epoch: 2
loss: 7.956954479217529
epoch: 3
loss: 7.952699661254883
epoch: 4
loss: 7.948470115661621
epoch: 5
loss: 7.944273471832275
epoch: 6
loss: 7.940122127532959
epoch: 7
loss: 7.936063289642334
epoch: 8
loss: 7.932027816772461
epoch: 9
loss: 7.928022384643555
epoch: 10
loss: 7.924047470092773
epoch: 11
loss: 7.92010498046875
epoch: 12
loss: 7.9161906242370605
epoch: 13
loss: 7.91231107711792
epoch: 14
loss: 7.908459186553955
epoch: 15
loss: 7.904630184173584
epoch: 16
loss: 7.900825500488281
epoch: 17
loss: 7.897034168243408
epoch: 18
loss: 7.893259048461914
epoch: 19
loss: 7.88949728012085
epoch: 20
loss: 7.885750770568848
epoch: 21
loss: 7.882022857666016
epoch: 22
loss: 7.878310203552246
epoch: 23
loss: 7.874615669250488
epoch: 24
loss: 7.870940685272217
epoch: 25
loss: 7.867282867431641
epoch: 26
loss: 7.863636493682861
epoch: 27
loss: 7.859999179840088
epoch: 28
loss: 7.85637092590332
epoch: 29
loss: 7.85274982

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 8.061453819274902
epoch: 1
loss: 8.060834884643555
epoch: 2
loss: 8.057311058044434
epoch: 3
loss: 8.053888320922852
epoch: 4
loss: 8.050497055053711
epoch: 5
loss: 8.04713249206543
epoch: 6
loss: 8.04378890991211
epoch: 7
loss: 8.040467262268066
epoch: 8
loss: 8.037162780761719
epoch: 9
loss: 8.033878326416016
epoch: 10
loss: 8.030611038208008
epoch: 11
loss: 8.027360916137695
epoch: 12
loss: 8.024127006530762
epoch: 13
loss: 8.02091121673584
epoch: 14
loss: 8.017707824707031
epoch: 15
loss: 8.01451587677002
epoch: 16
loss: 8.011332511901855
epoch: 17
loss: 8.008157730102539
epoch: 18
loss: 8.004988670349121
epoch: 19
loss: 8.001825332641602
epoch: 20
loss: 7.998663902282715
epoch: 21
loss: 7.995502948760986
epoch: 22
loss: 7.992340564727783
epoch: 23
loss: 7.98917818069458
epoch: 24
loss: 7.9860124588012695
epoch: 25
loss: 7.9828410148620605
epoch: 26
loss: 7.979665279388428
epoch: 27
loss: 7.976484775543213
epoch: 28
loss: 7.973296165466309
epoch: 29
loss: 7.970098972

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 8.06873893737793
epoch: 1
loss: 8.06854248046875
epoch: 2
loss: 8.065099716186523
epoch: 3
loss: 8.061758041381836
epoch: 4
loss: 8.058449745178223
epoch: 5
loss: 8.055169105529785
epoch: 6
loss: 8.051919937133789
epoch: 7
loss: 8.048697471618652
epoch: 8
loss: 8.04549789428711
epoch: 9
loss: 8.042318344116211
epoch: 10
loss: 8.039155960083008
epoch: 11
loss: 8.036008834838867
epoch: 12
loss: 8.032876968383789
epoch: 13
loss: 8.02975845336914
epoch: 14
loss: 8.026654243469238
epoch: 15
loss: 8.0235595703125
epoch: 16
loss: 8.020471572875977
epoch: 17
loss: 8.017390251159668
epoch: 18
loss: 8.014311790466309
epoch: 19
loss: 8.011234283447266
epoch: 20
loss: 8.008159637451172
epoch: 21
loss: 8.005084991455078
epoch: 22
loss: 8.002012252807617
epoch: 23
loss: 7.998935699462891
epoch: 24
loss: 7.995856285095215
epoch: 25
loss: 7.992777347564697
epoch: 26
loss: 7.989687919616699
epoch: 27
loss: 7.986591815948486
epoch: 28
loss: 7.983488082885742
epoch: 29
loss: 7.980379104614

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 8.067981719970703
epoch: 1
loss: 8.06786823272705
epoch: 2
loss: 8.064461708068848
epoch: 3
loss: 8.061158180236816
epoch: 4
loss: 8.057890892028809
epoch: 5
loss: 8.054657936096191
epoch: 6
loss: 8.051458358764648
epoch: 7
loss: 8.048286437988281
epoch: 8
loss: 8.045141220092773
epoch: 9
loss: 8.042023658752441
epoch: 10
loss: 8.038932800292969
epoch: 11
loss: 8.035862922668457
epoch: 12
loss: 8.032809257507324
epoch: 13
loss: 8.029773712158203
epoch: 14
loss: 8.026754379272461
epoch: 15
loss: 8.023750305175781
epoch: 16
loss: 8.020756721496582
epoch: 17
loss: 8.017770767211914
epoch: 18
loss: 8.014790534973145
epoch: 19
loss: 8.011817932128906
epoch: 20
loss: 8.008849143981934
epoch: 21
loss: 8.005878448486328
epoch: 22
loss: 8.002908706665039
epoch: 23
loss: 7.99993896484375
epoch: 24
loss: 7.996967792510986
epoch: 25
loss: 7.99399471282959
epoch: 26
loss: 7.991018772125244
epoch: 27
loss: 7.988039493560791
epoch: 28
loss: 7.985051155090332
epoch: 29
loss: 7.982053756

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 9.783014297485352
epoch: 1
loss: 10.193513870239258
epoch: 2
loss: 10.150776863098145
epoch: 3
loss: 10.107815742492676
epoch: 4
loss: 10.0646390914917
epoch: 5
loss: 10.02124309539795
epoch: 6
loss: 9.977622985839844
epoch: 7
loss: 9.93378734588623
epoch: 8
loss: 9.889730453491211
epoch: 9
loss: 9.845455169677734
epoch: 10
loss: 9.800962448120117
epoch: 11
loss: 9.756248474121094
epoch: 12
loss: 9.711299896240234
epoch: 13
loss: 9.666139602661133
epoch: 14
loss: 9.620768547058105
epoch: 15
loss: 9.575199127197266
epoch: 16
loss: 9.529437065124512
epoch: 17
loss: 9.48399543762207
epoch: 18
loss: 9.439083099365234
epoch: 19
loss: 9.393828392028809
epoch: 20
loss: 9.348305702209473
epoch: 21
loss: 9.302557945251465
epoch: 22
loss: 9.256616592407227
epoch: 23
loss: 9.210502624511719
epoch: 24
loss: 9.164252281188965
epoch: 25
loss: 9.117876052856445
epoch: 26
loss: 9.07140064239502
epoch: 27
loss: 9.024852752685547
epoch: 28
loss: 8.978261947631836
epoch: 29
loss: 8.9316606

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 8.7959566116333
epoch: 1
loss: 8.856557846069336
epoch: 2
loss: 8.831271171569824
epoch: 3
loss: 8.806096076965332
epoch: 4
loss: 8.781013488769531
epoch: 5
loss: 8.756102561950684
epoch: 6
loss: 8.731256484985352
epoch: 7
loss: 8.70648193359375
epoch: 8
loss: 8.681793212890625
epoch: 9
loss: 8.65733528137207
epoch: 10
loss: 8.633013725280762
epoch: 11
loss: 8.60880184173584
epoch: 12
loss: 8.584724426269531
epoch: 13
loss: 8.560802459716797
epoch: 14
loss: 8.537020683288574
epoch: 15
loss: 8.513387680053711
epoch: 16
loss: 8.489893913269043
epoch: 17
loss: 8.466545104980469
epoch: 18
loss: 8.443344116210938
epoch: 19
loss: 8.42029094696045
epoch: 20
loss: 8.397398948669434
epoch: 21
loss: 8.37466812133789
epoch: 22
loss: 8.352090835571289
epoch: 23
loss: 8.329673767089844
epoch: 24
loss: 8.307425498962402
epoch: 25
loss: 8.285356521606445
epoch: 26
loss: 8.263474464416504
epoch: 27
loss: 8.241759300231934
epoch: 28
loss: 8.220166206359863
epoch: 29
loss: 8.1987800598144

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 10.075550079345703
epoch: 1
loss: 10.445169448852539
epoch: 2
loss: 10.402426719665527
epoch: 3
loss: 10.359468460083008
epoch: 4
loss: 10.316298484802246
epoch: 5
loss: 10.272933006286621
epoch: 6
loss: 10.229354858398438
epoch: 7
loss: 10.18556022644043
epoch: 8
loss: 10.141554832458496
epoch: 9
loss: 10.097345352172852
epoch: 10
loss: 10.052946090698242
epoch: 11
loss: 10.008323669433594
epoch: 12
loss: 9.963456153869629
epoch: 13
loss: 9.918378829956055
epoch: 14
loss: 9.873096466064453
epoch: 15
loss: 9.827605247497559
epoch: 16
loss: 9.781903266906738
epoch: 17
loss: 9.735997200012207
epoch: 18
loss: 9.689888000488281
epoch: 19
loss: 9.643571853637695
epoch: 20
loss: 9.597040176391602
epoch: 21
loss: 9.55030345916748
epoch: 22
loss: 9.50336742401123
epoch: 23
loss: 9.4562349319458
epoch: 24
loss: 9.408894538879395
epoch: 25
loss: 9.36123275756836
epoch: 26
loss: 9.31342887878418
epoch: 27
loss: 9.26538372039795
epoch: 28
loss: 9.21712875366211
epoch: 29
loss: 9.168

Traceback (most recent call last):
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 187, in scorer
    return score_func(y, y_pred, pred_times, **kwargs)
  File "/usr/local/ivan/Unito/survwrap/survwrap/metrics.py", line 11, in concordance_index_score
    r = sksurv.metrics.concordance_index_censored(
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
    event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 47, in _check_inputs
    estimate = _check_estimate_1d(estimate, event_time)
  File "/usr/local/ivan/Unito/conda/envs/hive/lib/python3.8/site-packages/sksurv/metrics.py", line 38,

epoch: 0
loss: 8.370159149169922
epoch: 1
loss: 8.365800857543945
epoch: 2
loss: 8.361469268798828
epoch: 3
loss: 8.357159614562988
epoch: 4
loss: 8.352876663208008
epoch: 5
loss: 8.348620414733887
epoch: 6
loss: 8.344400405883789
epoch: 7
loss: 8.340253829956055
epoch: 8
loss: 8.336128234863281
epoch: 9
loss: 8.33202838897705
epoch: 10
loss: 8.327960014343262
epoch: 11
loss: 8.323922157287598
epoch: 12
loss: 8.319914817810059
epoch: 13
loss: 8.315943717956543
epoch: 14
loss: 8.312003135681152
epoch: 15
loss: 8.308094024658203
epoch: 16
loss: 8.304222106933594
epoch: 17
loss: 8.30038070678711
epoch: 18
loss: 8.296566009521484
epoch: 19
loss: 8.292780876159668
epoch: 20
loss: 8.289022445678711
epoch: 21
loss: 8.285293579101562
epoch: 22
loss: 8.281590461730957
epoch: 23
loss: 8.277910232543945
epoch: 24
loss: 8.27425765991211
epoch: 25
loss: 8.270627975463867
epoch: 26
loss: 8.267013549804688
epoch: 27
loss: 8.263416290283203
epoch: 28
loss: 8.259832382202148
epoch: 29
loss: 8.256258010

In [79]:
survwrap.get_model_scores_df(opt_lasso_search)

['rank_test_neg-brier', 'mean_test_neg-brier', 'std_test_neg-brier', 'rank_test_auc', 'mean_test_auc', 'std_test_auc', 'rank_test_c-index-median', 'mean_test_c-index-median', 'std_test_c-index-median', 'params', 'mean_fit_time', 'std_fit_time']


Unnamed: 0,rank_test_neg-brier,mean_test_neg-brier,std_test_neg-brier,rank_test_auc,mean_test_auc,std_test_auc,rank_test_c-index-median,mean_test_c-index-median,std_test_c-index-median,params,mean_fit_time,std_fit_time
0,1,-0.093461,0.002908,3,0.81575,0.016025,1,,,"{'layer_sizes': [8, 8]}",13.56029,0.48902
1,2,-0.093614,0.00262,2,0.81639,0.01445,1,,,"{'layer_sizes': [8, 8, 8]}",13.543517,0.405181
2,3,-0.094147,0.002876,1,0.817227,0.015398,1,,,"{'layer_sizes': [8, 8, 8, 8]}",14.749032,0.223881
