In [1]:
# Import libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
import re
import pickle

import os
path_dir = os.path.dirname(os.getcwd())

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

import torch # For building the networks 
import torchtuples as tt # Some useful functions

from pycox.datasets import metabric
from pycox.models import LogisticHazard, PMF, DeepHitSingle, CoxPH
from pycox.evaluation import EvalSurv

%load_ext autoreload
%autoreload 2

In [2]:
cd ../src/

/Users/linafaik/Documents/survival_analysis/src


In [3]:
from train import *
from train_survival_ml import *
from train_survival_deep import *

In [4]:
# We also set some seeds to make this reproducable.
# Note that on gpu, there is still some randomness.
np.random.seed(1234)
_ = torch.manual_seed(123)

In [5]:
# Parameters

scaler_name = "StandardScaler" #MinMaxScaler
random_state = 123
test_size = 0.3

In [67]:
df = pd.read_csv(os.path.join(path_dir, "outputs/hdhi_clean.csv"))


Columns (1) have mixed types.Specify dtype option on import or set low_memory=False.



# 1. Train / test split

In [68]:
# covariate columns (used when possible)

cols_x = [
    'age', 'gender', 'rural',
    'duration_of_stay', 'duration_of_intensive_unit_stay', 
    'smoking','alcohol', 'dm', 'htn', 'cad', 'prior_cmp', 'ckd', 'hb', 'tlc',
    'platelets', 'glucose', 'urea', 'creatinine', 'raised_cardiac_enzymes',
    'severe_anaemia', 'anaemia', 'stable_angina', 'acs', 'stemi',
    'atypical_chest_pain', 'heart_failure', 'hfref', 'hfnef', 'valvular',
    'chb', 'sss', 'aki', 'cva_infract', 'cva_bleed', 'af', 'vt', 'psvt',
    'congenital', 'uti', 'neuro_cardiogenic_syncope', 'orthostatic',
    'infective_endocarditis', 'dvt', 'cardiogenic_shock', 'shock',
    'pulmonary_embolism', 'chest_infection',
    'type_adm', 
    'first_visit', 'nb_visits',
    'duration_of_stay_lag1', 'duration_of_intensive_unit_stay_lag1',
    'cardiogenic_shock_lag1', 'cad_lag1', 'time_before_readm_lag1'
]

col_target = "time_before_readm"

In [69]:
Xy_train, Xy_test, y_train, y_test = split_train_test(
    df, cols_x+['doa'], col_target, test_size=0.15, random_state=random_state)

Xy_train, Xy_val, y_train, y_val = split_train_test(
    Xy_train, cols_x, col_target, test_size=0.2, random_state=random_state)

n_train, n_test, n_val = Xy_train.shape[0], Xy_test.shape[0], Xy_val.shape[0]
n_tot =  n_train + n_test + n_val

print("Train: {}%, Test: {}%, Val: {}%".format(
    round(n_train/n_tot *100),
    round(n_test/n_tot *100),
    round(n_val/n_tot *100)
))

Train: 68%, Test: 15%, Val: 17%


In [70]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# rescale
scaler = eval(scaler_name)()

Xy_train[cols_x] = scaler.fit_transform(Xy_train[cols_x])
Xy_test[cols_x] = scaler.transform(Xy_test[cols_x])

In [71]:
#with open(os.path.join(path_dir, "outputs/cox_ph.pkl"), "rb") as f:
    #estimator = pickle.load(f)

In [72]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

# train an estimator
estimator = CoxPHSurvivalAnalysis(alpha=0.5)
estimator = estimator.fit(Xy_train[cols_x], y_train)

In [73]:
estimator.score(Xy_test[cols_x], y_test)

0.6339186792998365

In [74]:
feat_importance, fig = plot_feat_imp(cols_x, estimator.coef_)
fig

# 2. DeepSurv

Source: https://nbviewer.org/github/havakv/pycox/blob/master/examples/cox-ph.ipynb

In [75]:
n_top = 40
cols_x_reduc = list(feat_importance.feature.iloc[:n_top])
cols_x_reduc

['dvt',
 'type_adm',
 'gender',
 'rural',
 'atypical_chest_pain',
 'chb',
 'sss',
 'age',
 'cva_bleed',
 'severe_anaemia',
 'pulmonary_embolism',
 'chest_infection',
 'aki',
 'psvt',
 'congenital',
 'neuro_cardiogenic_syncope',
 'acs',
 'heart_failure',
 'valvular',
 'orthostatic',
 'vt',
 'duration_of_stay_lag1',
 'hfref',
 'stable_angina',
 'htn',
 'cad_lag1',
 'shock',
 'anaemia',
 'infective_endocarditis',
 'first_visit',
 'alcohol',
 'duration_of_intensive_unit_stay_lag1',
 'cva_infract',
 'smoking',
 'tlc',
 'platelets',
 'stemi',
 'prior_cmp',
 'cardiogenic_shock_lag1',
 'raised_cardiac_enzymes']

In [76]:
get_target = lambda df: (df[col_target].values, df['censored'].values)

y_train = get_target(Xy_train)
y_val = get_target(Xy_val)
y_test = get_target(Xy_test)

train = (np.array(Xy_train[cols_x_reduc]).astype(np.float32), y_train)
val = (np.array(Xy_val[cols_x_reduc]).astype(np.float32), y_val)
test = (np.array(Xy_test[cols_x_reduc]).astype(np.float32), y_test)

In [77]:
params = {
    'n_nodes': 264,
    'n_layers': 4,
    'dropout': 0.4,
    'lr':0.005
}


logs_df, model, score = train_deep_surv(
    train, val, test, CoxPH, out_features = 1, tolerance=10, 
    print_lr=True, print_logs=True, verbose = True,
    **params
)

print('score', score)

concordance: 0.49542856307029476
0:	[5s / 5s],		train_loss: 1.9072,	val_loss: 2.4611
1:	[3s / 8s],		train_loss: 1.8467,	val_loss: 3.9352
2:	[2s / 11s],		train_loss: 1.8224,	val_loss: 2.4810
3:	[2s / 14s],		train_loss: 1.8034,	val_loss: 4.5627
4:	[2s / 17s],		train_loss: 1.7895,	val_loss: 3.2483
concordance: 0.5466598272906236
5:	[2s / 20s],		train_loss: 1.7875,	val_loss: 1.8318
6:	[2s / 22s],		train_loss: 1.7824,	val_loss: 2.4341
7:	[2s / 25s],		train_loss: 1.7837,	val_loss: 1.8555
8:	[2s / 27s],		train_loss: 1.7812,	val_loss: 1.8770
9:	[2s / 30s],		train_loss: 1.7823,	val_loss: 2.6103
concordance: 0.4911468917406464
10:	[3s / 33s],		train_loss: 1.7771,	val_loss: 4.4214
11:	[2s / 36s],		train_loss: 1.7844,	val_loss: 1.9348
12:	[3s / 39s],		train_loss: 1.7786,	val_loss: 1.9218
13:	[3s / 42s],		train_loss: 1.7778,	val_loss: 1.8841
14:	[2s / 45s],		train_loss: 1.7759,	val_loss: 4.0899
concordance: 0.5081631181098708
15:	[2s / 48s],		train_loss: 1.7763,	val_loss: 2.0168
16:	[2s / 50s],		tr

score 0.6078104479324722


In [78]:
surv = model.predict_surv_df(test[0])
ev = EvalSurv(surv, test[1][0], test[1][1])
score = ev.concordance_td()
score

0.6078104479324722

In [51]:
grid_params = {
    "n_nodes" :[232, 248, 264],
    "n_layers" :[4, 5],
    "dropout" :[0.3, 0.4],
    "lr" :[0.01, 0.005, 0.001]
}

best_model, table = grid_search_deep(train, val, test, 1, grid_params, CoxPH)

36 total scenario to run
1/36: params: {'n_nodes': 232, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.01}
concordance: 0.5048211044180049
0:	[1s / 1s],		train_loss: 1.9120,	val_loss: 13.9515
1:	[1s / 3s],		train_loss: 1.8116,	val_loss: 2.4332
2:	[1s / 5s],		train_loss: 1.7955,	val_loss: 2.0563
3:	[1s / 7s],		train_loss: 1.7892,	val_loss: 5.8767
4:	[1s / 9s],		train_loss: 1.7911,	val_loss: 5.4681
concordance: 0.4930594960146512
5:	[1s / 10s],		train_loss: 1.7903,	val_loss: 3.2110
6:	[1s / 12s],		train_loss: 1.7906,	val_loss: 1.9036
7:	[1s / 14s],		train_loss: 1.7905,	val_loss: 2.4391
8:	[1s / 16s],		train_loss: 1.7866,	val_loss: 3.9835
9:	[1s / 18s],		train_loss: 1.7884,	val_loss: 2.4450
concordance: 0.49485205814276645
10:	[1s / 20s],		train_loss: 1.7899,	val_loss: 4.1043
11:	[1s / 22s],		train_loss: 1.7880,	val_loss: 2.3542
12:	[1s / 24s],		train_loss: 1.7883,	val_loss: 1.9747
13:	[1s / 25s],		train_loss: 1.7827,	val_loss: 4.2597
14:	[1s / 27s],		train_loss: 1.7879,	val_loss: 4.5053



overflow encountered in exp



concordance: 0.5093204151844236
15:	[2s / 29s],		train_loss: 1.7810,	val_loss: 9690.1289
16:	[2s / 31s],		train_loss: 1.7838,	val_loss: 3.5557
17:	[1s / 33s],		train_loss: 1.7809,	val_loss: 2.5852
18:	[1s / 35s],		train_loss: 1.7841,	val_loss: 1.9255
19:	[1s / 37s],		train_loss: 1.7797,	val_loss: 2.2638
concordance: 0.4995942272841675
20:	[2s / 39s],		train_loss: 1.7796,	val_loss: 60.4871
21:	[2s / 41s],		train_loss: 1.7798,	val_loss: 4.0647
Current score: 0.593094852538029 vs. best score: -100
2/36: params: {'n_nodes': 232, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.5115538043524861
0:	[2s / 2s],		train_loss: 1.8780,	val_loss: 7.6443
1:	[1s / 4s],		train_loss: 1.8288,	val_loss: 5.6912
2:	[1s / 6s],		train_loss: 1.8067,	val_loss: 1.9125
3:	[1s / 7s],		train_loss: 1.7920,	val_loss: 10.3598
4:	[1s / 9s],		train_loss: 1.7826,	val_loss: 2.0246
concordance: 0.503790658602659
5:	[1s / 11s],		train_loss: 1.7813,	val_loss: 2.2488
6:	[1s / 13s],		train_loss: 1.7779,	val_loss: 2.


overflow encountered in exp



concordance: 0.501403938290267
30:	[2s / 1m:7s],		train_loss: 1.7531,	val_loss: 15.1942
31:	[1s / 1m:9s],		train_loss: 1.7516,	val_loss: 29.0351
32:	[2s / 1m:11s],		train_loss: 1.7504,	val_loss: 275.2091
33:	[1s / 1m:13s],		train_loss: 1.7502,	val_loss: 205.6651
34:	[1s / 1m:15s],		train_loss: 1.7507,	val_loss: 26.4903
concordance: 0.4993334634630545
35:	[1s / 1m:16s],		train_loss: 1.7504,	val_loss: 539.6981
36:	[1s / 1m:18s],		train_loss: 1.7503,	val_loss: 3.3641
Current score: 0.6099321662182666 vs. best score: 0.593094852538029
3/36: params: {'n_nodes': 232, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.001}
concordance: 0.5093738793334333
0:	[2s / 2s],		train_loss: 1.8662,	val_loss: 4.0174
1:	[2s / 4s],		train_loss: 1.8068,	val_loss: 5.7973
2:	[1s / 6s],		train_loss: 1.8063,	val_loss: 5.2317
3:	[1s / 8s],		train_loss: 1.7995,	val_loss: 6.6659
4:	[1s / 9s],		train_loss: 1.7968,	val_loss: 2.8747
concordance: 0.49768666679780516
5:	[1s / 11s],		train_loss: 1.7943,	val_loss: 1.8551
6:	[1s / 1


overflow encountered in exp



concordance: 0.5032393726133427
15:	[2s / 37s],		train_loss: 1.7923,	val_loss: 605.8121
16:	[2s / 40s],		train_loss: 1.7842,	val_loss: 4.4483
17:	[1s / 41s],		train_loss: 1.7890,	val_loss: 951.9289
18:	[2s / 43s],		train_loss: 1.7912,	val_loss: 2.0384
19:	[1s / 45s],		train_loss: 1.7866,	val_loss: 44.9756



overflow encountered in exp



concordance: 0.501364344557274
20:	[2s / 48s],		train_loss: 1.7904,	val_loss: 315.3583
21:	[1s / 49s],		train_loss: 1.7893,	val_loss: 192737.7812
22:	[1s / 51s],		train_loss: 1.7941,	val_loss: 4135.8096
23:	[2s / 54s],		train_loss: 1.7865,	val_loss: 6815.6870
Current score: 0.6009151822714116 vs. best score: 0.6099321662182666
5/36: params: {'n_nodes': 232, 'n_layers': 4, 'dropout': 0.4, 'lr': 0.005}
concordance: 0.4909199212967376
0:	[2s / 2s],		train_loss: 1.9059,	val_loss: 11.8731
1:	[2s / 4s],		train_loss: 1.8448,	val_loss: 15.0202
2:	[1s / 5s],		train_loss: 1.8175,	val_loss: 7.3371
3:	[4s / 10s],		train_loss: 1.8056,	val_loss: 2.1543
4:	[2s / 13s],		train_loss: 1.7890,	val_loss: 3.0656
concordance: 0.4968958008954741
5:	[1s / 15s],		train_loss: 1.7870,	val_loss: 2.0165
6:	[1s / 17s],		train_loss: 1.7862,	val_loss: 1.8368
7:	[1s / 19s],		train_loss: 1.7810,	val_loss: 2.3614
8:	[2s / 21s],		train_loss: 1.7835,	val_loss: 3.9041
9:	[2s / 23s],		train_loss: 1.7799,	val_loss: 2.1017
con


overflow encountered in exp



concordance: 0.5026633720645787
0:	[2s / 2s],		train_loss: 1.9234,	val_loss: 28.8466
1:	[2s / 4s],		train_loss: 1.8198,	val_loss: 18.0954
2:	[2s / 6s],		train_loss: 1.7999,	val_loss: 3.6162
3:	[2s / 8s],		train_loss: 1.7959,	val_loss: 7.9408
4:	[2s / 10s],		train_loss: 1.7941,	val_loss: 2.0629
concordance: 0.5147616532930638
5:	[2s / 12s],		train_loss: 1.7890,	val_loss: 1.8312
6:	[2s / 14s],		train_loss: 1.7921,	val_loss: 1.8414
7:	[2s / 16s],		train_loss: 1.7927,	val_loss: 1.8156
8:	[2s / 18s],		train_loss: 1.7920,	val_loss: 1.9264
9:	[2s / 20s],		train_loss: 1.7881,	val_loss: 2.3825
concordance: 0.5278477603313163
10:	[2s / 22s],		train_loss: 1.7890,	val_loss: 1.8572
11:	[2s / 25s],		train_loss: 1.7919,	val_loss: 1.8731
12:	[2s / 27s],		train_loss: 1.7856,	val_loss: 1.8511
13:	[2s / 29s],		train_loss: 1.7858,	val_loss: 1.9774
14:	[2s / 31s],		train_loss: 1.7840,	val_loss: 2.3700
concordance: 0.5167668110703069
15:	[2s / 33s],		train_loss: 1.7878,	val_loss: 1.8436
16:	[2s / 36s],		tra


overflow encountered in exp


overflow encountered in exp



concordance: 0.49669631909421647
20:	[2s / 48s],		train_loss: 1.7887,	val_loss: 32.4788
21:	[2s / 51s],		train_loss: 1.7864,	val_loss: 963.6024
22:	[2s / 53s],		train_loss: 1.7921,	val_loss: 339.2501
Current score: 0.5012776212273834 vs. best score: 0.6099321662182666
11/36: params: {'n_nodes': 232, 'n_layers': 5, 'dropout': 0.4, 'lr': 0.005}
concordance: 0.4950401914218286
0:	[2s / 2s],		train_loss: 1.9018,	val_loss: 2.7344
1:	[2s / 4s],		train_loss: 1.8471,	val_loss: 2.7565
2:	[2s / 7s],		train_loss: 1.8236,	val_loss: 1.9647
3:	[2s / 9s],		train_loss: 1.8015,	val_loss: 2.2229
4:	[2s / 11s],		train_loss: 1.7997,	val_loss: 2.9908
concordance: 0.5281173507808036
5:	[2s / 13s],		train_loss: 1.7922,	val_loss: 2.0933
6:	[2s / 16s],		train_loss: 1.7866,	val_loss: 2.3007
7:	[2s / 18s],		train_loss: 1.7863,	val_loss: 1.8453
8:	[2s / 20s],		train_loss: 1.7845,	val_loss: 1.8233
9:	[2s / 22s],		train_loss: 1.7831,	val_loss: 1.8682
concordance: 0.491707256547719
10:	[2s / 24s],		train_loss: 1.779


overflow encountered in exp



concordance: 0.5191797590784395
20:	[1s / 40s],		train_loss: 1.7813,	val_loss: 870.4969
21:	[1s / 42s],		train_loss: 1.7771,	val_loss: 1537.0574
22:	[2s / 44s],		train_loss: 1.7771,	val_loss: 3293.3372
23:	[4s / 48s],		train_loss: 1.7796,	val_loss: 300.3329
24:	[2s / 51s],		train_loss: 1.7720,	val_loss: 128.6498



overflow encountered in exp



concordance: 0.5097549374898178
25:	[2s / 54s],		train_loss: 1.7812,	val_loss: 1108.7439
Current score: 0.5890021931222665 vs. best score: 0.6153677942545297
14/36: params: {'n_nodes': 248, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.5055791857006603
0:	[2s / 2s],		train_loss: 1.8935,	val_loss: 4.6718
1:	[2s / 4s],		train_loss: 1.8355,	val_loss: 3.9804
2:	[1s / 6s],		train_loss: 1.8104,	val_loss: 9.4833
3:	[2s / 8s],		train_loss: 1.7952,	val_loss: 5.3503
4:	[1s / 10s],		train_loss: 1.7827,	val_loss: 7.5749
concordance: 0.4927755307703831
5:	[1s / 12s],		train_loss: 1.7844,	val_loss: 5.5455
6:	[1s / 14s],		train_loss: 1.7798,	val_loss: 2.4541
7:	[1s / 16s],		train_loss: 1.7809,	val_loss: 2.2806
8:	[1s / 17s],		train_loss: 1.7801,	val_loss: 3.0967
9:	[1s / 19s],		train_loss: 1.7762,	val_loss: 2.1797
concordance: 0.5439127283511808
10:	[1s / 21s],		train_loss: 1.7755,	val_loss: 2.0291
11:	[1s / 23s],		train_loss: 1.7741,	val_loss: 2.0428
12:	[1s / 25s],		train_loss: 1.7746,


overflow encountered in exp


overflow encountered in exp



concordance: 0.001458411196805063
20:	[2s / 42s],		train_loss: 1.7920,	val_loss: 8439.5576
21:	[1s / 44s],		train_loss: 1.7912,	val_loss: 3729.2668
Current score: 0.5013885524117969 vs. best score: 0.6153677942545297
17/36: params: {'n_nodes': 248, 'n_layers': 4, 'dropout': 0.4, 'lr': 0.005}
concordance: 0.48914148177402117
0:	[2s / 2s],		train_loss: 1.8967,	val_loss: 7.5837
1:	[2s / 4s],		train_loss: 1.8482,	val_loss: 3.8125
2:	[1s / 6s],		train_loss: 1.8240,	val_loss: 4.3822
3:	[1s / 7s],		train_loss: 1.8057,	val_loss: 6.4602
4:	[1s / 9s],		train_loss: 1.7959,	val_loss: 1.8564
concordance: 0.5060684331019748
5:	[2s / 12s],		train_loss: 1.7922,	val_loss: 2.0030
6:	[1s / 13s],		train_loss: 1.7841,	val_loss: 2.7014
7:	[1s / 15s],		train_loss: 1.7859,	val_loss: 1.8696
8:	[2s / 18s],		train_loss: 1.7844,	val_loss: 4.1428
9:	[1s / 19s],		train_loss: 1.7825,	val_loss: 3.1867
concordance: 0.496853937458042
10:	[1s / 21s],		train_loss: 1.7809,	val_loss: 2.7702
11:	[2s / 23s],		train_loss: 1.7


overflow encountered in exp


overflow encountered in exp



concordance: 0.17824165492725094
15:	[4s / 1m:16s],		train_loss: 1.7864,	val_loss: 33275.2734
16:	[2s / 1m:19s],		train_loss: 1.7856,	val_loss: 9.1050
17:	[2s / 1m:21s],		train_loss: 1.7800,	val_loss: 114600.7656
18:	[2s / 1m:23s],		train_loss: 1.7846,	val_loss: 4300.4966
19:	[2s / 1m:25s],		train_loss: 1.7843,	val_loss: 2508.3750
concordance: 0.5005964278887158
20:	[2s / 1m:28s],		train_loss: 1.7744,	val_loss: 41116.7383
21:	[2s / 1m:30s],		train_loss: 1.7816,	val_loss: 29923.1875
22:	[2s / 1m:32s],		train_loss: 1.7781,	val_loss: 176.8635
Current score: 0.5843982302013337 vs. best score: 0.6153677942545297
20/36: params: {'n_nodes': 248, 'n_layers': 5, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.4817568722867575
0:	[2s / 2s],		train_loss: 1.8941,	val_loss: 2.0210
1:	[2s / 4s],		train_loss: 1.8375,	val_loss: 4.5994
2:	[2s / 7s],		train_loss: 1.8140,	val_loss: 5.1193
3:	[2s / 9s],		train_loss: 1.7959,	val_loss: 3.1449
4:	[2s / 11s],		train_loss: 1.7909,	val_loss: 3.5841
concordance: 0.5


overflow encountered in exp



concordance: 0.5014160433806087
20:	[4s / 1m:3s],		train_loss: 1.7935,	val_loss: 861.2737
21:	[4s / 1m:8s],		train_loss: 1.7945,	val_loss: 209749.0781
22:	[3s / 1m:12s],		train_loss: 1.7930,	val_loss: 9735.4590
23:	[3s / 1m:15s],		train_loss: 1.7930,	val_loss: 3419.0522
24:	[2s / 1m:17s],		train_loss: 1.7912,	val_loss: 44.0565
Current score: 0.6066950795006821 vs. best score: 0.6153677942545297
23/36: params: {'n_nodes': 248, 'n_layers': 5, 'dropout': 0.4, 'lr': 0.005}
concordance: 0.48640699230368445
0:	[2s / 2s],		train_loss: 1.9063,	val_loss: 2.1718
1:	[2s / 4s],		train_loss: 1.8520,	val_loss: 3.0915
2:	[2s / 7s],		train_loss: 1.8211,	val_loss: 14.2358
3:	[2s / 9s],		train_loss: 1.8098,	val_loss: 7.1432
4:	[2s / 11s],		train_loss: 1.7985,	val_loss: 2.5164
concordance: 0.5113154853863818
5:	[2s / 14s],		train_loss: 1.7887,	val_loss: 2.4239
6:	[2s / 16s],		train_loss: 1.7861,	val_loss: 1.9331
7:	[2s / 19s],		train_loss: 1.7856,	val_loss: 1.8429
8:	[2s / 21s],		train_loss: 1.7854,	val_


overflow encountered in exp



concordance: 0.5022376763875586
15:	[2s / 31s],		train_loss: 1.7863,	val_loss: 17801.2422
16:	[2s / 33s],		train_loss: 1.7803,	val_loss: 2758019.2500
17:	[1s / 35s],		train_loss: 1.7859,	val_loss: 8080084.0000
18:	[2s / 37s],		train_loss: 1.7833,	val_loss: 111783.6094
19:	[1s / 39s],		train_loss: 1.7797,	val_loss: 3203719.5000
Current score: 0.5014303109898377 vs. best score: 0.6153677942545297
26/36: params: {'n_nodes': 264, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.503700879182624
0:	[2s / 2s],		train_loss: 1.9037,	val_loss: 2.3871
1:	[1s / 3s],		train_loss: 1.8398,	val_loss: 2.3965
2:	[1s / 5s],		train_loss: 1.8187,	val_loss: 12.9877
3:	[2s / 8s],		train_loss: 1.7969,	val_loss: 2.2879
4:	[1s / 9s],		train_loss: 1.7903,	val_loss: 2.5231
concordance: 0.5181737756331593
5:	[2s / 12s],		train_loss: 1.7839,	val_loss: 3.3872
6:	[2s / 14s],		train_loss: 1.7792,	val_loss: 2.8689
7:	[1s / 16s],		train_loss: 1.7775,	val_loss: 2.9597
8:	[1s / 17s],		train_loss: 1.7756,	val_los


overflow encountered in exp



concordance: 0.501364344557274
30:	[2s / 1m:1s],		train_loss: 1.7560,	val_loss: 14540.1543
31:	[1s / 1m:3s],		train_loss: 1.7541,	val_loss: 2.0420
32:	[1s / 1m:5s],		train_loss: 1.7503,	val_loss: 17.5307
Current score: 0.5505562497609242 vs. best score: 0.6153677942545297
27/36: params: {'n_nodes': 264, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.001}
concordance: 0.5186340212555299
0:	[2s / 2s],		train_loss: 1.8610,	val_loss: 3.2563
1:	[2s / 4s],		train_loss: 1.8117,	val_loss: 10.5630
2:	[2s / 6s],		train_loss: 1.8064,	val_loss: 2.5651
3:	[2s / 8s],		train_loss: 1.8034,	val_loss: 1.8671
4:	[1s / 10s],		train_loss: 1.7971,	val_loss: 1.8889
concordance: 0.5083832794404624
5:	[2s / 12s],		train_loss: 1.7932,	val_loss: 6.1845
6:	[2s / 14s],		train_loss: 1.7865,	val_loss: 2.8487
7:	[2s / 16s],		train_loss: 1.7899,	val_loss: 3.9920
8:	[2s / 18s],		train_loss: 1.7871,	val_loss: 6.5476
9:	[2s / 20s],		train_loss: 1.7872,	val_loss: 3.1547
concordance: 0.5143571415241418
10:	[1s / 22s],		train_loss:

8:	[2s / 22s],		train_loss: 1.7961,	val_loss: 1.9359
9:	[2s / 24s],		train_loss: 1.7932,	val_loss: 1.8166
concordance: 0.5127736443938048
10:	[2s / 27s],		train_loss: 1.7920,	val_loss: 1.8209
11:	[2s / 29s],		train_loss: 1.7911,	val_loss: 1.8338
12:	[2s / 32s],		train_loss: 1.7911,	val_loss: 97.1710
13:	[2s / 34s],		train_loss: 1.7891,	val_loss: 2.3481
14:	[2s / 37s],		train_loss: 1.7823,	val_loss: 6705097.5000



overflow encountered in exp



concordance: 0.5001954467711437
15:	[2s / 39s],		train_loss: 1.7875,	val_loss: 10973593.0000
16:	[2s / 42s],		train_loss: 1.7894,	val_loss: 563302.3750
17:	[2s / 45s],		train_loss: 1.7852,	val_loss: 640.6959
18:	[2s / 47s],		train_loss: 1.7814,	val_loss: 1328.6193
19:	[2s / 50s],		train_loss: 1.7804,	val_loss: 84.3618
concordance: 0.5013577876333388
20:	[2s / 52s],		train_loss: 1.7865,	val_loss: 923.4852
21:	[2s / 54s],		train_loss: 1.7795,	val_loss: 224336.1094
22:	[2s / 57s],		train_loss: 1.7796,	val_loss: 1585286.6250
23:	[2s / 59s],		train_loss: 1.7785,	val_loss: 3308414.7500
24:	[2s / 1m:1s],		train_loss: 1.7788,	val_loss: 393131.8750
Current score: 0.602295765488926 vs. best score: 0.6153677942545297
32/36: params: {'n_nodes': 264, 'n_layers': 5, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.4889629316914796
0:	[2s / 2s],		train_loss: 1.9032,	val_loss: 12.1491
1:	[2s / 5s],		train_loss: 1.8476,	val_loss: 1.8649
2:	[2s / 7s],		train_loss: 1.8157,	val_loss: 3.2992
3:	[2s / 9s],		trai


overflow encountered in exp



concordance: 0.49496251709213535
15:	[2s / 37s],		train_loss: 1.7980,	val_loss: 3.6090
16:	[2s / 39s],		train_loss: 1.7937,	val_loss: 9.0001
17:	[2s / 41s],		train_loss: 1.7888,	val_loss: 2300.5085
18:	[2s / 44s],		train_loss: 1.7896,	val_loss: 169.1670
19:	[2s / 46s],		train_loss: 1.7958,	val_loss: 3.3558



overflow encountered in exp


overflow encountered in exp



concordance: 0.03136630659066687
20:	[2s / 48s],		train_loss: 1.7940,	val_loss: 662208.3750
21:	[2s / 50s],		train_loss: 1.7922,	val_loss: 638140.8750
22:	[2s / 52s],		train_loss: 1.7888,	val_loss: 8147.6050
23:	[2s / 55s],		train_loss: 1.7915,	val_loss: 5797.0415
Current score: 0.503055707855713 vs. best score: 0.6153677942545297
35/36: params: {'n_nodes': 264, 'n_layers': 5, 'dropout': 0.4, 'lr': 0.005}
concordance: 0.506300447333526
0:	[2s / 2s],		train_loss: 1.9151,	val_loss: 12.8083
1:	[2s / 4s],		train_loss: 1.8574,	val_loss: 4.1540
2:	[2s / 6s],		train_loss: 1.8250,	val_loss: 1.8475
3:	[2s / 8s],		train_loss: 1.8035,	val_loss: 3.5549
4:	[2s / 10s],		train_loss: 1.7969,	val_loss: 2.0482
concordance: 0.4912094347074124
5:	[2s / 12s],		train_loss: 1.7891,	val_loss: 2.9268
6:	[2s / 14s],		train_loss: 1.7842,	val_loss: 2.9816
7:	[2s / 16s],		train_loss: 1.7845,	val_loss: 2.2286
8:	[2s / 18s],		train_loss: 1.7855,	val_loss: 2.1228
9:	[2s / 20s],		train_loss: 1.7827,	val_loss: 2.0713
c

In [52]:
table

Unnamed: 0,n_nodes,n_layers,dropout,lr,score
10,232,5,0.4,0.005,0.615368
1,232,4,0.3,0.005,0.609932
28,264,4,0.4,0.005,0.6097
22,248,5,0.4,0.005,0.609517
19,248,5,0.3,0.005,0.607932
7,232,5,0.3,0.005,0.607726
4,232,4,0.4,0.005,0.607076
34,264,5,0.4,0.005,0.606806
21,248,5,0.4,0.01,0.606695
23,248,5,0.4,0.001,0.606534


In [53]:
table.sort_values(by="score", ascending=False, inplace=True)
best_params = table.drop('score', axis=1).iloc[0].to_dict()
best_score = table.score.iloc[0]

print('Best score: ', best_score)
print('Best params: ')
print(best_params)

Best score:  0.6153677942545297
Best params: 
{'n_nodes': 232.0, 'n_layers': 5.0, 'dropout': 0.4, 'lr': 0.005}


# 3. DeepHit

In [57]:
from pycox.models import DeepHitSingle

In [58]:
num_durations = int(df[col_target].max())
labtrans = DeepHitSingle.label_transform(num_durations)

In [59]:
get_target = lambda df: (df[col_target].values, df['censored'].values)

y_train = labtrans.fit_transform(*get_target(Xy_train))
y_val = labtrans.transform(*get_target(Xy_val))

train = (np.array(Xy_train[cols_x_reduc]).astype(np.float32), y_train)
val = (np.array(Xy_val[cols_x_reduc]).astype(np.float32), y_val)
test = (np.array(Xy_test[cols_x_reduc]).astype(np.float32), y_test)

In [60]:
params = {
    'n_nodes': 236,
    'n_layers': 4,
    'out_features': 729,
    'dropout': 0.1,
    'model_params': {'alpha': 0.2, 'sigma': 0.1, 'duration_index': labtrans.cuts},
    'discrete': True,
}

In [61]:
logs_df, model, score = train_deep_surv(
    train, val, test, DeepHitSingle, tolerance=10, 
    print_lr=True, print_logs=True, verbose = True,
    **params
)

concordance: 0.5030442491096127
0:	[4s / 4s],		train_loss: 33.3355,	val_loss: 35.8477
1:	[3s / 8s],		train_loss: 26.5236,	val_loss: 110.4716
2:	[3s / 11s],		train_loss: 29.6387,	val_loss: 97.1768
3:	[2s / 14s],		train_loss: 25.5052,	val_loss: 248.8673
4:	[2s / 17s],		train_loss: 11.9761,	val_loss: 312.6776
concordance: 0.5010419184825589
5:	[3s / 20s],		train_loss: 22.5591,	val_loss: 267.9500
6:	[2s / 23s],		train_loss: 7.9705,	val_loss: 313.3258
7:	[3s / 26s],		train_loss: 6.4626,	val_loss: 263.2984
8:	[2s / 29s],		train_loss: 5.5661,	val_loss: 582.5009
9:	[2s / 32s],		train_loss: 2.7997,	val_loss: 753.6334
concordance: 0.500992207186588
10:	[2s / 35s],		train_loss: 2.9495,	val_loss: 24518.8633
11:	[2s / 38s],		train_loss: 3.4734,	val_loss: 9061.5703
12:	[3s / 41s],		train_loss: 3.6451,	val_loss: 56414.7539
13:	[2s / 44s],		train_loss: 7.8709,	val_loss: 106712.0547
14:	[2s / 46s],		train_loss: 4.8069,	val_loss: 43522.3359
concordance: 0.5014287581613583
15:	[3s / 50s],		train_loss: 2.

In [62]:
grid_params = {
    'n_nodes': [32],
    'n_layers': [4],
    'dropout': [0.1],
    'model_params': [
        {'alpha': 0.01, 'sigma': 0.2, 'duration_index': labtrans.cuts},
        {'alpha': 0.05, 'sigma': 0.2, 'duration_index': labtrans.cuts},
        {'alpha': 0.1, 'sigma': 0.2, 'duration_index': labtrans.cuts},
    ],
    'epochs': [512],
    'batch_size': [16],
    'discrete': [True],
    'output_bias': [True],
}

best_model, table = grid_search_deep(train, val, test, labtrans.out_features, grid_params, DeepHitSingle)

3 total scenario to run
1/3: params: {'n_nodes': 32, 'n_layers': 4, 'dropout': 0.1, 'model_params': {'alpha': 0.01, 'sigma': 0.2}, 'epochs': 512, 'batch_size': 16, 'discrete': True, 'output_bias': True}
concordance: 0.5123649909333662
0:	[2s / 2s],		train_loss: 0.4140,	val_loss: 1.5919
1:	[2s / 4s],		train_loss: 0.4049,	val_loss: 0.6698
2:	[2s / 6s],		train_loss: 0.4028,	val_loss: 0.5355
3:	[2s / 8s],		train_loss: 0.4011,	val_loss: 0.4680
4:	[2s / 11s],		train_loss: 0.4001,	val_loss: 0.5460
concordance: 0.4976244561407582
5:	[2s / 13s],		train_loss: 0.3991,	val_loss: 0.9318
6:	[2s / 15s],		train_loss: 0.3983,	val_loss: 1.3449
7:	[2s / 17s],		train_loss: 0.3981,	val_loss: 0.4674
8:	[2s / 20s],		train_loss: 0.3978,	val_loss: 0.5604
9:	[2s / 22s],		train_loss: 0.3958,	val_loss: 0.4443
concordance: 0.5062444453304439
10:	[3s / 25s],		train_loss: 0.3961,	val_loss: 0.4938
11:	[2s / 28s],		train_loss: 0.3976,	val_loss: 0.4298
12:	[2s / 30s],		train_loss: 0.3960,	val_loss: 0.4702
13:	[2s / 33s

In [63]:
table

Unnamed: 0,n_nodes,n_layers,dropout,model_params,epochs,batch_size,discrete,output_bias,lr,score
2,32,4,0.1,"{'alpha': 0.1, 'sigma': 0.2}",512,16,True,True,0.01,0.620541
0,32,4,0.1,"{'alpha': 0.01, 'sigma': 0.2}",512,16,True,True,0.01,0.616121
1,32,4,0.1,"{'alpha': 0.05, 'sigma': 0.2}",512,16,True,True,0.01,0.608533


# Draft

In [None]:
# finding the best learning rate from this model
lrfinder = model.lr_finder(train[0], train[1], batch_size, tolerance=10)
lr = lrfinder.get_best_lr()
model.optimizer.set_lr(lr)

lrfinder_df = lrfinder.to_pandas()
fig = px.line(x=lrfinder_df.index, y=lrfinder_df.train_loss, 
              log_x=True, width=700, height=400)

fig.update_layout(dict(xaxis={'title':'lr'}, yaxis={'title':'batch_loss'}))
fig.show()

print("Best learning rate: ", lr)

In [None]:
callbacks = [tt.callbacks.EarlyStopping(patience=15)]

log = model.fit(train[0], train[1], batch_size, 100, callbacks, verbose,
            val_data=val, val_batch_size=batch_size)

In [None]:
logs_df = log.to_pandas().reset_index().melt(
    id_vars="index", value_name="loss", var_name="dataset").reset_index()

fig = px.line(logs_df, y="loss", x="index", color="dataset", width=800, height = 400)
fig.show()

# scoring the model
surv = model.interpolate(10).predict_surv_df(test[0])
ev = EvalSurv(surv, test[1][0], test[1][1], censor_surv='km')
score = ev.concordance_td()
score

In [None]:

best_score

In [None]:
# finding the best learning rate from this model
lrfinder = model.lr_finder(train[0], y_train, params['batch_size'], tolerance=10, lr_min=0.001, lr_max=1.0)

lrfinder_df = lrfinder.to_pandas()
fig = px.line(x=lrfinder_df.index, y=lrfinder_df.train_loss, 
              log_x=True, width=700, height=400)
fig.update_layout(dict(xaxis={'title':'lr'}, yaxis={'title':'batch_loss'}))


print("Best LR: ", lrfinder.get_best_lr())
fig

In [None]:
help(model.lr_finder)

In [None]:
# setting the new number
model.optimizer.set_lr(lrfinder.get_best_lr())
model.optimizer.param_groups[0]['lr']

In [None]:
callbacks = [tt.callbacks.EarlyStopping()]
verbose = True

log = model.fit(train[0], train[1], params['batch_size'], 2, callbacks, verbose,
                val_data=val, val_batch_size=params['batch_size'])

In [None]:
logs_df = log.to_pandas().reset_index().melt(
    id_vars="index", value_name="loss", var_name="dataset").reset_index()

px.line(logs_df, y="loss", x="index", color="dataset", width=800, height = 400)

In [None]:
model.partial_log_likelihood(*val).mean()

In [None]:
_ = model.compute_baseline_hazards()
surv = model.predict_surv_df(test[0])

In [None]:
N = 3
surv_df = surv[np.random.choice(surv.columns, N)]\
    .reset_index().melt(id_vars="duration", var_name="patient_id", value_name="S")
px.line(surv_df, x="duration", y="S", color="patient_id", width=800, height = 400)

In [None]:
ev = EvalSurv(surv, test[1][0], test[1][1], censor_surv='km')

In [None]:
test[1][0].shape[0]

In [None]:
test[0][0].shape

In [None]:
surv.shape[1]

In [None]:
ev.concordance_td()

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
brier_scores = ev.brier_score(time_grid)

In [None]:
brier_scores_df = pd.DataFrame(brier_scores).reset_index().rename(columns={"index":"duration"})
px.line(brier_scores_df, x="duration", y="brier_score", width=800, height = 400)

In [None]:
time_grid = np.arange(1, 91)

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

In [None]:
help(ev.integrated_nbll)