In [21]:
### migrate code from tensorflow v1 to v2: 
# !tf_upgrade_v2 \
#   --infile class_DeepLongitudinal-Original.py \
#   --outfile class_DeepLongitudinal-Original_v2.py

# !tf_upgrade_v2 \
#   --infile utils_network-Original.py \
#   --outfile utils_network-Original_v2.py

In [22]:
_EPSILON = 1e-08

import numpy as np
import pandas as pd
import tensorflow as tf
import random
import os

from sklearn.model_selection import train_test_split

import import_data as impt
from tf_slim.layers import layers as _layers
from class_DeepLongitudinal import Model_Longitudinal_Attention

from utils_eval             import c_index, brier_score
from utils_log              import save_logging, load_logging
from utils_helper           import f_get_minibatch, f_get_boosted_trainset

In [23]:
def _f_get_pred(sess, model, data, data_mi, pred_horizon):
    '''
        predictions based on the prediction time.
        create new_data and new_mask2 that are available previous or equal to the prediction time (no future measurements are used)
    '''
    new_data    = np.zeros(np.shape(data))
    new_data_mi = np.zeros(np.shape(data_mi))

    meas_time = np.concatenate([np.zeros([np.shape(data)[0], 1]), np.cumsum(data[:, :, 0], axis=1)[:, :-1]], axis=1)

    for i in range(np.shape(data)[0]):
        last_meas = np.sum(meas_time[i, :] <= pred_horizon)

        new_data[i, :last_meas, :]    = data[i, :last_meas, :]
        new_data_mi[i, :last_meas, :] = data_mi[i, :last_meas, :]

    return model.predict(new_data, new_data_mi)


def f_get_risk_predictions(sess, model, data_, data_mi_, pred_time, eval_time):
    
    pred = _f_get_pred(sess, model, data_[[0]], data_mi_[[0]], 0)
    _, num_Event, num_Category = np.shape(pred)
       
    risk_all = {}
    for k in range(num_Event):
        risk_all[k] = np.zeros([np.shape(data_)[0], len(pred_time), len(eval_time)])
            
    for p, p_time in enumerate(pred_time):
        ### PREDICTION
        pred_horizon = int(p_time)
        pred = _f_get_pred(sess, model, data_, data_mi_, pred_horizon)


        for t, t_time in enumerate(eval_time):
            eval_horizon = int(t_time) + pred_horizon #if eval_horizon >= num_Category, output the maximum...

            # calculate F(t | x, Y, t >= t_M) = \sum_{t_M <= \tau < t} P(\tau | x, Y, \tau > t_M)
            risk = np.sum(pred[:,:,pred_horizon:(eval_horizon+1)], axis=2) #risk score until eval_time
            risk = risk / (np.sum(np.sum(pred[:,:,pred_horizon:], axis=2), axis=1, keepdims=True) +_EPSILON) #conditioniong on t > t_pred
            
            for k in range(num_Event):
                risk_all[k][:, p, t] = risk[:, k]
                
    return risk_all

In [24]:
### Set prediction time window (t) and evaluation time (delta t) for C-index and Brier-Score)
pred_time = list(range(1,33,1)) # prediction time (in years)
eval_time = list(range(1))  


### 1. Import Dataset
#####      - Users must prepare dataset in csv format and modify 'import_data.py' following our examplar 'PBC2'

In [25]:
df = pd.read_csv('./data/data_longi_long_expanded_variables_between_y0_y15_all_subjects.csv')


trainingid_all = pd.read_csv('./data/all_testing_set_ID_dynamic_deephit_all_subjects_2.csv')
validationid_all = pd.read_csv('./data/all_testing_set_ID_dynamic_deephit_all_subjects_2.csv')
testingid_all = pd.read_csv('./data/all_testing_set_ID_dynamic_deephit_all_subjects_2.csv')



In [26]:
df

Unnamed: 0,ID,status,time,exam_year,time_te_in_yrs,AGE_Y0,MALE,RACEBLACK,ARMCI,ASMA,...,PULSE,SMKNW,WGT,WINE,WST,HBM,DBP,SBP,CHNOW,PATCK
0,100016012504,0,11825,0,32.375086,22,1,0,0.0,0,...,28,1,173.2,1,86.0,0,95.0,126.0,0,0
1,100016012504,0,11825,2,32.375086,22,1,0,30.0,0,...,29,1,164.5,1,78.0,0,71.0,103.0,0,1
2,100016012504,0,11825,5,32.375086,22,1,0,33.0,0,...,38,1,180.0,1,90.5,0,72.0,114.0,0,0
3,100023004268,0,11820,0,32.361396,30,1,0,22.0,0,...,21,0,167.0,0,80.0,0,74.0,108.0,0,0
4,100023004268,0,11820,2,32.361396,30,1,0,30.0,0,...,31,0,173.5,0,80.0,0,64.0,97.0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24919,416817227898,0,11604,2,31.770021,22,1,1,38.0,0,...,27,1,199.5,0,95.5,0,63.0,102.0,0,0
24920,416817227898,0,11604,5,31.770021,22,1,1,36.0,0,...,35,1,204.0,0,104.0,0,101.0,133.0,0,0
24921,416817227898,0,11604,7,31.770021,22,1,1,36.0,0,...,33,1,196.0,0,94.0,0,77.0,133.0,0,0
24922,416817227898,0,11604,10,31.770021,22,1,1,35.5,0,...,30,0,192.0,0,92.0,0,66.0,112.0,0,0


In [27]:
df = df.drop(columns = 'time').rename(columns={"status": "event", "time_te_in_yrs": "time"})


In [28]:
df

Unnamed: 0,ID,event,exam_year,time,AGE_Y0,MALE,RACEBLACK,ARMCI,ASMA,BEER,...,PULSE,SMKNW,WGT,WINE,WST,HBM,DBP,SBP,CHNOW,PATCK
0,100016012504,0,0,32.375086,22,1,0,0.0,0,20,...,28,1,173.2,1,86.0,0,95.0,126.0,0,0
1,100016012504,0,2,32.375086,22,1,0,30.0,0,8,...,29,1,164.5,1,78.0,0,71.0,103.0,0,1
2,100016012504,0,5,32.375086,22,1,0,33.0,0,20,...,38,1,180.0,1,90.5,0,72.0,114.0,0,0
3,100023004268,0,0,32.361396,30,1,0,22.0,0,6,...,21,0,167.0,0,80.0,0,74.0,108.0,0,0
4,100023004268,0,2,32.361396,30,1,0,30.0,0,8,...,31,0,173.5,0,80.0,0,64.0,97.0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24919,416817227898,0,2,31.770021,22,1,1,38.0,0,8,...,27,1,199.5,0,95.5,0,63.0,102.0,0,0
24920,416817227898,0,5,31.770021,22,1,1,36.0,0,24,...,35,1,204.0,0,104.0,0,101.0,133.0,0,0
24921,416817227898,0,7,31.770021,22,1,1,36.0,0,28,...,33,1,196.0,0,94.0,0,77.0,133.0,0,0
24922,416817227898,0,10,31.770021,22,1,1,35.5,0,12,...,30,0,192.0,0,92.0,0,66.0,112.0,0,0


In [29]:
df.columns.values

array(['ID', 'event', 'exam_year', 'time', 'AGE_Y0', 'MALE', 'RACEBLACK',
       'ARMCI', 'ASMA', 'BEER', 'BMI', 'CANCR', 'CGTDY', 'CHOL', 'DFPAY',
       'DIAB', 'ED', 'GALL', 'GLU', 'HDL', 'KIDNY', 'LDL', 'LIFE', 'LIQR',
       'LIVER', 'MENTL', 'NPREG', 'NTRIG', 'PSTYR', 'PULSE', 'SMKNW',
       'WGT', 'WINE', 'WST', 'HBM', 'DBP', 'SBP', 'CHNOW', 'PATCK'],
      dtype=object)

In [30]:
bin_list = ['MALE', 'RACEBLACK', 'ASMA', 'CANCR', 'DIAB'
                                 ,  'GALL', 'KIDNY', 'LIVER', 'MENTL', 'SMKNW', 'HBM', 'CHNOW', 'PATCK']
cont_list = ['AGE_Y0', 'ARMCI', 'BEER', 'BMI', 'CGTDY', 'CHOL'
       , 'ED', 'HDL', 'LDL', 'LIFE', 'LIQR'
       , 'NPREG', 'NTRIG', 'PSTYR', 'PULSE', 'WGT'
       , 'WINE', 'WST', 'DBP', 'SBP', 'GLU', 'DFPAY']
len(bin_list)+len(cont_list)

35

In [31]:
#data_mode                   = 'PBC2' 
data_mode                   = 'CARDIA_ASCVD' 
seed                        = 1234

##### IMPORT DATASET
'''
    num_Category            = max event/censoring time * 1.2
    num_Event               = number of evetns i.e. len(np.unique(label))-1
    max_length              = maximum number of measurements
    x_dim                   = data dimension including delta (1 + num_features)
    x_dim_cont              = dim of continuous features
    x_dim_bin               = dim of binary features
    mask1, mask2, mask3     = used for cause-specific network (FCNet structure)
'''

# (x_dim, x_dim_cont, x_dim_bin), (data, time, label), (mask1, mask2, mask3), (data_mi) = impt.import_dataset(norm_mode = 'standard')



(x_dim, x_dim_cont, x_dim_bin), (data, time, label), (mask1, mask2, mask3), (data_mi) = impt.import_dataset(df_ = df
                  , bin_list = bin_list
                  , cont_list = cont_list
                   , norm_mode = 'standard')





_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]


file_path = '{}'.format(data_mode)

if not os.path.exists(file_path):
    os.makedirs(file_path)

In [32]:
data.shape

(5083, 6, 36)

In [33]:
data[1,:,:10]

array([[ 2.        ,  1.42111709, -1.81182381,  0.44804029, -0.75962534,
        -0.01741329,  0.54192917,  0.64375276,  0.4619518 ,  0.47399923],
       [ 3.        ,  1.42111709, -0.22885939,  0.73322445, -0.75567558,
        -0.14191382, -0.00878362,  0.64375276,  0.31992237,  0.06789359],
       [ 2.        ,  1.42111709, -0.6246005 ,  0.16285613, -0.68612013,
        -0.5154154 , -0.3855871 ,  0.64375276, -0.67428363, -0.15077868],
       [ 3.        ,  1.42111709, -0.52566522,  0.16285613, -0.71651461,
        -0.5154154 ,  0.6578687 ,  0.64375276,  0.67499594,  0.53647702],
       [ 0.        ,  1.42111709, -0.42672994,  0.44804029, -0.05225887,
        -0.5154154 , -0.09573827,  0.64375276, -0.31921006,  0.13037138],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])

### 2. Set Hyper-Parameters
##### - Play with your own hyper-parameters!

In [38]:
burn_in_mode                = 'ON' #{'ON', 'OFF'}
boost_mode                  = 'ON' #{'ON', 'OFF'}

##### HYPER-PARAMETERS
new_parser = {'mb_size': 2, # 4 #64

             'iteration_burn_in': 3000,
             'iteration': 25000,

             'keep_prob': 0.6, # 0.6
             'lr_train': 1e-4,

             'h_dim_RNN': 100,
             'h_dim_FC' : 100,
             'num_layers_RNN': 6, #14, #8 #6 #4 #2
             'num_layers_ATT':2,
             'num_layers_CS' :2,

             'RNN_type':'LSTM', #{'LSTM', 'GRU'}

             'FC_active_fn' : tf.nn.relu,
             'RNN_active_fn': tf.nn.tanh,

            'reg_W'         : 1e-5,
            'reg_W_out'     : 0.,

             'alpha' :1.0,
             'beta'  :0.1,
             'gamma' :1.0
}


# INPUT DIMENSIONS
input_dims                  = { 'x_dim'         : x_dim,
                                'x_dim_cont'    : x_dim_cont,
                                'x_dim_bin'     : x_dim_bin,
                                'num_Event'     : num_Event,
                                'num_Category'  : num_Category,
                                'max_length'    : max_length }

# NETWORK HYPER-PARMETERS
network_settings            = { 'h_dim_RNN'         : new_parser['h_dim_RNN'],
                                'h_dim_FC'          : new_parser['h_dim_FC'],
                                'num_layers_RNN'    : new_parser['num_layers_RNN'],
                                'num_layers_ATT'    : new_parser['num_layers_ATT'],
                                'num_layers_CS'     : new_parser['num_layers_CS'],
                                'RNN_type'          : new_parser['RNN_type'],
                                'FC_active_fn'      : new_parser['FC_active_fn'],
                                'RNN_active_fn'     : new_parser['RNN_active_fn'],
                               # 'initial_W'         : tf.contrib.layers.xavier_initializer(),
                               
                                'initial_W'         : tf.keras.initializers.glorot_normal(),

                               
                                'reg_W'             : new_parser['reg_W'],
                                'reg_W_out'         : new_parser['reg_W_out']
                                 }


mb_size           = new_parser['mb_size']
iteration         = new_parser['iteration']
iteration_burn_in = new_parser['iteration_burn_in']

keep_prob         = new_parser['keep_prob']
lr_train          = new_parser['lr_train']

alpha             = new_parser['alpha']
beta              = new_parser['beta']
gamma             = new_parser['gamma']

# SAVE HYPERPARAMETERS
log_name = file_path + '/hyperparameters_log.txt'
save_logging(new_parser, log_name)

### 4. Train the Network

In [35]:
# ## Tuning: training with number of iteration 25000 -> 50000

# fold = 1
# print('FOLD '+str(fold) + '...')

# ##### get training, testing, and validation data:
# df_train = df.loc[df['ID'].isin(trainingid_all.iloc[:,fold])]
# df_val = df.loc[df['ID'].isin(validationid_all.iloc[:,fold])]
# df_test = df.loc[df['ID'].isin(testingid_all.iloc[:,fold])]

# # ### TRAINING-TESTING SPLIT in the format suitable for this network

# (x_dim, x_dim_cont, x_dim_bin), (te_data, te_time, te_label), (te_mask1, te_mask2, te_mask3), (te_data_mi) = impt.import_dataset(df_ = df_test)
# (x_dim, x_dim_cont, x_dim_bin), (va_data, va_time, va_label), (va_mask1, va_mask2, va_mask3), (va_data_mi) = impt.import_dataset(df_ = df_val)
# (x_dim, x_dim_cont, x_dim_bin), (tr_data, tr_time, tr_label), (tr_mask1, tr_mask2, tr_mask3), (tr_data_mi) = impt.import_dataset(df_ = df_train)

# if boost_mode == 'ON':
#     tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3 = f_get_boosted_trainset(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)  







# ##### CREATE AND TRAIN NETWORK:
# # tf.reset_default_graph()
# tf.compat.v1.reset_default_graph()

# # config = tf.ConfigProto()
# config = tf.compat.v1.ConfigProto()

# config.gpu_options.allow_growth = True
# sess = tf.compat.v1.Session(config=config)

# model = Model_Longitudinal_Attention(sess, "Dyanmic-DeepHit", input_dims, network_settings)
# # saver = tf.train.Saver()
# saver = tf.compat.v1.train.Saver()

# # sess.run(tf.global_variables_initializer())
# sess.run(tf.compat.v1.global_variables_initializer())

# ### TRAINING - BURN-IN
# if burn_in_mode == 'ON':
#     print( "BURN-IN TRAINING ...")
#     for itr in range(iteration_burn_in):
#         x_mb, x_mi_mb, k_mb, t_mb, m1_mb, m2_mb, m3_mb = f_get_minibatch(mb_size, tr_data, tr_data_mi, tr_label, tr_time, tr_mask1, tr_mask2, tr_mask3)
#         DATA = (x_mb, k_mb, t_mb)
#         MISSING = (x_mi_mb)

#         _, loss_curr = model.train_burn_in(DATA, MISSING, keep_prob, lr_train)

#         if (itr+1)%1000 == 0:
#             print('itr: {:04d} | loss: {:.4f}'.format(itr+1, loss_curr))


# ### TRAINING - MAIN
# print( "MAIN TRAINING ...")
# min_valid = 0.5

# for itr in range(iteration):
#     x_mb, x_mi_mb, k_mb, t_mb, m1_mb, m2_mb, m3_mb = f_get_minibatch(mb_size, tr_data, tr_data_mi, tr_label, tr_time, tr_mask1, tr_mask2, tr_mask3)
#     DATA = (x_mb, k_mb, t_mb)
#     MASK = (m1_mb, m2_mb, m3_mb)
#     MISSING = (x_mi_mb)
#     PARAMETERS = (alpha, beta, gamma)

#     _, loss_curr = model.train(DATA, MASK, MISSING, PARAMETERS, keep_prob, lr_train)

#     if (itr+1)%1000 == 0:
#         print('itr: {:04d} | loss: {:.4f}'.format(itr+1, loss_curr))

#     ### VALIDATION  (based on average C-index of our interest)
#     if (itr+1)%1000 == 0:        
#         risk_all = f_get_risk_predictions(sess, model, va_data, va_data_mi, pred_time, eval_time)

#         for p, p_time in enumerate(pred_time):
#             pred_horizon = int(p_time)
#             val_result1 = np.zeros([num_Event, len(eval_time)])

#             for t, t_time in enumerate(eval_time):                
#                 eval_horizon = int(t_time) + pred_horizon
#                 for k in range(num_Event):
#                     val_result1[k, t] = c_index(risk_all[k][:, p, t], va_time, (va_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

#             if p == 0:
#                 val_final1 = val_result1
#             else:
#                 val_final1 = np.append(val_final1, val_result1, axis=0)

#         tmp_valid = np.mean(val_final1)

#         if tmp_valid >  min_valid:
#             min_valid = tmp_valid
#             saver.save(sess, file_path + '/model')
#             print( 'updated.... average c-index = ' + str('%.4f' %(tmp_valid)))








# ### PREDICTION ON TEST SET               
# #saver.restore(sess, file_path + '/model')

# risk_all = f_get_risk_predictions(sess, model, te_data, te_data_mi, pred_time, eval_time)

# for p, p_time in enumerate(pred_time):
#     pred_horizon = int(p_time)
#     result1, result2 = np.zeros([num_Event, len(eval_time)]), np.zeros([num_Event, len(eval_time)])

#     for t, t_time in enumerate(eval_time):                
#         eval_horizon = int(t_time) + pred_horizon
#         for k in range(num_Event):
#             result1[k, t] = c_index(risk_all[k][:, p, t], te_time, (te_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)
#             result2[k, t] = brier_score(risk_all[k][:, p, t], te_time, (te_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

#     if p == 0:
#         final1, final2 = result1, result2
#     else:
#         final1, final2 = np.append(final1, result1, axis=0), np.append(final2, result2, axis=0)








# ### PRINT PERFORMANCE RESULTS
# row_header = []
# for p_time in pred_time:
#     for t in range(num_Event):
#         row_header.append('pred_time {}: event_{}'.format(p_time,k+1))

# col_header = []
# for t_time in eval_time:
#     col_header.append('eval_time {}'.format(t_time))


# # c-index result
# df1 = pd.DataFrame(final1, index = row_header, columns=col_header)

# # brier-score result
# df2 = pd.DataFrame(final2, index = row_header, columns=col_header)

# print('========================================================')
# print('--------------------------------------------------------')
# print('- C-INDEX: ')
# print(df1)
# print('--------------------------------------------------------')
# print('- BRIER-SCORE: ')
# print(df2)
# print('========================================================')








# ### SAVE C-INDEX, BRIER SCORE, and PREDICTED PROB RISK ON TEST SET
# actual_fold = fold+1
# work_dir = 'U:/Hieu/CARDIA_longi_project'
# savedir = os.path.join(work_dir,'rdata_files/dynamic_deephit_expanded_var_y15_2_fold_'+str(actual_fold)+'/')
# try: 
#     os.makedirs(savedir)
# except OSError:
#     if not os.path.isdir(savedir):
#         raise



# c_over_time = df1.iloc[:,0]
# # c_over_time.to_csv(savedir+'/c_index.csv', index = None, header = True)

# brier_over_time = df2.iloc[:,0]
# # brier_over_time.to_csv(savedir+'/brier_score.csv', index = None, header = True)



# prob_risk_test_df = pd.DataFrame(risk_all[0][:,:,0])
# prob_risk_test_df.columns = pred_time
# prob_risk_test_df.insert(loc=0, column='ID', value=np.unique(df_test['ID']))
# # prob_risk_test_df.to_csv(savedir+'/prob_risk_test.csv', index = None, header = True)



In [36]:
### TRAINING AND TESTING IN LOOP:

In [39]:
nfolds = 10

for fold in range(nfolds):


    print('FOLD '+str(fold) + '...')
    
    ##### get training, testing, and validation data:
    df_train = df.loc[df['ID'].isin(trainingid_all.iloc[:,fold])]
    df_val = df.loc[df['ID'].isin(validationid_all.iloc[:,fold])]
    df_test = df.loc[df['ID'].isin(testingid_all.iloc[:,fold])]

    # ### TRAINING-TESTING SPLIT in the format suitable for this network

    (x_dim, x_dim_cont, x_dim_bin), (te_data, te_time, te_label), (te_mask1, te_mask2, te_mask3), (te_data_mi) = impt.import_dataset(df_ = df_test)
    (x_dim, x_dim_cont, x_dim_bin), (va_data, va_time, va_label), (va_mask1, va_mask2, va_mask3), (va_data_mi) = impt.import_dataset(df_ = df_val)
    (x_dim, x_dim_cont, x_dim_bin), (tr_data, tr_time, tr_label), (tr_mask1, tr_mask2, tr_mask3), (tr_data_mi) = impt.import_dataset(df_ = df_train)

    if boost_mode == 'ON':
        tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3 = f_get_boosted_trainset(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)  



        
        


    ##### CREATE AND TRAIN NETWORK:
    # tf.reset_default_graph()
    tf.compat.v1.reset_default_graph()

    # config = tf.ConfigProto()
    config = tf.compat.v1.ConfigProto()

    config.gpu_options.allow_growth = True
    sess = tf.compat.v1.Session(config=config)

    model = Model_Longitudinal_Attention(sess, "Dyanmic-DeepHit", input_dims, network_settings)
    # saver = tf.train.Saver()
    saver = tf.compat.v1.train.Saver()

    # sess.run(tf.global_variables_initializer())
    sess.run(tf.compat.v1.global_variables_initializer())

    ### TRAINING - BURN-IN
    if burn_in_mode == 'ON':
        print( "BURN-IN TRAINING ...")
        for itr in range(iteration_burn_in):
            x_mb, x_mi_mb, k_mb, t_mb, m1_mb, m2_mb, m3_mb = f_get_minibatch(mb_size, tr_data, tr_data_mi, tr_label, tr_time, tr_mask1, tr_mask2, tr_mask3)
            DATA = (x_mb, k_mb, t_mb)
            MISSING = (x_mi_mb)

            _, loss_curr = model.train_burn_in(DATA, MISSING, keep_prob, lr_train)

            if (itr+1)%1000 == 0:
                print('itr: {:04d} | loss: {:.4f}'.format(itr+1, loss_curr))


    ### TRAINING - MAIN
    print( "MAIN TRAINING ...")
    min_valid = 0.5

    for itr in range(iteration):
        x_mb, x_mi_mb, k_mb, t_mb, m1_mb, m2_mb, m3_mb = f_get_minibatch(mb_size, tr_data, tr_data_mi, tr_label, tr_time, tr_mask1, tr_mask2, tr_mask3)
        DATA = (x_mb, k_mb, t_mb)
        MASK = (m1_mb, m2_mb, m3_mb)
        MISSING = (x_mi_mb)
        PARAMETERS = (alpha, beta, gamma)

        _, loss_curr = model.train(DATA, MASK, MISSING, PARAMETERS, keep_prob, lr_train)

        if (itr+1)%1000 == 0:
            print('itr: {:04d} | loss: {:.4f}'.format(itr+1, loss_curr))

        ### VALIDATION  (based on average C-index of our interest)
        if (itr+1)%1000 == 0:        
            risk_all = f_get_risk_predictions(sess, model, va_data, va_data_mi, pred_time, eval_time)

            for p, p_time in enumerate(pred_time):
                pred_horizon = int(p_time)
                val_result1 = np.zeros([num_Event, len(eval_time)])

                for t, t_time in enumerate(eval_time):                
                    eval_horizon = int(t_time) + pred_horizon
                    for k in range(num_Event):
                        val_result1[k, t] = c_index(risk_all[k][:, p, t], va_time, (va_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

                if p == 0:
                    val_final1 = val_result1
                else:
                    val_final1 = np.append(val_final1, val_result1, axis=0)

            tmp_valid = np.mean(val_final1)

            if tmp_valid >  min_valid:
                min_valid = tmp_valid
                saver.save(sess, file_path + '/model')
                print( 'updated.... average c-index = ' + str('%.4f' %(tmp_valid)))

                
                
                
         
        
        
                
    ### PREDICTION ON TEST SET               
    #saver.restore(sess, file_path + '/model')
 
    risk_all = f_get_risk_predictions(sess, model, te_data, te_data_mi, pred_time, eval_time)

    for p, p_time in enumerate(pred_time):
        pred_horizon = int(p_time)
        result1, result2 = np.zeros([num_Event, len(eval_time)]), np.zeros([num_Event, len(eval_time)])

        for t, t_time in enumerate(eval_time):                
            eval_horizon = int(t_time) + pred_horizon
            for k in range(num_Event):
                result1[k, t] = c_index(risk_all[k][:, p, t], te_time, (te_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)
                result2[k, t] = brier_score(risk_all[k][:, p, t], te_time, (te_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

        if p == 0:
            final1, final2 = result1, result2
        else:
            final1, final2 = np.append(final1, result1, axis=0), np.append(final2, result2, axis=0)


            
            
    
    
    
    
    ### PRINT PERFORMANCE RESULTS
    row_header = []
    for p_time in pred_time:
        for t in range(num_Event):
            row_header.append('pred_time {}: event_{}'.format(p_time,k+1))

    col_header = []
    for t_time in eval_time:
        col_header.append('eval_time {}'.format(t_time))


    # c-index result
    df1 = pd.DataFrame(final1, index = row_header, columns=col_header)

    # brier-score result
    df2 = pd.DataFrame(final2, index = row_header, columns=col_header)

    print('========================================================')
    print('--------------------------------------------------------')
    print('- C-INDEX: ')
    print(df1)
    print('--------------------------------------------------------')
    print('- BRIER-SCORE: ')
    print(df2)
    print('========================================================')
    
    
    
    
    
    
    
    
    ### SAVE C-INDEX, BRIER SCORE, and PREDICTED PROB RISK ON TEST SET
    actual_fold = fold+1
    work_dir = 'U:/Hieu/CARDIA_longi_project'
    savedir = os.path.join(work_dir,'rdata_files/dynamic_deephit_expanded_var_all_subjects_2_fold_'+str(actual_fold)+'/')
    try: 
        os.makedirs(savedir)
    except OSError:
        if not os.path.isdir(savedir):
            raise



    c_over_time = df1.iloc[:,0]
    c_over_time.to_csv(savedir+'/c_index.csv', index = None, header = True)

    brier_over_time = df2.iloc[:,0]
    brier_over_time.to_csv(savedir+'/brier_score.csv', index = None, header = True)



    prob_risk_test_df = pd.DataFrame(risk_all[0][:,:,0])
    prob_risk_test_df.columns = pred_time
    prob_risk_test_df.insert(loc=0, column='ID', value=np.unique(df_test['ID']))
    prob_risk_test_df.to_csv(savedir+'/prob_risk_test.csv', index = None, header = True)



FOLD 0...
BURN-IN TRAINING ...
itr: 1000 | loss: 4.4114
itr: 2000 | loss: 1.6496
itr: 3000 | loss: 1.8362
MAIN TRAINING ...
itr: 1000 | loss: 2.6800
itr: 2000 | loss: 1.5496
itr: 3000 | loss: 2.5686
itr: 4000 | loss: 2.4867
itr: 5000 | loss: 1.2150
itr: 6000 | loss: 1.1068
itr: 7000 | loss: 10.8495
itr: 8000 | loss: 2.3186
itr: 9000 | loss: 1.0001
itr: 10000 | loss: 4.1598
itr: 11000 | loss: 3.6945
itr: 12000 | loss: 0.3590
itr: 13000 | loss: 5.2541
itr: 14000 | loss: 1.4385
itr: 15000 | loss: 2.3767
itr: 16000 | loss: 1.7046
itr: 17000 | loss: 4.0954
itr: 18000 | loss: 6.7296
itr: 19000 | loss: 1.9252
itr: 20000 | loss: 2.8522
itr: 21000 | loss: 1.2373
itr: 22000 | loss: 0.0917
itr: 23000 | loss: 0.7289
itr: 24000 | loss: 5.8627
itr: 25000 | loss: 0.9789
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     



BURN-IN TRAINING ...
itr: 1000 | loss: 0.8616
itr: 2000 | loss: 3.7864
itr: 3000 | loss: 0.8679
MAIN TRAINING ...
itr: 1000 | loss: 6.7613
itr: 2000 | loss: 6.2552
itr: 3000 | loss: 2.6470
itr: 4000 | loss: 2.4273
itr: 5000 | loss: 2.0319
itr: 6000 | loss: 3.9374
itr: 7000 | loss: 1.1670
itr: 8000 | loss: 0.3310
itr: 9000 | loss: 3.9052
itr: 10000 | loss: 0.7267
itr: 11000 | loss: 0.9240
itr: 12000 | loss: 4.6200
itr: 13000 | loss: 3.7646
itr: 14000 | loss: 1.0840
itr: 15000 | loss: 0.6188
itr: 16000 | loss: 1.4529
itr: 17000 | loss: 0.1351
itr: 18000 | loss: 1.4097
itr: 19000 | loss: 2.5302
itr: 20000 | loss: 1.8616
itr: 21000 | loss: 0.5263
itr: 22000 | loss: 0.4335
itr: 23000 | loss: 0.9884
itr: 24000 | loss: 1.0509
itr: 25000 | loss: 2.0976
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 4.8048
itr: 2000 | loss: 1.5682
itr: 3000 | loss: 2.5122
MAIN TRAINING ...
itr: 1000 | loss: 2.7481
itr: 2000 | loss: 2.5874
itr: 3000 | loss: 3.5902
itr: 4000 | loss: 1.9604
itr: 5000 | loss: 3.1476
itr: 6000 | loss: 3.4868
itr: 7000 | loss: 3.7712
itr: 8000 | loss: 2.0613
itr: 9000 | loss: 2.9942
itr: 10000 | loss: 1.1791
itr: 11000 | loss: 1.9748
itr: 12000 | loss: 1.3144
itr: 13000 | loss: 1.4861
itr: 14000 | loss: 0.4876
itr: 15000 | loss: 3.0832
itr: 16000 | loss: 0.8966
itr: 17000 | loss: 6.3313
itr: 18000 | loss: 0.5911
itr: 19000 | loss: 1.1133
itr: 20000 | loss: 1.2195
itr: 21000 | loss: 3.8556
itr: 22000 | loss: 1.1249
itr: 23000 | loss: 1.5162
itr: 24000 | loss: 1.3910
itr: 25000 | loss: 1.8291
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 3.3112
itr: 2000 | loss: 4.6015
itr: 3000 | loss: 0.6258
MAIN TRAINING ...
itr: 1000 | loss: 3.0000
itr: 2000 | loss: 2.3389
itr: 3000 | loss: 2.4585
itr: 4000 | loss: 1.5259
itr: 5000 | loss: 2.8875
itr: 6000 | loss: 0.7102
itr: 7000 | loss: 2.4908
itr: 8000 | loss: 1.5543
itr: 9000 | loss: 3.9842
itr: 10000 | loss: 1.3045
itr: 11000 | loss: 0.8447
itr: 12000 | loss: 0.8300
itr: 13000 | loss: 3.1925
itr: 14000 | loss: 2.3171
itr: 15000 | loss: 0.1132
itr: 16000 | loss: 0.4596
itr: 17000 | loss: 1.3835
itr: 18000 | loss: 3.9311
itr: 19000 | loss: 0.4341
itr: 20000 | loss: 3.5238
itr: 21000 | loss: 1.0117
itr: 22000 | loss: 0.7921
itr: 23000 | loss: 1.0063
itr: 24000 | loss: 1.0972
itr: 25000 | loss: 0.6305
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 2.4400
itr: 2000 | loss: 2.2221
itr: 3000 | loss: 3.3524
MAIN TRAINING ...
itr: 1000 | loss: 12.9423
itr: 2000 | loss: 3.8502
itr: 3000 | loss: 2.1580
itr: 4000 | loss: 2.0828
itr: 5000 | loss: 2.8605
itr: 6000 | loss: 1.2664
itr: 7000 | loss: 1.6368
itr: 8000 | loss: 1.8838
itr: 9000 | loss: 0.7573
itr: 10000 | loss: 1.1356
itr: 11000 | loss: 0.5235
itr: 12000 | loss: 3.7389
itr: 13000 | loss: 0.8725
itr: 14000 | loss: 2.8094
itr: 15000 | loss: 0.7211
itr: 16000 | loss: 5.1252
itr: 17000 | loss: 0.4352
itr: 18000 | loss: 1.4220
itr: 19000 | loss: 3.8483
itr: 20000 | loss: 0.5022
itr: 21000 | loss: 1.5648
itr: 22000 | loss: 2.8269
itr: 23000 | loss: 0.6736
itr: 24000 | loss: 0.2285
itr: 25000 | loss: 0.4670
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000




BURN-IN TRAINING ...
itr: 1000 | loss: 0.4433
itr: 2000 | loss: 2.4836
itr: 3000 | loss: 1.0242
MAIN TRAINING ...
itr: 1000 | loss: 1.8416
itr: 2000 | loss: 2.5330
itr: 3000 | loss: 1.0988
itr: 4000 | loss: 0.9954
itr: 5000 | loss: 1.6863
itr: 6000 | loss: 1.0267
itr: 7000 | loss: 1.4599
itr: 8000 | loss: 1.3872
itr: 9000 | loss: 4.3484
itr: 10000 | loss: 1.2567
itr: 11000 | loss: 2.4781
itr: 12000 | loss: 0.5316
itr: 13000 | loss: 1.4191
itr: 14000 | loss: 0.1955
itr: 15000 | loss: 2.6258
itr: 16000 | loss: 1.2629
itr: 17000 | loss: 1.4119
itr: 18000 | loss: 2.5065
itr: 19000 | loss: 0.3363
itr: 20000 | loss: 0.7843
itr: 21000 | loss: 0.7646
itr: 22000 | loss: 4.6804
itr: 23000 | loss: 1.0049
itr: 24000 | loss: 2.0336
itr: 25000 | loss: 0.9852
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 2.8065
itr: 2000 | loss: 0.7453
itr: 3000 | loss: 2.6267
MAIN TRAINING ...
itr: 1000 | loss: 2.0546
itr: 2000 | loss: 1.7749
itr: 3000 | loss: 1.1992
itr: 4000 | loss: 2.5345
itr: 5000 | loss: 1.3994
itr: 6000 | loss: 0.9892
itr: 7000 | loss: 2.9864
itr: 8000 | loss: 4.2047
itr: 9000 | loss: 3.5885
itr: 10000 | loss: 1.4260
itr: 11000 | loss: 1.3087
itr: 12000 | loss: 0.6148
itr: 13000 | loss: 2.3212
itr: 14000 | loss: 2.0202
itr: 15000 | loss: 0.9525
itr: 16000 | loss: 0.4902
itr: 17000 | loss: 0.4270
itr: 18000 | loss: 2.1343
itr: 19000 | loss: 2.0121
itr: 20000 | loss: 0.6746
itr: 21000 | loss: 2.1098
itr: 22000 | loss: 2.5651
itr: 23000 | loss: 6.9713
itr: 24000 | loss: 4.1921
itr: 25000 | loss: 1.1775
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 1.1583
itr: 2000 | loss: 1.1451
itr: 3000 | loss: 4.8293
MAIN TRAINING ...
itr: 1000 | loss: 1.6407
itr: 2000 | loss: 1.6887
itr: 3000 | loss: 1.9483
itr: 4000 | loss: 3.0367
itr: 5000 | loss: 2.8584
itr: 6000 | loss: 0.2787
itr: 7000 | loss: 6.3155
itr: 8000 | loss: 3.6637
itr: 9000 | loss: 0.7621
itr: 10000 | loss: 2.8473
itr: 11000 | loss: 3.1583
itr: 12000 | loss: 2.6294
itr: 13000 | loss: 2.1872
itr: 14000 | loss: 0.8775
itr: 15000 | loss: 0.4857
itr: 16000 | loss: 3.9915
itr: 17000 | loss: 3.0549
itr: 18000 | loss: 0.6294
itr: 19000 | loss: 1.1460
itr: 20000 | loss: 0.1670
itr: 21000 | loss: 0.9825
itr: 22000 | loss: 1.5001
itr: 23000 | loss: 0.1868
itr: 24000 | loss: 1.3690
itr: 25000 | loss: 3.5133
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 3.9790
itr: 2000 | loss: 1.8757
itr: 3000 | loss: 0.2334
MAIN TRAINING ...
itr: 1000 | loss: 0.8407
itr: 2000 | loss: 1.5074
itr: 3000 | loss: 2.0114
itr: 4000 | loss: 1.0215
itr: 5000 | loss: 5.0955
itr: 6000 | loss: 2.3537
itr: 7000 | loss: 0.4524
itr: 8000 | loss: 6.3093
itr: 9000 | loss: 2.6688
itr: 10000 | loss: 3.7370
itr: 11000 | loss: 3.8936
itr: 12000 | loss: 4.7863
itr: 13000 | loss: 3.9842
itr: 14000 | loss: 4.5772
itr: 15000 | loss: 1.0349
itr: 16000 | loss: 2.1813
itr: 17000 | loss: 0.2753
itr: 18000 | loss: 0.9669
itr: 19000 | loss: 1.7651
itr: 20000 | loss: 3.5540
itr: 21000 | loss: 0.3737
itr: 22000 | loss: 1.8538
itr: 23000 | loss: 0.0913
itr: 24000 | loss: 2.1730
itr: 25000 | loss: 0.7346
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p



BURN-IN TRAINING ...
itr: 1000 | loss: 2.0218
itr: 2000 | loss: 0.6661
itr: 3000 | loss: 4.2295
MAIN TRAINING ...
itr: 1000 | loss: 1.8945
itr: 2000 | loss: 2.6535
itr: 3000 | loss: 2.2355
itr: 4000 | loss: 0.6514
itr: 5000 | loss: 2.7071
itr: 6000 | loss: 3.5986
itr: 7000 | loss: 2.2786
itr: 8000 | loss: 0.3285
itr: 9000 | loss: 1.3448
itr: 10000 | loss: 1.9038
itr: 11000 | loss: 1.8339
itr: 12000 | loss: 0.8511
itr: 13000 | loss: 1.2642
itr: 14000 | loss: 3.5479
itr: 15000 | loss: 2.1746
itr: 16000 | loss: 2.8564
itr: 17000 | loss: 0.1947
itr: 18000 | loss: 2.1891
itr: 19000 | loss: 5.9936
itr: 20000 | loss: 0.7184
itr: 21000 | loss: 0.9805
itr: 22000 | loss: 0.9083
itr: 23000 | loss: 2.5841
itr: 24000 | loss: 3.6883
itr: 25000 | loss: 2.6913
--------------------------------------------------------
- C-INDEX: 
                       eval_time 0
pred_time 1: event_1     -1.000000
pred_time 2: event_1     -1.000000
pred_time 3: event_1     -1.000000
pred_time 4: event_1     -1.000000
p

In [None]:
fold