In [27]:
# Import libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime
import re
import pickle

import os
path_dir = os.path.dirname(os.getcwd())

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

import torch # For building the networks 
import torchtuples as tt # Some useful functions

from pycox.datasets import metabric
from pycox.models import LogisticHazard, PMF, DeepHitSingle, CoxPH
from pycox.evaluation import EvalSurv

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
cd ../src/

/Users/linafaik/Documents/survival_analysis/src


In [29]:
from train import *
from train_survival_ml import *
from train_survival_deep import *

In [30]:
# We also set some seeds to make this reproducable.
# Note that on gpu, there is still some randomness.
np.random.seed(1234)
_ = torch.manual_seed(123)

In [31]:
# Parameters

scaler_name = "StandardScaler" #MinMaxScaler
random_state = 123

In [72]:
df = pd.read_csv(os.path.join(path_dir, "outputs/customer_subscription_clean.csv"))

# 1. Train / test split

In [73]:
df.columns

Index(['date_time', 'customer_id', 'signup_date_time', 'cancel_date_time',
       'price', 'billing_cycle', 'age', 'duration', 'censored',
       'product=prd_1', 'gender=female', 'channel=email', 'reason=support',
       'nb_cases', 'time_since_signup', 'date_month_cos', 'date_month_sin',
       'date_weekday_cos', 'date_weekday_sin', 'date_hour_cos',
       'date_hour_sin'],
      dtype='object')

In [74]:
# covariate columns (used when possible)

cols_x = [
    'age', 'gender', 'rural',
    'duration_of_stay', 'duration_of_intensive_unit_stay', 
    'smoking','alcohol', 'dm', 'htn', 'cad', 'prior_cmp', 'ckd', 'hb', 'tlc',
    'platelets', 'glucose', 'urea', 'creatinine', 'raised_cardiac_enzymes',
    'severe_anaemia', 'anaemia', 'stable_angina', 'acs', 'stemi',
    'atypical_chest_pain', 'heart_failure', 'hfref', 'hfnef', 'valvular',
    'chb', 'sss', 'aki', 'cva_infract', 'cva_bleed', 'af', 'vt', 'psvt',
    'congenital', 'uti', 'neuro_cardiogenic_syncope', 'orthostatic',
    'infective_endocarditis', 'dvt', 'cardiogenic_shock', 'shock',
    'pulmonary_embolism', 'chest_infection',
    'type_adm', 
    'first_visit', 'nb_visits',
    'duration_of_stay_lag1', 'duration_of_intensive_unit_stay_lag1',
    'cardiogenic_shock_lag1', 'cad_lag1', 'time_before_readm_lag1'
]

col_target = "time_before_readm"

In [75]:
cols_x = [
    'price', 'billing_cycle', 'age',
    'product=prd_1', 'gender=female', 'channel=email', 'reason=support',
    'nb_cases', 'time_since_signup', 
    'date_month_cos', 'date_month_sin',
    'date_weekday_cos', 'date_weekday_sin', 'date_hour_cos',
    'date_hour_sin'
]

col_target = "duration"

In [76]:
Xy_train, Xy_test, y_train, y_test = split_train_test(
    df, cols_x, col_target, test_size=0.15, col_stratify= "censored", random_state=random_state)

Xy_train, Xy_val, y_train, y_val = split_train_test(
    Xy_train, cols_x, col_target, test_size=0.2,  col_stratify= "censored", random_state=random_state)

n_train, n_test, n_val = Xy_train.shape[0], Xy_test.shape[0], Xy_val.shape[0]
n_tot =  n_train + n_test + n_val

print("Train: {}%, Test: {}%, Val: {}%".format(
    round(n_train/n_tot *100),
    round(n_test/n_tot *100),
    round(n_val/n_tot *100)
))

Train: 68%, Test: 15%, Val: 17%


In [77]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# rescale
scaler = eval(scaler_name)()

Xy_train[cols_x] = scaler.fit_transform(Xy_train[cols_x])
Xy_test[cols_x] = scaler.transform(Xy_test[cols_x])

In [78]:
#with open(os.path.join(path_dir, "outputs/cox_ph.pkl"), "rb") as f:
    #estimator = pickle.load(f)

In [79]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

# train an estimator
estimator = CoxPHSurvivalAnalysis(alpha=0.5)
estimator = estimator.fit(Xy_train[cols_x], y_train)

In [80]:
estimator.score(Xy_test[cols_x], y_test)

0.6818926526547785

In [81]:
feat_importance, fig = plot_feat_imp(cols_x, estimator.coef_)
fig

# 2. DeepSurv

Source: https://nbviewer.org/github/havakv/pycox/blob/master/examples/cox-ph.ipynb

In [82]:
n_top = 40
cols_x_reduc = list(feat_importance.feature.iloc[:n_top])
cols_x_reduc

['date_weekday_cos',
 'date_hour_cos',
 'date_weekday_sin',
 'gender=female',
 'date_hour_sin',
 'age',
 'time_since_signup',
 'nb_cases',
 'date_month_cos',
 'price',
 'product=prd_1',
 'billing_cycle',
 'channel=email',
 'reason=support',
 'date_month_sin']

In [83]:
get_target = lambda df: (df[col_target].values, df['censored'].values)

y_train = get_target(Xy_train)
y_val = get_target(Xy_val)
y_test = get_target(Xy_test)

train = (np.array(Xy_train[cols_x_reduc]).astype(np.float32), y_train)
val = (np.array(Xy_val[cols_x_reduc]).astype(np.float32), y_val)
test = (np.array(Xy_test[cols_x_reduc]).astype(np.float32), y_test)

In [88]:
params = {
    'n_nodes': 128,
    'n_layers': 4,
    'dropout': 0.4,
    'lr':0.005, 
    'batch_size': 64,
}


logs_df, model, score = train_deep_surv(
    train, val, test, CoxPH, out_features = 1, tolerance=10, 
    print_lr=True, print_logs=True, verbose = True,
    **params
)

print('score', score)

concordance: 0.652570868600722
0:	[21s / 21s],		train_loss: 3.0540,	val_loss: 29.2488
1:	[15s / 36s],		train_loss: 3.0369,	val_loss: 21.1583
2:	[12s / 48s],		train_loss: 3.0339,	val_loss: 19.1340
3:	[13s / 1m:2s],		train_loss: 3.0318,	val_loss: 12.9004
4:	[9s / 1m:11s],		train_loss: 3.0306,	val_loss: 13.3456
concordance: 0.6588056433214963
5:	[17s / 1m:29s],		train_loss: 3.0301,	val_loss: 23.7524
6:	[9s / 1m:38s],		train_loss: 3.0289,	val_loss: 14.4282
7:	[9s / 1m:48s],		train_loss: 3.0286,	val_loss: 12.7086
8:	[9s / 1m:57s],		train_loss: 3.0283,	val_loss: 12.1495
9:	[10s / 2m:7s],		train_loss: 3.0275,	val_loss: 17.6870
concordance: 0.6616914091750455
10:	[17s / 2m:24s],		train_loss: 3.0276,	val_loss: 13.1918
11:	[9s / 2m:34s],		train_loss: 3.0271,	val_loss: 8.8010
12:	[9s / 2m:43s],		train_loss: 3.0257,	val_loss: 13.8372
13:	[9s / 2m:52s],		train_loss: 3.0262,	val_loss: 9.0589
14:	[9s / 3m:1s],		train_loss: 3.0274,	val_loss: 8.7906
concordance: 0.6603391914794575
15:	[16s / 3m:18s],		

score 0.7045817952566035


In [89]:
surv = model.predict_surv_df(test[0])
ev = EvalSurv(surv, test[1][0], test[1][1])
score = ev.concordance_td()
score

0.7045817952566035

In [85]:
grid_params = {
    "n_nodes" :[32, 64, 128],
    "n_layers" :[2, 4],
    "dropout" :[0.3, 0.4],
    "lr" :[0.01, 0.005]
}

best_model, table = grid_search_deep(train, val, test, 1, grid_params, CoxPH)

24 total scenario to run
1/24: params: {'n_nodes': 32, 'n_layers': 2, 'dropout': 0.3, 'lr': 0.01}
concordance: 0.6504141121319281
0:	[28s / 28s],		train_loss: 1.7751,	val_loss: 34.9330
1:	[20s / 48s],		train_loss: 1.7673,	val_loss: 99.0763
2:	[27s / 1m:15s],		train_loss: 1.7659,	val_loss: 41.2679
3:	[24s / 1m:39s],		train_loss: 1.7654,	val_loss: 58.9233
4:	[24s / 2m:4s],		train_loss: 1.7647,	val_loss: 93.6270
concordance: 0.6423004425339345
5:	[30s / 2m:35s],		train_loss: 1.7645,	val_loss: 31.7624
6:	[20s / 2m:55s],		train_loss: 1.7665,	val_loss: 37.1184
7:	[18s / 3m:14s],		train_loss: 1.7641,	val_loss: 51.9789
8:	[20s / 3m:35s],		train_loss: 1.7644,	val_loss: 57.5091
9:	[20s / 3m:55s],		train_loss: 1.7638,	val_loss: 78.7920
concordance: 0.649419447325389
10:	[26s / 4m:22s],		train_loss: 1.7638,	val_loss: 58.7192
11:	[18s / 4m:40s],		train_loss: 1.7637,	val_loss: 74.4302
12:	[20s / 5m:0s],		train_loss: 1.7640,	val_loss: 56.9902
13:	[19s / 5m:19s],		train_loss: 1.7636,	val_loss: 91.4547

28:	[41s / 20m:4s],		train_loss: 1.7635,	val_loss: 10.1493
29:	[31s / 20m:36s],		train_loss: 1.7648,	val_loss: 10.1115
concordance: 0.6586946568968153
30:	[39s / 21m:15s],		train_loss: 1.7644,	val_loss: 21.7910
31:	[31s / 21m:47s],		train_loss: 1.7655,	val_loss: 16.3775
32:	[31s / 22m:18s],		train_loss: 1.7641,	val_loss: 16.5737
33:	[31s / 22m:49s],		train_loss: 1.7656,	val_loss: 9.5660
34:	[37s / 23m:27s],		train_loss: 1.7640,	val_loss: 15.5554
Current score: 0.7029911558302366 vs. best score: 0.6998000406757322
6/24: params: {'n_nodes': 32, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.6513218306927372
0:	[38s / 38s],		train_loss: 1.7804,	val_loss: 49.3429
1:	[30s / 1m:9s],		train_loss: 1.7692,	val_loss: 14.6531
2:	[42s / 1m:51s],		train_loss: 1.7681,	val_loss: 19.4893
3:	[38s / 2m:30s],		train_loss: 1.7651,	val_loss: 24.0068
4:	[38s / 3m:9s],		train_loss: 1.7640,	val_loss: 26.8162
concordance: 0.659067969390784
5:	[55s / 4m:4s],		train_loss: 1.7639,	val_loss: 18.8007
6:


overflow encountered in exp



concordance: 0.6539490217862658
20:	[27s / 9m:1s],		train_loss: 1.7619,	val_loss: 35.3761
21:	[21s / 9m:22s],		train_loss: 1.7624,	val_loss: 52.1026
22:	[21s / 9m:44s],		train_loss: 1.7614,	val_loss: 69.7244
23:	[19s / 10m:4s],		train_loss: 1.7624,	val_loss: 37.3000
24:	[20s / 10m:24s],		train_loss: 1.7614,	val_loss: 87.0978



overflow encountered in exp



concordance: 0.6484240840624551
25:	[39s / 11m:4s],		train_loss: 1.7614,	val_loss: 48.4945
26:	[22s / 11m:26s],		train_loss: 1.7619,	val_loss: 32.2269
27:	[33s / 11m:59s],		train_loss: 1.7612,	val_loss: 52.4732
28:	[25s / 12m:25s],		train_loss: 1.7608,	val_loss: 38.6607
29:	[27s / 12m:53s],		train_loss: 1.7610,	val_loss: 49.7750
Current score: 0.6975826365205902 vs. best score: 0.7029911558302366
10/24: params: {'n_nodes': 64, 'n_layers': 2, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.6507453171840314
0:	[36s / 36s],		train_loss: 1.7705,	val_loss: 46.4719
1:	[24s / 1m:0s],		train_loss: 1.7613,	val_loss: 42.9896
2:	[20s / 1m:21s],		train_loss: 1.7599,	val_loss: 37.4015
3:	[23s / 1m:45s],		train_loss: 1.7589,	val_loss: 34.0024
4:	[23s / 2m:8s],		train_loss: 1.7584,	val_loss: 38.3617
concordance: 0.6512078678566279
5:	[30s / 2m:39s],		train_loss: 1.7562,	val_loss: 64.2460
6:	[20s / 3m:0s],		train_loss: 1.7571,	val_loss: 68.3413
7:	[20s / 3m:20s],		train_loss: 1.7556,	val_loss: 44.8217
8:	

31:	[32s / 20m:1s],		train_loss: 1.7616,	val_loss: 5.0091
32:	[32s / 20m:33s],		train_loss: 1.7619,	val_loss: 11.3569
33:	[33s / 21m:6s],		train_loss: 1.7604,	val_loss: 9.1955
34:	[32s / 21m:39s],		train_loss: 1.7611,	val_loss: 15.6003
concordance: 0.6611726792827489
35:	[41s / 22m:21s],		train_loss: 1.7612,	val_loss: 16.3114
36:	[32s / 22m:53s],		train_loss: 1.7597,	val_loss: 26.4593
37:	[34s / 23m:28s],		train_loss: 1.7611,	val_loss: 10.1839
38:	[35s / 24m:4s],		train_loss: 1.7600,	val_loss: 12.0969
39:	[33s / 24m:37s],		train_loss: 1.7613,	val_loss: 19.3626
concordance: 0.6538691085901994
40:	[40s / 25m:18s],		train_loss: 1.7603,	val_loss: 23.5972
41:	[33s / 25m:51s],		train_loss: 1.7613,	val_loss: 7.8698
42:	[32s / 26m:24s],		train_loss: 1.7620,	val_loss: 11.1046
43:	[33s / 26m:57s],		train_loss: 1.7598,	val_loss: 9.5172
44:	[32s / 27m:30s],		train_loss: 1.7605,	val_loss: 11.3205
concordance: 0.6584852737348893
45:	[40s / 28m:11s],		train_loss: 1.7610,	val_loss: 19.6043
46:	[32s / 


overflow encountered in exp



concordance: 0.64683670147149
0:	[39s / 39s],		train_loss: 1.7785,	val_loss: 148.9819
1:	[29s / 1m:9s],		train_loss: 1.7711,	val_loss: 192.8973
2:	[28s / 1m:37s],		train_loss: 1.7699,	val_loss: 36.8388
3:	[28s / 2m:6s],		train_loss: 1.7701,	val_loss: 27.1849
4:	[28s / 2m:35s],		train_loss: 1.7678,	val_loss: 33.9914



overflow encountered in exp



concordance: 0.6497309877068005
5:	[36s / 3m:11s],		train_loss: 1.7664,	val_loss: 70.2941
6:	[27s / 3m:39s],		train_loss: 1.7666,	val_loss: 26.3376
7:	[27s / 4m:6s],		train_loss: 1.7677,	val_loss: 47.8540
8:	[26s / 4m:33s],		train_loss: 1.7679,	val_loss: 80.5235
9:	[26s / 5m:0s],		train_loss: 1.7670,	val_loss: 58.8040
concordance: 0.6516421836806305
10:	[36s / 5m:36s],		train_loss: 1.7651,	val_loss: 35.5838
11:	[27s / 6m:4s],		train_loss: 1.7660,	val_loss: 45.0953
12:	[27s / 6m:31s],		train_loss: 1.7648,	val_loss: 27.4272
13:	[27s / 6m:59s],		train_loss: 1.7649,	val_loss: 37.0089
14:	[28s / 7m:28s],		train_loss: 1.7655,	val_loss: 35.3459



overflow encountered in exp



concordance: 0.5791227522218645
15:	[44s / 8m:12s],		train_loss: 1.7664,	val_loss: 13.2295
16:	[30s / 8m:42s],		train_loss: 1.7645,	val_loss: 69.5888
17:	[28s / 9m:10s],		train_loss: 1.7646,	val_loss: 32.5881
18:	[27s / 9m:37s],		train_loss: 1.7651,	val_loss: 33.6881
19:	[28s / 10m:6s],		train_loss: 1.7652,	val_loss: 45.4317
concordance: 0.5941981420063976
20:	[37s / 10m:43s],		train_loss: 1.7647,	val_loss: 25.0309
21:	[28s / 11m:12s],		train_loss: 1.7651,	val_loss: 55.8740
22:	[30s / 11m:42s],		train_loss: 1.7649,	val_loss: 46.4835
23:	[27s / 12m:10s],		train_loss: 1.7655,	val_loss: 36.5460
24:	[32s / 12m:42s],		train_loss: 1.7647,	val_loss: 24.4814
concordance: 0.6524557940876116
25:	[36s / 13m:18s],		train_loss: 1.7652,	val_loss: 41.1815
26:	[27s / 13m:45s],		train_loss: 1.7645,	val_loss: 15.8148
27:	[26s / 14m:12s],		train_loss: 1.7629,	val_loss: 19.7122
28:	[31s / 14m:44s],		train_loss: 1.7642,	val_loss: 47.6226
29:	[40s / 15m:24s],		train_loss: 1.7640,	val_loss: 40.7047
concordan


overflow encountered in exp



concordance: 0.6023382415612882
15:	[58s / 18m:8s],		train_loss: 1.7651,	val_loss: 3258.9050
16:	[50s / 18m:59s],		train_loss: 1.7648,	val_loss: 135734.7500
17:	[54s / 19m:53s],		train_loss: 1.7658,	val_loss: 67.6407
18:	[53s / 20m:47s],		train_loss: 1.7654,	val_loss: 60845.6562
19:	[45s / 21m:32s],		train_loss: 1.7632,	val_loss: 5636.6074



overflow encountered in exp



concordance: 0.5556364583649838
20:	[59s / 22m:32s],		train_loss: 1.7621,	val_loss: 143398.7344
21:	[47s / 23m:19s],		train_loss: 1.7623,	val_loss: 25.7974
22:	[48s / 24m:7s],		train_loss: 1.7616,	val_loss: 367.2150
Current score: 0.6995147800314381 vs. best score: 0.7034254578418767
22/24: params: {'n_nodes': 128, 'n_layers': 4, 'dropout': 0.3, 'lr': 0.005}
concordance: 0.6601263208404484
0:	[56s / 56s],		train_loss: 1.7762,	val_loss: 13.8259
1:	[35s / 1m:32s],		train_loss: 1.7643,	val_loss: 13.5254
2:	[32s / 2m:4s],		train_loss: 1.7612,	val_loss: 27.3265
3:	[33s / 2m:37s],		train_loss: 1.7586,	val_loss: 9.6973
4:	[37s / 3m:15s],		train_loss: 1.7582,	val_loss: 20.8342



overflow encountered in exp



concordance: 0.6544100471244467
5:	[41s / 3m:56s],		train_loss: 1.7576,	val_loss: 33.2281
6:	[33s / 4m:30s],		train_loss: 1.7572,	val_loss: 20.5247
7:	[31s / 5m:2s],		train_loss: 1.7568,	val_loss: 20.6723
8:	[31s / 5m:34s],		train_loss: 1.7564,	val_loss: 31.3397
9:	[31s / 6m:5s],		train_loss: 1.7561,	val_loss: 14.8633
concordance: 0.6573151587786639
10:	[44s / 6m:49s],		train_loss: 1.7565,	val_loss: 20.6125
11:	[49s / 7m:38s],		train_loss: 1.7556,	val_loss: 19.6140
12:	[42s / 8m:21s],		train_loss: 1.7554,	val_loss: 23.2593
13:	[44s / 9m:6s],		train_loss: 1.7552,	val_loss: 22.6672
14:	[1m:19s / 10m:25s],		train_loss: 1.7566,	val_loss: 28.5096
concordance: 0.6640097864442709
15:	[1m:6s / 11m:31s],		train_loss: 1.7540,	val_loss: 15.7202
16:	[32s / 12m:4s],		train_loss: 1.7546,	val_loss: 22.7837
17:	[34s / 12m:39s],		train_loss: 1.7546,	val_loss: 9.0076
18:	[32s / 13m:12s],		train_loss: 1.7539,	val_loss: 22.4596
19:	[32s / 13m:44s],		train_loss: 1.7545,	val_loss: 23.0365
concordance: 0.651

56:	[1m:22s / 43m:24s],		train_loss: 1.7537,	val_loss: 12.6665
57:	[44s / 44m:8s],		train_loss: 1.7548,	val_loss: 10.3159
58:	[33s / 44m:41s],		train_loss: 1.7554,	val_loss: 7.1895
59:	[32s / 45m:13s],		train_loss: 1.7555,	val_loss: 10.1598
concordance: 0.6608787938172622
60:	[43s / 45m:57s],		train_loss: 1.7552,	val_loss: 7.9473
Current score: 0.7041290806938355 vs. best score: 0.7058750645685248


In [86]:
table

Unnamed: 0,n_nodes,n_layers,dropout,lr,score
21,128,4,0.3,0.005,0.705875
23,128,4,0.4,0.005,0.704129
12,64,4,0.3,0.01,0.703425
4,32,4,0.3,0.01,0.702991
13,64,4,0.3,0.005,0.702621
15,64,4,0.4,0.005,0.701399
9,64,2,0.3,0.005,0.701087
17,128,2,0.3,0.005,0.700704
6,32,4,0.4,0.01,0.700245
14,64,4,0.4,0.01,0.700141


In [87]:
table.sort_values(by="score", ascending=False, inplace=True)
best_params = table.drop('score', axis=1).iloc[0].to_dict()
best_score = table.score.iloc[0]

print('Best score: ', best_score)
print('Best params: ')
print(best_params)

Best score:  0.7058750645685248
Best params: 
{'n_nodes': 128.0, 'n_layers': 4.0, 'dropout': 0.3, 'lr': 0.005}


# 3. DeepHit

In [90]:
from pycox.models import DeepHitSingle

In [91]:
num_durations = int(df[col_target].max())
labtrans = DeepHitSingle.label_transform(num_durations)

In [92]:
get_target = lambda df: (df[col_target].values, df['censored'].values)

y_train = labtrans.fit_transform(*get_target(Xy_train))
y_val = labtrans.transform(*get_target(Xy_val))

train = (np.array(Xy_train[cols_x_reduc]).astype(np.float32), y_train)
val = (np.array(Xy_val[cols_x_reduc]).astype(np.float32), y_val)
test = (np.array(Xy_test[cols_x_reduc]).astype(np.float32), y_test)

In [103]:
params = {
    'n_nodes': 128,
    'n_layers': 4,
    'out_features': num_durations,
    'dropout': 0.1,
    'model_params': {'alpha': 0.2, 'sigma': 0.1, 'duration_index': labtrans.cuts},
    'discrete': True,
    'lr':0.005, 
    'batch_size': 128,
}

In [104]:
logs_df, model, score = train_deep_surv(
    train, val, test, DeepHitSingle, tolerance=10, 
    print_lr=True, print_logs=True, verbose = True,
    **params
)

concordance: 0.6244590987575174
0:	[1m:30s / 1m:30s],		train_loss: 0.8827,	val_loss: 23.4963
1:	[28s / 1m:58s],		train_loss: 0.7699,	val_loss: 15.2776
2:	[28s / 2m:26s],		train_loss: 0.7403,	val_loss: 11.4759
3:	[26s / 2m:53s],		train_loss: 0.7283,	val_loss: 12.8435
4:	[29s / 3m:22s],		train_loss: 0.7204,	val_loss: 26.5724
concordance: 0.6379005919105127
5:	[1m:53s / 5m:16s],		train_loss: 0.7179,	val_loss: 14.9949
6:	[22s / 5m:38s],		train_loss: 0.7105,	val_loss: 17.1590
7:	[25s / 6m:4s],		train_loss: 0.7083,	val_loss: 13.9112
8:	[25s / 6m:29s],		train_loss: 0.7050,	val_loss: 12.7882
9:	[30s / 6m:59s],		train_loss: 0.7032,	val_loss: 20.0412
concordance: 0.6542120400366527
10:	[1m:33s / 8m:32s],		train_loss: 0.7016,	val_loss: 16.1270
11:	[21s / 8m:54s],		train_loss: 0.7004,	val_loss: 18.0156
12:	[23s / 9m:18s],		train_loss: 0.6987,	val_loss: 27.0577
13:	[25s / 9m:43s],		train_loss: 0.6975,	val_loss: 21.9931
14:	[26s / 10m:10s],		train_loss: 0.6968,	val_loss: 11.5833
concordance: 0.65010

FileNotFoundError: [Errno 2] No such file or directory: 'weight_checkpoint_2023-1-25_10-13-59_ANYoDZu2jIdJI3yIaAPp.pt'

In [None]:
grid_params = {
    'n_nodes': [32],
    'n_layers': [4],
    'dropout': [0.1],
    'model_params': [
        {'alpha': 0.01, 'sigma': 0.2, 'duration_index': labtrans.cuts},
        {'alpha': 0.05, 'sigma': 0.2, 'duration_index': labtrans.cuts},
        {'alpha': 0.1, 'sigma': 0.2, 'duration_index': labtrans.cuts},
    ],
    'epochs': [512],
    'batch_size': [16],
    'discrete': [True],
    'output_bias': [True],
}

best_model, table = grid_search_deep(train, val, test, labtrans.out_features, grid_params, DeepHitSingle)

In [None]:
table

# Draft

In [None]:
# finding the best learning rate from this model
lrfinder = model.lr_finder(train[0], train[1], batch_size, tolerance=10)
lr = lrfinder.get_best_lr()
model.optimizer.set_lr(lr)

lrfinder_df = lrfinder.to_pandas()
fig = px.line(x=lrfinder_df.index, y=lrfinder_df.train_loss, 
              log_x=True, width=700, height=400)

fig.update_layout(dict(xaxis={'title':'lr'}, yaxis={'title':'batch_loss'}))
fig.show()

print("Best learning rate: ", lr)

In [None]:
callbacks = [tt.callbacks.EarlyStopping(patience=15)]

log = model.fit(train[0], train[1], batch_size, 100, callbacks, verbose,
            val_data=val, val_batch_size=batch_size)

In [None]:
logs_df = log.to_pandas().reset_index().melt(
    id_vars="index", value_name="loss", var_name="dataset").reset_index()

fig = px.line(logs_df, y="loss", x="index", color="dataset", width=800, height = 400)
fig.show()

# scoring the model
surv = model.interpolate(10).predict_surv_df(test[0])
ev = EvalSurv(surv, test[1][0], test[1][1], censor_surv='km')
score = ev.concordance_td()
score

In [None]:

best_score

In [None]:
# finding the best learning rate from this model
lrfinder = model.lr_finder(train[0], y_train, params['batch_size'], tolerance=10, lr_min=0.001, lr_max=1.0)

lrfinder_df = lrfinder.to_pandas()
fig = px.line(x=lrfinder_df.index, y=lrfinder_df.train_loss, 
              log_x=True, width=700, height=400)
fig.update_layout(dict(xaxis={'title':'lr'}, yaxis={'title':'batch_loss'}))


print("Best LR: ", lrfinder.get_best_lr())
fig

In [None]:
help(model.lr_finder)

In [None]:
# setting the new number
model.optimizer.set_lr(lrfinder.get_best_lr())
model.optimizer.param_groups[0]['lr']

In [None]:
callbacks = [tt.callbacks.EarlyStopping()]
verbose = True

log = model.fit(train[0], train[1], params['batch_size'], 2, callbacks, verbose,
                val_data=val, val_batch_size=params['batch_size'])

In [None]:
logs_df = log.to_pandas().reset_index().melt(
    id_vars="index", value_name="loss", var_name="dataset").reset_index()

px.line(logs_df, y="loss", x="index", color="dataset", width=800, height = 400)

In [None]:
model.partial_log_likelihood(*val).mean()

In [None]:
_ = model.compute_baseline_hazards()
surv = model.predict_surv_df(test[0])

In [None]:
N = 3
surv_df = surv[np.random.choice(surv.columns, N)]\
    .reset_index().melt(id_vars="duration", var_name="patient_id", value_name="S")
px.line(surv_df, x="duration", y="S", color="patient_id", width=800, height = 400)

In [None]:
ev = EvalSurv(surv, test[1][0], test[1][1], censor_surv='km')

In [None]:
test[1][0].shape[0]

In [None]:
test[0][0].shape

In [None]:
surv.shape[1]

In [None]:
ev.concordance_td()

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
brier_scores = ev.brier_score(time_grid)

In [None]:
brier_scores_df = pd.DataFrame(brier_scores).reset_index().rename(columns={"index":"duration"})
px.line(brier_scores_df, x="duration", y="brier_score", width=800, height = 400)

In [None]:
time_grid = np.arange(1, 91)

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

In [None]:
help(ev.integrated_nbll)