In [1]:
import tensorflow as tf
tf.enable_eager_execution()

In [2]:
import import_data as impt
import numpy as np

(x_dim, x_dim_cont, x_dim_bin), (data, time, label), (mask1, mask2, mask3), (data_mi) = impt.import_dataset(norm_mode = 'standard')

_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]

pred_time = [52, 3*52, 5*52] # prediction time (in months)
eval_time = [12, 36, 60, 120] # months evaluation time (for C-index and Brier-Score)

In [3]:
# Hyperparams
burn_in_mode                = 'ON' #{'ON', 'OFF'}
boost_mode                  = 'ON' #{'ON', 'OFF'}

##### HYPER-PARAMETERS
new_parser = {'mb_size': 32,

             'iteration_burn_in': 3000,
             'iteration': 25000,

             'keep_prob': 0.6,
             'lr_train': 1e-4,

             'h_dim_RNN': 100,
             'h_dim_FC' : 100,
             'num_layers_RNN':2,
             'num_layers_ATT':2,
             'num_layers_CS' :2,

             'RNN_type':'LSTM', #{'LSTM', 'GRU'}

             'FC_active_fn' : tf.nn.relu,
             'RNN_active_fn': tf.nn.tanh,

            'reg_W'         : 1e-5,
            'reg_W_out'     : 0.,

             'alpha' :1.0,
             'beta'  :0.1,
             'gamma' :1.0
}


# INPUT DIMENSIONS
input_dims                  = { 'x_dim'         : x_dim,
                                'x_dim_cont'    : x_dim_cont,
                                'x_dim_bin'     : x_dim_bin,
                                'num_Event'     : num_Event,
                                'num_Category'  : num_Category,
                                'max_length'    : max_length }

# NETWORK HYPER-PARMETERS
network_settings            = { 'h_dim_RNN'         : new_parser['h_dim_RNN'],
                                'h_dim_FC'          : new_parser['h_dim_FC'],
                                'num_layers_RNN'    : new_parser['num_layers_RNN'],
                                'num_layers_ATT'    : new_parser['num_layers_ATT'],
                                'num_layers_CS'     : new_parser['num_layers_CS'],
                                'RNN_type'          : new_parser['RNN_type'],
                                'FC_active_fn'      : new_parser['FC_active_fn'],
                                'RNN_active_fn'     : new_parser['RNN_active_fn'],
                                'initial_W'         : tf.contrib.layers.xavier_initializer(),

                                'reg_W'             : new_parser['reg_W'],
                                'reg_W_out'         : new_parser['reg_W_out']
                                 }


mb_size           = new_parser['mb_size']
iteration         = new_parser['iteration']
iteration_burn_in = new_parser['iteration_burn_in']

keep_prob         = new_parser['keep_prob']
lr_train          = new_parser['lr_train']

alpha             = new_parser['alpha']
beta              = new_parser['beta']
gamma             = new_parser['gamma']



For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [12]:
import numpy as np
import tensorflow as tf
import random
import utils_network as utils
from tensorflow.contrib.layers import fully_connected as FC_Net

_EPSILON = 1e-08

def log(x):
    return tf.log(x + _EPSILON)

def div(x, y):
    return tf.div(x, (y + _EPSILON))

def get_seq_length(sequence):
    used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))
    tmp_length = tf.reduce_sum(used, 1)
    tmp_length = tf.cast(tmp_length, tf.int32)
    return tmp_length

# INPUT DIMENSIONS
x_dim              = input_dims['x_dim']
x_dim_cont         = input_dims['x_dim_cont']
x_dim_bin          = input_dims['x_dim_bin']

num_Event          = input_dims['num_Event']
num_Category       = input_dims['num_Category']
max_length         = input_dims['max_length']

# NETWORK HYPER-PARMETERS
h_dim1             = network_settings['h_dim_RNN']
h_dim2             = network_settings['h_dim_FC']
num_layers_RNN     = network_settings['num_layers_RNN']
num_layers_ATT     = network_settings['num_layers_ATT']
num_layers_CS      = network_settings['num_layers_CS']

RNN_type           = network_settings['RNN_type']

FC_active_fn       = network_settings['FC_active_fn']
RNN_active_fn      = network_settings['RNN_active_fn']
initial_W          = network_settings['initial_W']

reg_W              = tf.contrib.layers.l1_regularizer(scale=network_settings['reg_W'])
reg_W_out          = tf.contrib.layers.l1_regularizer(scale=network_settings['reg_W_out'])


INFO:tensorflow:Scale of 0 disables regularizer.


In [15]:
x = data

seq_length     = get_seq_length(x)
tmp_range      = tf.expand_dims(tf.range(0, max_length, 1), axis=0)

rnn_mask1 = tf.cast(tf.less_equal(tmp_range, tf.expand_dims(seq_length - 1, axis=1)), tf.float32)            
rnn_mask2 = tf.cast(tf.equal(tmp_range, tf.expand_dims(seq_length - 1, axis=1)), tf.float32) 


### DEFINE LOOP FUNCTION FOR RAW_RNN w/ TEMPORAL ATTENTION
def loop_fn_att(time, cell_output, cell_state, loop_state):

    emit_output = cell_output 

    if cell_output is None:  # time == 0
        next_cell_state = cell.zero_state(mb_size, tf.float32)
        next_loop_state = loop_state_ta
    else:
        next_cell_state = cell_state
        tmp_h = utils.create_concat_state(next_cell_state, num_layers_RNN, RNN_type)

        e = utils.create_FCNet(tf.concat([tmp_h, all_last], axis=1), num_layers_ATT, h_dim2, 
                                tf.nn.tanh, 1, None, initial_W, keep_prob=keep_prob)
        e = tf.exp(e)

        next_loop_state = (loop_state[0].write(time-1, e),                # save att power (e_{j})
                            loop_state[1].write(time-1, tmp_h))  # save all the hidden states

    # elements_finished = (time >= seq_length)
    elements_finished = (time >= max_length-1)

    #this gives the break-point (no more recurrence after the max_length)
    finished = tf.reduce_all(elements_finished)    
    next_input = tf.cond(finished, lambda: tf.zeros([mb_size, 2*x_dim], dtype=tf.float32),  # [x_hist, mi_hist]
                                    lambda: inputs_ta.read(time))

    return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)



# divide into the last x and previous x's
x_last = tf.slice(x, [0,(max_length-1), 1], [-1,-1,-1])      #current measurement
x_last = tf.reshape(x_last, [-1, (x_dim_cont+x_dim_bin)])    #remove the delta of the last measurement

x_last = tf.reduce_sum(tf.tile(tf.expand_dims(rnn_mask2, axis=2), [1,1,x_dim]) * x, reduction_indices=1)    #sum over time since all others time stamps are 0
x_last = tf.slice(x_last, [0,1], [-1,-1])                               #remove the delta of the last measurement
x_hist = x * (1.-tf.tile(tf.expand_dims(rnn_mask2, axis=2), [1,1,x_dim]))                                    #since all others time stamps are 0 and measurements are 0-padded
x_hist = tf.slice(x_hist, [0, 0, 0], [-1,(max_length-1),-1])  

# do same thing for missing indicator
mi_last = tf.slice(x_mi, [0,(max_length-1), 1], [-1,-1,-1])      #current measurement
mi_last = tf.reshape(mi_last, [-1, (x_dim_cont+x_dim_bin)])    #remove the delta of the last measurement

mi_last = tf.reduce_sum(tf.tile(tf.expand_dims(rnn_mask2, axis=2), [1,1,x_dim]) * x_mi, reduction_indices=1)    #sum over time since all others time stamps are 0
mi_last = tf.slice(mi_last, [0,1], [-1,-1])                               #remove the delta of the last measurement
mi_hist = x_mi * (1.-tf.tile(tf.expand_dims(rnn_mask2, axis=2), [1,1,x_dim]))                                    #since all others time stamps are 0 and measurements are 0-padded
mi_hist = tf.slice(mi_hist, [0, 0, 0], [-1,(max_length-1),-1])  

all_hist = tf.concat([x_hist, mi_hist], axis=2)
all_last = tf.concat([x_last, mi_last], axis=1)


#extract inputs for the temporal attention: mask (to incorporate only the measured time) and x_{M}
seq_length     = get_seq_length(x_hist)
rnn_mask_att   = tf.cast(tf.not_equal(tf.reduce_sum(x_hist, reduction_indices=2), 0), dtype=tf.float32)  #[mb_size, max_length-1], 1:measurements 0:no measurements


##### SHARED SUBNETWORK: RNN w/ TEMPORAL ATTENTION
#change the input tensor to TensorArray format with [max_length, mb_size, x_dim]
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_length-1).unstack(_transpose_batch_time(all_hist), name = 'Shared_Input')


#create a cell with RNN hyper-parameters (RNN types, #layers, #nodes, activation functions, keep proability)
cell = utils.create_rnn_cell(h_dim1, num_layers_RNN, keep_prob, 
                                RNN_type, RNN_active_fn)

#define the loop_state TensorArray for information from rnn time steps
loop_state_ta = (tf.TensorArray(size=max_length-1, dtype=tf.float32),  #e values (e_{j})
                    tf.TensorArray(size=max_length-1, dtype=tf.float32))  #hidden states (h_{j})

rnn_outputs_ta, rnn_final_state, loop_state_ta = tf.nn.raw_rnn(cell, loop_fn_att)
#rnn_outputs_ta  : TensorArray
#rnn_final_state : Tensor
#rnn_states_ta   : (TensorArray, TensorArray)

rnn_outputs = _transpose_batch_time(rnn_outputs_ta.stack())
# rnn_outputs =  tf.reshape(rnn_outputs, [-1, max_length-1, h_dim1])

rnn_states  = _transpose_batch_time(loop_state_ta[1].stack())

att_weight  = _transpose_batch_time(loop_state_ta[0].stack()) #e_{j}
att_weight  = tf.reshape(att_weight, [-1, max_length-1]) * rnn_mask_att # masking to set 0 for the unmeasured e_{j}

#get a_{j} = e_{j}/sum_{l=1}^{M-1}e_{l}
att_weight  = div(att_weight,(tf.reduce_sum(att_weight, axis=1, keepdims=True) + _EPSILON)) #softmax (tf.exp is done, previously)

# 1) expand att_weight to hidden state dimension, 2) c = \sum_{j=1}^{M} a_{j} x h_{j}
context_vec = tf.reduce_sum(tf.tile(tf.reshape(att_weight, [-1, max_length-1, 1]), [1, 1, num_layers_RNN*h_dim1]) * rnn_states, axis=1)


z_mean      = FC_Net(rnn_outputs, x_dim, activation_fn=None, weights_initializer=initial_W, scope="RNN_out_mean1")
z_std       = tf.exp(FC_Net(rnn_outputs, x_dim, activation_fn=None, weights_initializer=initial_W, scope="RNN_out_std1"))

epsilon          = tf.random_normal([mb_size, max_length-1, x_dim], mean=0.0, stddev=1.0, dtype=tf.float32)
z           = z_mean + z_std * epsilon


##### CS-SPECIFIC SUBNETWORK w/ FCNETS 
inputs = tf.concat([x_last, context_vec], axis=1)


#1 layer for combining inputs
h = FC_Net(inputs, h_dim2, activation_fn=FC_active_fn, weights_initializer=initial_W, scope="Layer1")
h = tf.nn.dropout(h, keep_prob=keep_prob)

# (num_layers_CS-1) layers for cause-specific (num_Event subNets)
out = []
for _ in range(num_Event):
    cs_out = utils.create_FCNet(h, (num_layers_CS), h_dim2, FC_active_fn, h_dim2, FC_active_fn, initial_W, reg_W, keep_prob)
    out.append(cs_out)
out = tf.stack(out, axis=1) # stack referenced on subject
out = tf.reshape(out, [-1, num_Event*h_dim2])
out = tf.nn.dropout(out, keep_prob=keep_prob)

out = FC_Net(out, num_Event * num_Category, activation_fn=tf.nn.softmax, 
                weights_initializer=initial_W, weights_regularizer=reg_W_out, scope="Output")
out = tf.reshape(out, [-1, num_Event, num_Category])


##### GET LOSS FUNCTIONS
loss_Log_Likelihood()      #get loss1: Log-Likelihood loss
loss_Ranking()             #get loss2: Ranking loss
loss_RNN_Prediction()      #get loss3: RNN prediction loss

LOSS_TOTAL     = a*LOSS_1 + b*LOSS_2 + c*LOSS_3 + tf.losses.get_regularization_loss()
LOSS_BURNIN    = LOSS_3 + tf.losses.get_regularization_loss()

solver         = tf.train.AdamOptimizer(learning_rate=lr_rate).minimize(LOSS_TOTAL)
solver_burn_in = tf.train.AdamOptimizer(learning_rate=lr_rate).minimize(LOSS_BURNIN)



NameError: name 'x_mi' is not defined

In [6]:
# x = tf.random.normal([32,16,16], 0, 1, tf.float32, seed=1)
# x.shape[1]
# max_length = x.shape[1]
# tf.slice(x, [0,(max_length-1), 1], [-1,-1,-1])

<tf.Tensor: id=32, shape=(32, 1, 15), dtype=float32, numpy=
array([[[-1.21184146e+00,  4.48889762e-01, -7.37403184e-02,
          7.54432082e-01,  1.48847151e+00, -1.52995682e+00,
         -6.03821278e-01,  5.47428668e-01, -8.55400443e-01,
          1.44296741e+00, -6.51406288e-01, -3.30137253e-01,
         -1.14750016e+00, -2.40512347e+00, -3.09611768e-01]],

       [[ 1.14851883e-02, -1.92178220e-01, -2.14976120e+00,
         -5.76461256e-01, -1.12946346e-01,  1.36483836e+00,
          9.55794930e-01, -1.29790676e+00,  1.19078565e+00,
          1.66978955e+00,  5.24244070e-01,  5.06962612e-02,
         -1.28597736e-01,  9.25031304e-02, -6.54789388e-01]],

       [[ 1.40942708e-01,  1.30323052e+00,  3.42019469e-01,
          4.77353007e-01, -4.16334450e-01,  7.89854348e-01,
          8.25113595e-01, -5.37863255e-01,  1.52390099e+00,
          2.24795505e-01,  2.97209203e-01, -3.64836961e-01,
          3.32593441e-01, -8.87461960e-01,  1.05487764e+00]],

       [[ 5.10776818e-01,  1.30