In [3]:
# import necessary pacakages
import pandas as pd
import numpy as np
from lib.neural_ode_surv import *
from lib.utils import *
import warnings

# check for available GPUs
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Loading and pre-processing the data

In [5]:
# For more details on the dataset and covariates in Framingham data, 
# see : https://biolincc.nhlbi.nih.gov/media/teachingstudies/FHS_Teaching_Longitudinal_Data_Documentation_2021a.pdf?link_time=2022-02-03_18:20:47.023970
# this publicly available data has a person-time format for longitudinal measurements 
df_framingham = pd.read_csv('data/framingham.csv') 

# Specify categorical features
feat_cat = ['SEX', 'CIGPDAY', 'CURSMOKE','educ', 'DIABETES', 'PREVAP', 'PREVCHD', 'PREVMI', 'PREVSTRK', 'PREVHYP', 'BPMEDS']
# Specify continuous features
feat_cont = ['AGE', 'SYSBP', 'DIABP', 'TOTCHOL', 'HDLC', 'LDLC', 'BMI', 'GLUCOSE', 'HEARTRTE']
# Specify features which SurvLatent ODE is set to reconstruct
feat_reconstr = ['SYSBP', 'DIABP', 'TOTCHOL', 'HDLC', 'LDLC', 'BMI', 'GLUCOSE', 'HEARTRTE']
# Specify data_info_dic as follows 
# id_col : unique identifier for a patient
# event_col : columns correspond to event indicator; this should be binary
# time_col : column corresponds to time of measurement
# time_to_event_col : column corresponds to observed time to event (i.e. t_i = min(T_i, C_i))
# feat_cat : list containing a set of categorical features
# feat_cont : list containing a set of continuous valued features
data_info_dic = {'id_col' : 'RANDID', 'event_col' : 'DEATH', 'time_col' : 'TIME', 'time_to_event_col' : 'TIMEDTH', 'feat_cat' : feat_cat, 'feat_cont' : feat_cont}

feats_dim = len(feat_cat) + len(feat_cont)
reconstr_dim = len(feat_reconstr)
if type(data_info_dic['event_col']) == list:
    n_events = len(data_info_dic['event_col'])
else:
    n_events = 1
# Given that the Framingham study is a long follow-up study which spans about 20 years (or around 7500 days),  
# we discretize follow-up time by 10 days. Therefore, our time unit is 10-day.
df_framingham[data_info_dic['time_col']] = np.round(df_framingham[data_info_dic['time_col']].values/10)
df_framingham[data_info_dic['time_to_event_col']] = np.round(df_framingham[data_info_dic['time_to_event_col']].values/10)

# We perform 0.65-0.15-0.2 (train-valid-test) split
test_set_frac = 0.2; train_set_frac = 0.65
random_seed = 1991 # set random seed for reproducibility
np.random.seed(random_seed)

sample_ids = set(df_framingham.RANDID.values)
sample_ids_test = set(np.random.choice(list(sample_ids), size = int(len(sample_ids)*test_set_frac), replace = False))
sample_ids_train = 	set(np.random.choice(list(sample_ids - sample_ids_test), size = int(len(sample_ids - sample_ids_test)*train_set_frac/(1-test_set_frac)), replace = False))
sample_ids_valid = sample_ids - sample_ids_test - sample_ids_train

data_test = df_framingham.loc[df_framingham.RANDID.isin(sample_ids_test)].sort_values([data_info_dic['id_col'], data_info_dic['time_col']], ascending = (True, True))
data_train = df_framingham.loc[df_framingham.RANDID.isin(sample_ids_train)].sort_values([data_info_dic['id_col'], data_info_dic['time_col']], ascending = (True, True))
data_valid = df_framingham.loc[df_framingham.RANDID.isin(sample_ids_valid)].sort_values([data_info_dic['id_col'], data_info_dic['time_col']], ascending = (True, True))

# outlier processing
# We threshold outliers (i.e. feature vals < 0.005 percentile of corresponding feature vals in training set 
# AND feature vals > 0.995 percentile of corresponding features vals in training set)
feats_oi = feat_cont + ['CIGPDAY']; feat_to_min_max_dict = {}
for feat in feats_oi:
    min_feat = np.quantile(data_train[feat].dropna().values, q = 0.005)
    max_feat = np.quantile(data_train[feat].dropna().values, q = 0.995)
    data_train.loc[data_train[feat] < min_feat, feat] = min_feat
    data_train.loc[data_train[feat] > max_feat, feat] = max_feat
    feat_to_min_max_dict[feat] = (min_feat, max_feat)
    # control outliers in the valid + test co hort using training set
    data_valid.loc[data_valid[feat] < min_feat, feat] = min_feat
    data_valid.loc[data_valid[feat] > max_feat, feat] = max_feat
    data_test.loc[data_test[feat] < min_feat, feat] = min_feat
    data_test.loc[data_test[feat] > max_feat, feat] = max_feat

## Choose model hyperparameters and instantiate the model object

In [None]:
# {'batch_size': 75,
#  'dec_g_nn_layers': 5,
#  'dec_latent_dim': 56,
#  'enc_f_nn_layers': 7,
#  'enc_latent_dim': 75,
#  'haz_dec_layers': 2,
#  'lr': 0.01,
#  'num_units_gru': 50,
#  'num_units_ode': 70,
#  'surv_loss_scale': 50,
#  'wait_until_full_surv_loss': 3}

In [8]:
# lr : learning rate
# surv_loss_scale : determines the scaling factor for the survival loss in the total loss
# wait_until_full_surv_loss : wait # epochs until the full survival loss scaling, which allows the model to learn input representation before tuning survival estimates.
batch_size=75; lr=0.01; surv_loss_scale=50; wait_until_full_surv_loss=3; early_stopping = True;
# ODE-RNN encoder
# enc_f_nn_layers : # of layers in the neural networks function f() for learning the latent dynamics on the encoder side
# enc_latent_dim : dimensionality in the latent embedding on the encoder side
# num_units_gru : # of units in each GRU cell
enc_latent_dim=75; enc_f_nn_layers=7; num_units_gru=50; 

# Decoder 
# dec_g_nn_layers : # of layers in the neural networks function g() for learning the latent dynamics on the decoder side
# dec_latent_dim : dimensionality in the latent embedding on the decoder side
# haz_dec_layers : # of layers in the cause-specific decoder module for hazard estimation
# num_units_ode : # of units in function f() and g()
dec_g_nn_layers=5; dec_latent_dim=56; haz_dec_layers=2; num_units_ode=70; 

# Specify the prediction window to 8000 days from the entry with 10-day as a unit time
max_pred_window = 800;
n_epochs = 15; # number of training epochs

# instantiate the model :
model = SurvLatentODE(input_dim = feats_dim, reconstr_dim = reconstr_dim, dec_latent_dim = dec_latent_dim, enc_latent_dim = enc_latent_dim, enc_f_nn_layers = enc_f_nn_layers, 
                      dec_g_nn_layers = dec_g_nn_layers, num_units_ode = num_units_ode, num_units_gru = num_units_gru, device = DEVICE, n_events = n_events, haz_dec_layers = haz_dec_layers)
# set the unique identifier for the corresponding training
run_id = 'framingham_death_outcome_example'

## Training and evaluating the model

In [None]:
# train model
model.fit(data_train, data_valid, data_info_dic, max_pred_window = max_pred_window, run_id = run_id, n_epochs = n_epochs, batch_size = batch_size, surv_loss_scale = surv_loss_scale, early_stopping = early_stopping, feat_reconstr = feat_reconstr, wait_until_full_surv_loss = wait_until_full_surv_loss, random_seed = random_seed)

Pre-processing data...: 100%|██████████| 2882/2882 [00:00<00:00, 19227.44it/s]
Pre-processing data...: 100%|██████████| 666/666 [00:00<00:00, 28809.89it/s]


> /Users/intaemoon/github/survlatent_ode/lib/neural_ode_surv.py(166)fit()
-> param_dics['input_dim'] = data_obj["input_dim"]
(Pdb) c


/Users/intaemoon/github/survlatent_ode/lib/utils.py



Directory 'framingham_death_outcome_example' created
Directory 'surv_curves/framingham_death_outcome_example/reconstruction' created


Training across 15 epochs:   6%|▌         | 38/623 [02:04<33:04,  3.39s/it]

Epoch :  1



Loading validation set...:   0%|          | 0/5 [00:00<?, ?it/s][A
Loading validation set...:  20%|██        | 1/5 [00:01<00:04,  1.24s/it][A
Loading validation set...:  40%|████      | 2/5 [00:02<00:03,  1.26s/it][A
Loading validation set...:  60%|██████    | 3/5 [00:03<00:02,  1.29s/it][A
Loading validation set...:  80%|████████  | 4/5 [00:05<00:01,  1.32s/it][A
Loading validation set...: 100%|██████████| 5/5 [00:06<00:00,  1.22s/it][A




survival log-likelihood :  tensor(-2065.9897)
reconstr. likelihood :  tensor(0.8562)
remaining time-to-event in validation/test set [25%, 37.5%, 50%, 62.5%, 75%] percentiles :  [107 154 194 226 302]
Performance at quantiles : 
BS(t) at the percentiles :  {0.25: 0.1559, 0.375: 0.2381, 0.5: 0.3005, 0.625: 0.3409, 0.75: 0.4274}
AUC(t) at the percentiles :  {0.25: 0.6031, 0.375: 0.6167, 0.5: 0.6416, 0.625: 0.6507, 0.75: 0.6401}
mean AUC(t) over 25-75 percentile :  0.6249
Integrated BS(t) over 25-75 percentile :  0.3057




Training across 15 epochs:   6%|▋         | 39/623 [02:15<55:30,  5.70s/it]

Best iteration so far wrt. mean AUC :  1
Storing the latest model...
model_performance/framingham_death_outcome_example/latest_model.pt




Training across 15 epochs:  12%|█▏        | 77/623 [04:20<28:28,  3.13s/it]

Epoch :  2



Loading validation set...:   0%|          | 0/5 [00:00<?, ?it/s][A
Loading validation set...:  20%|██        | 1/5 [00:01<00:05,  1.28s/it][A
Loading validation set...:  40%|████      | 2/5 [00:02<00:04,  1.33s/it][A
Loading validation set...:  60%|██████    | 3/5 [00:03<00:02,  1.31s/it][A
Loading validation set...:  80%|████████  | 4/5 [00:05<00:01,  1.28s/it][A
Loading validation set...: 100%|██████████| 5/5 [00:06<00:00,  1.21s/it][A




survival log-likelihood :  tensor(-1580.9055)
reconstr. likelihood :  tensor(0.8557)
remaining time-to-event in validation/test set [25%, 37.5%, 50%, 62.5%, 75%] percentiles :  [107 154 194 226 302]
Performance at quantiles : 
BS(t) at the percentiles :  {0.25: 0.0682, 0.375: 0.0988, 0.5: 0.1256, 0.625: 0.1514, 0.75: 0.1724}
AUC(t) at the percentiles :  {0.25: 0.6036, 0.375: 0.6264, 0.5: 0.6513, 0.625: 0.6553, 0.75: 0.6453}
mean AUC(t) over 25-75 percentile :  0.6283
Integrated BS(t) over 25-75 percentile :  0.1286




Training across 15 epochs:  13%|█▎        | 78/623 [04:30<49:00,  5.40s/it]

Best iteration so far wrt. mean AUC :  2
Storing the latest model...
model_performance/framingham_death_outcome_example/latest_model.pt




Training across 15 epochs:  19%|█▊        | 116/623 [06:40<26:54,  3.18s/it]

Epoch :  3



Loading validation set...:   0%|          | 0/5 [00:00<?, ?it/s][A
Loading validation set...:  20%|██        | 1/5 [00:01<00:05,  1.26s/it][A
Loading validation set...:  40%|████      | 2/5 [00:02<00:04,  1.33s/it][A
Loading validation set...:  60%|██████    | 3/5 [00:04<00:02,  1.35s/it][A
Loading validation set...:  80%|████████  | 4/5 [00:05<00:01,  1.32s/it][A
Loading validation set...: 100%|██████████| 5/5 [00:06<00:00,  1.22s/it][A




survival log-likelihood :  tensor(-1582.1737)
reconstr. likelihood :  tensor(0.8556)
remaining time-to-event in validation/test set [25%, 37.5%, 50%, 62.5%, 75%] percentiles :  [107 154 194 226 302]
Performance at quantiles : 
BS(t) at the percentiles :  {0.25: 0.0691, 0.375: 0.0998, 0.5: 0.126, 0.625: 0.1504, 0.75: 0.173}
AUC(t) at the percentiles :  {0.25: 0.4895, 0.375: 0.5033, 0.5: 0.5578, 0.625: 0.5604, 0.75: 0.561}
mean AUC(t) over 25-75 percentile :  0.5248
Integrated BS(t) over 25-75 percentile :  0.1289




Training across 15 epochs:  19%|█▉        | 117/623 [06:50<46:23,  5.50s/it]

Best iteration so far wrt. mean AUC :  2
Storing the latest model...
model_performance/framingham_death_outcome_example/latest_model.pt




Training across 15 epochs:  25%|██▍       | 155/623 [09:00<28:05,  3.60s/it]

Epoch :  4



Loading validation set...:   0%|          | 0/5 [00:00<?, ?it/s][A
Loading validation set...:  20%|██        | 1/5 [00:02<00:08,  2.11s/it][A
Loading validation set...:  40%|████      | 2/5 [00:03<00:05,  1.86s/it][A
Loading validation set...:  60%|██████    | 3/5 [00:05<00:03,  1.77s/it][A
Loading validation set...:  80%|████████  | 4/5 [00:07<00:01,  1.72s/it][A
Loading validation set...: 100%|██████████| 5/5 [00:08<00:00,  1.63s/it][A




survival log-likelihood :  tensor(-1564.1597)
reconstr. likelihood :  tensor(0.8558)
remaining time-to-event in validation/test set [25%, 37.5%, 50%, 62.5%, 75%] percentiles :  [107 154 194 226 302]
Performance at quantiles : 
BS(t) at the percentiles :  {0.25: 0.0669, 0.375: 0.0953, 0.5: 0.1192, 0.625: 0.1424, 0.75: 0.1621}
AUC(t) at the percentiles :  {0.25: 0.6576, 0.375: 0.6547, 0.5: 0.6782, 0.625: 0.6775, 0.75: 0.6702}
mean AUC(t) over 25-75 percentile :  0.6675
Integrated BS(t) over 25-75 percentile :  0.1221


Best iteration so far wrt. mean AUC :  4
Storing the latest model...
model_performance/framingham_death_outcome_example/latest_model.pt


Storing the best model...


Training across 15 epochs:  31%|███       | 194/623 [11:41<27:56,  3.91s/it]

Epoch :  5



Loading validation set...:   0%|          | 0/5 [00:00<?, ?it/s][A
Loading validation set...:  20%|██        | 1/5 [00:02<00:08,  2.06s/it][A
Loading validation set...:  40%|████      | 2/5 [00:03<00:05,  1.97s/it][A
Loading validation set...:  60%|██████    | 3/5 [00:05<00:03,  1.85s/it][A
Loading validation set...:  80%|████████  | 4/5 [00:07<00:01,  1.73s/it][A
Loading validation set...: 100%|██████████| 5/5 [00:08<00:00,  1.67s/it][A




survival log-likelihood :  tensor(-1557.7345)
reconstr. likelihood :  tensor(0.8558)
remaining time-to-event in validation/test set [25%, 37.5%, 50%, 62.5%, 75%] percentiles :  [107 154 194 226 302]
Performance at quantiles : 
BS(t) at the percentiles :  {0.25: 0.067, 0.375: 0.0941, 0.5: 0.1151, 0.625: 0.1378, 0.75: 0.1565}
AUC(t) at the percentiles :  {0.25: 0.6592, 0.375: 0.6723, 0.5: 0.6955, 0.625: 0.6863, 0.75: 0.6683}
mean AUC(t) over 25-75 percentile :  0.6732
Integrated BS(t) over 25-75 percentile :  0.1186


Best iteration so far wrt. mean AUC :  5
Storing the latest model...
model_performance/framingham_death_outcome_example/latest_model.pt


Storing the best model...


Training across 15 epochs:  37%|███▋      | 233/623 [14:16<23:54,  3.68s/it]

Epoch :  6



Loading validation set...:   0%|          | 0/5 [00:00<?, ?it/s][A
Loading validation set...:  20%|██        | 1/5 [00:02<00:08,  2.05s/it][A
Loading validation set...:  40%|████      | 2/5 [00:03<00:05,  1.87s/it][A
Loading validation set...:  60%|██████    | 3/5 [00:05<00:03,  1.95s/it][A
Loading validation set...:  80%|████████  | 4/5 [00:07<00:01,  1.82s/it][A
Loading validation set...: 100%|██████████| 5/5 [00:08<00:00,  1.72s/it][A




survival log-likelihood :  tensor(-1557.9792)
reconstr. likelihood :  tensor(0.8557)
remaining time-to-event in validation/test set [25%, 37.5%, 50%, 62.5%, 75%] percentiles :  [107 154 194 226 302]
Performance at quantiles : 
BS(t) at the percentiles :  {0.25: 0.0676, 0.375: 0.0941, 0.5: 0.1145, 0.625: 0.1367, 0.75: 0.1561}
AUC(t) at the percentiles :  {0.25: 0.6515, 0.375: 0.667, 0.5: 0.7047, 0.625: 0.6902, 0.75: 0.6811}
mean AUC(t) over 25-75 percentile :  0.6735
Integrated BS(t) over 25-75 percentile :  0.1183


Best iteration so far wrt. mean AUC :  6
Storing the latest model...
model_performance/framingham_death_outcome_example/latest_model.pt


Storing the best model...


Training across 15 epochs:  43%|████▎     | 266/623 [16:28<20:54,  3.51s/it]

In [None]:
# load the trained model
print('Loading the trained model...')
print('run_id : ', run_id)
path = 'model_performance/' + run_id + '/best_model.pt'
try:
    model_info = get_ckpt_model(path, model, DEVICE)
except:
    raise KeyError('Model not found...')

# Process the valid set 
batch_dict_valid = model.process_eval_data(data_valid, data_info_dic, max_pred_window = max_pred_window, run_id = run_id, feat_reconstr = feat_reconstr, model_info = model_info, random_seed = random_seed)

# Get estimated survival probabilities
# Note : survival probs are estimated from the latest observation for each sample
# due to generative nature, surv probs may be different across runs. In this example, we set the random seed to control non-deterministic elements
surv_prob = model.get_surv_prob(batch_dict_valid, model_info = model_info, max_pred_window = max_pred_window, filename_suffix = run_id, device = DEVICE, n_events = n_events)
df_valid_result_comp = eval_model(model_info, batch_dict_valid, surv_prob, run_id = run_id, max_pred_window = max_pred_window)


In [None]:
# Process the held-out test set 
batch_dict_test = model.process_eval_data(data_test, data_info_dic, max_pred_window = max_pred_window, run_id = run_id, feat_reconstr = feat_reconstr, model_info = model_info)
# Get estimated survival probabilities
# Note : survival probs are estimated from the latest observation for each sample
# due to generative nature, surv probs may be different across runs. In this example, we set the random seed to control non-deterministic elements
surv_prob = model.get_surv_prob(batch_dict_test, model_info = model_info, max_pred_window = max_pred_window, filename_suffix = run_id, device = DEVICE, n_events = n_events)
# Evaluate the model on the held-out set and obtain model performance summary
df_test_result_comp = eval_model(model_info, batch_dict_test, surv_prob, run_id = run_id, max_pred_window = max_pred_window)

In [None]:
# plot survival curves