In [1]:
# import necessary pacakages
import pandas as pd
import numpy as np
from lib.neural_ode_surv import *
from lib.utils import *
import warnings
import matplotlib.pyplot as plt

# check for available GPUs
# DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


## Time to multiple events prediction example : Framingham data
### Loading and pre-processing the data
Event of interest : ANYCHD (Angina Pectoris, Myocardial infarction, Coronary Insufficiency, or Fatal Coronary Heart Disease)

Competing event : Death from any cause

In [2]:
# For more details on the dataset and covariates in Framingham data, 
# see : https://biolincc.nhlbi.nih.gov/media/teachingstudies/FHS_Teaching_Longitudinal_Data_Documentation_2021a.pdf?link_time=2022-02-03_18:20:47.023970
# this publicly available data has a person-time format for longitudinal measurements 
df_framingham = pd.read_csv('data/framingham.csv') 

# Specify categorical features
feat_cat = ['SEX', 'CIGPDAY', 'CURSMOKE','educ', 'DIABETES', 'PREVSTRK', 'PREVHYP', 'BPMEDS']
# Specify continuous features
feat_cont = ['AGE', 'SYSBP', 'DIABP', 'TOTCHOL', 'HDLC', 'LDLC', 'BMI', 'GLUCOSE', 'HEARTRTE']
# Specify features which SurvLatent ODE is set to reconstruct
feat_reconstr = ['AGE', 'SYSBP', 'DIABP', 'TOTCHOL', 'HDLC', 'LDLC', 'BMI', 'GLUCOSE', 'HEARTRTE']#['SYSBP', 'DIABP', 'TOTCHOL', 'HDLC', 'LDLC', 'BMI', 'GLUCOSE', 'HEARTRTE']
# Specify data_info_dic as follows 
# id_col : unique identifier for a patient
# event_col : columns correspond to event indicator; this should be binary
# time_col : column corresponds to time of measurement
# time_to_event_col : column corresponds to observed time to event (i.e. t_i = min(T_i, C_i))
# feat_cat : list containing a set of categorical features
# feat_cont : list containing a set of continuous valued features
data_info_dic = {'id_col':'RANDID', 'event_col':['ANYCHD', 'DEATH'], 'time_to_event_col':['TIMECHD', 'TIMEDTH'],
                 'time_col':'TIME', 'feat_cat':feat_cat, 'feat_cont':feat_cont}

feats_dim = len(feat_cat) + len(feat_cont)
reconstr_dim = len(feat_reconstr)
if type(data_info_dic['event_col']) == list:
    n_events = len(data_info_dic['event_col'])
else:
    n_events = 1
# Given that the Framingham study is a long follow-up study which spans about 20 years (or around 7500 days),  
# we discretize follow-up time by 10 days. Therefore, our time unit is 10-day.
df_framingham[data_info_dic['time_col']] = np.round(df_framingham[data_info_dic['time_col']].values/10)
df_framingham[data_info_dic['time_to_event_col']] = np.round(df_framingham[data_info_dic['time_to_event_col']].values/10)

# We perform 0.65-0.15-0.2 (train-valid-test) split
test_set_frac = 0.2; train_set_frac = 0.65
random_seed = 1991 # set random seed for reproducibility
np.random.seed(random_seed)

sample_ids = set(df_framingham.RANDID.values)
sample_ids_test = set(np.random.choice(list(sample_ids), size = int(len(sample_ids)*test_set_frac), replace = False))
sample_ids_train = 	set(np.random.choice(list(sample_ids - sample_ids_test),
                                         size = int(len(sample_ids - sample_ids_test)*train_set_frac/(1-test_set_frac)),
                                         replace = False))
sample_ids_valid = sample_ids - sample_ids_test - sample_ids_train

data_test = (df_framingham.loc[df_framingham.RANDID.isin(sample_ids_test)].
             sort_values([data_info_dic['id_col'], data_info_dic['time_col']], ascending=(True, True)))
data_train = (df_framingham.loc[df_framingham.RANDID.isin(sample_ids_train)].
              sort_values([data_info_dic['id_col'], data_info_dic['time_col']], ascending=(True, True)))
data_valid = (df_framingham.loc[df_framingham.RANDID.isin(sample_ids_valid)].
              sort_values([data_info_dic['id_col'], data_info_dic['time_col']], ascending=(True, True)))

######
data_train = data_train.drop_duplicates('RANDID')
data_valid = data_valid.drop_duplicates('RANDID')
data_train = data_train.fillna(data_train.mean())
data_valid = data_valid.fillna(data_valid.mean())
data_train = data_train.fillna(0)
data_valid = data_valid.fillna(0)
######
# outlier processing
# We threshold outliers (i.e. feature vals < 0.005 percentile of corresponding feature vals in training set 
# AND feature vals > 0.995 percentile of corresponding features vals in training set)
feats_oi = feat_cont + ['CIGPDAY']
feat_to_min_max_dict = {}
for feat in feats_oi:
    min_feat = np.quantile(data_train[feat].dropna().values, q=0.005)
    max_feat = np.quantile(data_train[feat].dropna().values, q=0.995)
    data_train.loc[data_train[feat] < min_feat, feat] = min_feat
    data_train.loc[data_train[feat] > max_feat, feat] = max_feat
    feat_to_min_max_dict[feat] = (min_feat, max_feat)
    # control outliers in the valid + test cohorts using training set
    data_valid.loc[data_valid[feat] < min_feat, feat] = min_feat
    data_valid.loc[data_valid[feat] > max_feat, feat] = max_feat
    data_test.loc[data_test[feat] < min_feat, feat] = min_feat
    data_test.loc[data_test[feat] > max_feat, feat] = max_feat

  data_train.loc[data_train[feat] < min_feat, feat] = min_feat
  data_valid.loc[data_valid[feat] < min_feat, feat] = min_feat
  data_test.loc[data_test[feat] < min_feat, feat] = min_feat


## Choose model hyperparameters and instantiate the model object

In [3]:
# lr : learning rate
# surv_loss_scale : determines the scaling factor for the survival loss in the total loss
# wait_until_full_surv_loss : wait # epochs until the full survival loss scaling, which allows the model to learn input representation before tuning survival estimates.
batch_size = 100; lr = 0.01; surv_loss_scale = 100; wait_until_full_surv_loss = 3; early_stopping = True;
# ODE-RNN encoder
# enc_f_nn_layers : # of layers in the neural networks function f() for learning the latent dynamics on the encoder side
# enc_latent_dim : dimensionality in the latent embedding on the encoder side
# num_units_gru : # of units in each GRU cell
enc_latent_dim = 50; enc_f_nn_layers = 5; num_units_gru = 80; 

# Decoder 
# dec_g_nn_layers : # of layers in the neural networks function g() for learning the latent dynamics on the decoder side
# dec_latent_dim : dimensionality in the latent embedding on the decoder side
# haz_dec_layers : # of layers in the cause-specific decoder module for hazard estimation
# num_units_ode : # of units in function f() and g()
dec_g_nn_layers = 7; dec_latent_dim = 40; haz_dec_layers = 3; num_units_ode = 70

# Specify the prediction window to 8000 days from the entry with 10-day as a unit time
max_pred_window = 800
n_epochs =30 # number of training epochs

# reconstr_dim = 0
# del data_info_dic['feat_cont']
# data_info_dic['feat_cont'] = []
# feat_reconstr = []
# feats_dim = len(feat_cat)
# instantiate the model :
model = SurvLatentODE(input_dim=feats_dim, reconstr_dim=reconstr_dim, dec_latent_dim=dec_latent_dim,
                      enc_latent_dim=enc_latent_dim, enc_f_nn_layers=enc_f_nn_layers, 
                      dec_g_nn_layers=dec_g_nn_layers, num_units_ode=num_units_ode, num_units_gru=num_units_gru,
                      device=DEVICE, n_events=n_events, haz_dec_layers=haz_dec_layers)
# set the unique identifier for the corresponding training
run_id = 'framingham_competing_events_example_v1_1'

## Training and evaluating the model

In [None]:
# train model
# note that samples with event times of zero (i.e. t_i = 0; event time overlapping the latest observation times) are excluded.
model.fit(data_train, data_valid, data_info_dic,
          max_pred_window=max_pred_window, run_id=run_id, n_epochs=n_epochs,
          batch_size=batch_size, surv_loss_scale=surv_loss_scale, early_stopping=early_stopping,
          feat_reconstr=feat_reconstr, wait_until_full_surv_loss=wait_until_full_surv_loss, random_seed=random_seed)

Pre-processing data...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2882/2882 [00:00<00:00, 28952.70it/s]


excluded samples due to event times overlapping last observation times (i.e. remaining t_i = 0) :  [76273, 208566, 231492, 428306, 445400, 556045, 571377, 599475, 797308, 946128, 972314, 1075504, 1080397, 1232204, 1263082, 1406606, 1568334, 1663651, 1695438, 1798396, 1954038, 2087324, 2108588, 2134396, 2180046, 2181152, 2408348, 2434794, 2448708, 2474378, 2483517, 2507740, 2564697, 2640601, 2646666, 2682411, 2708769, 2727755, 2839250, 2865166, 2951629, 3117784, 3235453, 3361368, 3455001, 3587516, 3603542, 3683894, 3702628, 3710385, 3779954, 3810088, 3915943, 3927641, 4030316, 4066905, 4210168, 4220542, 4227351, 4229307, 4244133, 4295921, 4362066, 4504564, 4543147, 4637896, 4645346, 4700914, 4719511, 4754059, 4757585, 4897828, 4903592, 5130897, 5187398, 5266590, 5386797, 5406199, 5561922, 5571109, 5610759, 5611619, 5654796, 5784616, 5837883, 5871074, 5917953, 6171121, 6442383, 6662624, 6720746, 6759187, 7165088, 7173129, 7215259, 7392212, 7414089, 7416933, 7447033, 7507638, 7568367, 767

Pre-processing data...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 666/666 [00:00<00:00, 27577.22it/s]


excluded samples due to event times overlapping last observation times (i.e. remaining t_i = 0) :  [414678, 512339, 1101060, 2080190, 2097493, 3158323, 3718050, 4233073, 4588247, 4614345, 4726021, 5080716, 5865112, 6316687, 6949688, 7410564, 7546881, 7559394, 7668298, 8083473, 8466833, 9380233, 9802787, 9868819]
n =  24


/home/weijiesun/survival_project/survlatent_ode/lib/utils.py



using currently existing directory :  model_performance/framingham_competing_events_example_v1_1
using currently existing directory :  surv_curves/framingham_competing_events_example_v1_1


Training across 30 epochs:   3%|████▍                                                                                                                                    | 27/840 [03:39<49:30,  3.65s/it]

Epoch :  1



Loading validation set...:   0%|                                                                                                                                                    | 0/4 [00:00<?, ?it/s][A

In [None]:
data_test = data_test.drop_duplicates('RANDID')

In [None]:
# load the trained model
print('Loading the trained model...')
print('run_id : ', run_id)
path = 'model_performance/' + run_id + '/best_model.pt'
try:
    model_info = get_ckpt_model(path, model, DEVICE)
except:
    raise KeyError('Model not found...')

# Process the held-out test set 
batch_dict_test = model.process_eval_data(data_test, data_info_dic, max_pred_window=max_pred_window,
                                          run_id=run_id, feat_reconstr=feat_reconstr, model_info=model_info)
# Get estimated survival probabilities
# Note : survival probs are estimated from the latest observation for each sample
# due to generative nature, surv probs may be different across runs. In this example, we set the random seed to control non-deterministic elements
ef_surv_prob, cs_cif_total = model.get_surv_prob(batch_dict_test, model_info=model_info,
                                                 max_pred_window=max_pred_window, filename_suffix=run_id,
                                                 device=DEVICE, n_events=n_events)
# Evaluate the model on the held-out set and obtain model performance summary
df_test_result_comp = eval_model(model_info, batch_dict_test, ef_surv_prob, run_id=run_id,
                                 cs_cif_total=cs_cif_total, max_pred_window=max_pred_window, n_events=n_events)

In [None]:
df_test_result_comp