In [218]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
import pickle

# In the real code we'll use argparse instead of namespace, but
# namespace lets us set the params from within the notebook
from argparse import Namespace

In [2]:
sess = tf.InteractiveSession()

In [3]:
seqs = [[5, 4], [2,3,4],[-1], [5, 6, 4], [0, 3, 6, 5], [-1], [0, 3], [5, 6, 2], [3, 4],[-1],[0,6],[0,5,1],[5,6,2,1],[0,1,5]]
seqs

[[5, 4],
 [2, 3, 4],
 [-1],
 [5, 6, 4],
 [0, 3, 6, 5],
 [-1],
 [0, 3],
 [5, 6, 2],
 [3, 4],
 [-1],
 [0, 6],
 [0, 5, 1],
 [5, 6, 2, 1],
 [0, 1, 5]]

In [108]:
# Here, seqs is a phoney data set with an arbitrary number of patients already, but in the real graph we'll need
# to specify the batch size, which will be n_patients.
# Also, options won't be a thing in the real code. That'll come from the argparser, and args will be set accordingly
options = {'n_patients': 4, 'max_v': 6, 'max_t': 5, 'n_codes': 7, 'code_emb_dim': 4, 'visit_emb_dim': 4,
           'log_eps': 1e-6, 'win': 3}
args = Namespace(**options)


def fill_visit(visit, args):
    """Fill all deficit visits with -2.
    
    Ensure that all visits have the same number of ICDs
    for efficient tensor logic. If a visit has fewer ICDs,
    filler ICDs get one-hot encoded as the zero vector,
    so that they affect nothing.
    
    visit: a list of integer medical codes
    
    Note: No visit in training or testing should have more
    than max_v visits.
    """
    max_v = options['max_v']
    if visit != [-1]:
        new_visit = []
        new_visit.extend(visit)
        n_icd = len(visit)
        deficit = max_v - n_icd
        new_visit.extend([-2] * deficit)
        return new_visit
    

def fill_patient(patient, mask_batch, args):
    """Ensure that all patients have max_t visits.
    
    Create visits full of -2s, which are one-hot
    encoded as zero vectors. This makes all patients
    commensurate for efficient tensor logic.
    
    patient: list of list of integer codes
    max_t: the number of visits all patients ought to have
    
    Note: No patient in training or test data should have more 
    than max_t visits.
    """
    max_t = args.max_t
    max_v = args.max_v
    new_patient = []
    new_patient.extend(patient)
    new_mask_batch = mask_batch
    t = len(new_patient)
    deficit = (max_t - t)
    new_patient.extend([[-2] * max_v] * deficit)
    new_mask_batch.append([[0] * max_v] * deficit)
    return new_patient, new_mask_batch, t

def tensorize_seqs(seqs, args):
    """Convert med2vec to tensorflow data.
    
    seqs: list of list. cf  https://github.com/mp2893/med2vec
    
    returns:
        patients: tensor with shape [patients, max_t, max_v, |C|]
        row_masks: numpy array with shape [patients, max_t, max_v]
               Later, we will create a [patients, max_t, max_v, |C|]
               tensor where the [p, t, i, j] entry is p(c_j|c_i).
               Row_masks will drop the rows where c_i is the zero
               vector--that is, an NA ICD.
               
               A separate mask, col_mask, will be created from
               patients in order to mask, for each t, those j for
               which c_j did not appear in visit t, as well as p(c_i|c_i).
               
               The masks are to be applied in reverse order of creation.
               col_mask is applied with tf.multiply and row_masks
               with tf.boolean_mask to avoid needless reshaping.
    """
    max_v = args.max_v
    n_codes = args.n_codes
    patients = []
    new_patient = []
    row_masks = []
    mask_batch = []
    patients_ts = []
    for visit in seqs + [[-1]]:
        if visit != [-1]:
            visit = fill_visit(visit, args)
            new_patient.append(visit)
        else:
            new_patient, mask_batch, t = fill_patient(new_patient,
                                                   mask_batch,
                                                   args)
            patients.append(new_patient)
            patients_ts.append(t)
            row_masks.append(mask_batch)
            new_patient = []
            mask_batch = []
    patients = np.array(patients)
    row_masks = (patients != -2)
    patients = tf.one_hot(patients, depth=n_codes)
    return patients, row_masks, np.array(patients_ts, dtype=np.float32)

def col_masks(patients, args):
    """Create a mask to cover non-present ICDs.
    
    For each V_t, for each c_i in V_t,
    zero out those p(c_j|c_i) for which c_j is not
    in V_t or for which i==j.
    
    See doc string for tensorize_seqs.
    
    patients: [patients,max_t,max_v,|C|] tensor
    
    returns: a binary tensor with shape patients.shape
    """
    max_v = args.max_v
    x_t = tf.reduce_sum(patients, axis=-2)
    x_t = tf.expand_dims(x_t, -2)
    x_t = tf.tile(x_t, [1,1,max_v,1])
    col_masks = x_t - patients
    return col_masks

In [111]:
patients, row_masks, visit_counts = tensorize_seqs(seqs, args)

In [76]:
# d is the length of the demographic vector, chosen arbitrarily here
demo_dim = 6
D_t = tf.truncated_normal([args.n_patients, args.max_t, demo_dim], # [patients, max_t, demo_dim]
                          mean=0.0,
                          stddev=5.0)

In [97]:
# Change these constants to Variables for the actual code, but we can't debug on Variables.
W_c = tf.truncated_normal([args.code_emb_dim, args.n_codes],
                                      mean=0.0,
                                      stddev=1.0,
                                      #dtype=tf.float32
                                      )
W_v = tf.truncated_normal([args.visit_emb_dim, args.code_emb_dim + d],
                                      mean=0.0,
                                      stddev=1.0,
                                      #dtype=tf.float32
                                      )
W_s = tf.truncated_normal([args.n_codes, args.visit_emb_dim],
                                      mean=0.0,
                                      stddev=1.0,
                                      dtype=tf.float32
                                      )

b_c = tf.zeros([W_c.shape[0], 1], dtype=tf.float32)
b_v = tf.zeros([W_v.shape[0], 1], dtype=tf.float32)
b_s = tf.zeros([W_s.shape[0], 1], dtype=tf.float32)

In [66]:
def codes_cost(patients, row_masks, visit_counts, W_c=W_c, b_c=b_c, args=args):
    """Calculate the cost for the code embeddings."""
    W_c_prime = tf.nn.relu(W_c)
    
    # tf.matmul doesn't broadcast, and we need to keep these grouped by visit,
    # so we need to tile W_c to one copy for every (real or dummy) visit 
    W_c_tiled = tf.expand_dims(W_c_prime, 0)
    W_c_tiled = tf.expand_dims(W_c_tiled, 0)
    W_c_tiled = tf.tile(W_c_tiled, [args.n_patients, args.max_t, 1, 1])
    
    # w_ij is a n_patients X max_t array of code_emb_dim X max_v
    # matrices whose columns are the representations of the codes
    # appearing in each visit in seqs
    w_ij = tf.matmul(W_c_tiled, patients, transpose_b=True)
    
    # We want a patients X visits X max_v array of
    # code_emb_dim X 1 vectors which are the columns
    # from w_ij.
    w_ij = tf.transpose(w_ij, [0,1,3,2])
    
    w_ij_shape = [args.n_patients,
                  args.max_t,
                  args.max_v,
                  args.code_emb_dim,
                  1]
    w_ij = tf.reshape(w_ij, w_ij_shape)
    
    # tf.multiply will broadcast these columns to
    # each column of W_c in each tile of W_c_tiled
    pre_sum = tf.multiply(W_c_prime, w_ij)
    logits = tf.reduce_sum(pre_sum, -2)
    
    # Logits now has a n_patients X max_t array of
    # max_v X n_codes vectors whose i, jth element
    # is the dot product of the code embedding of
    # code i (which appears in visit t) with code j
    # (which may or may not)
    
    # The probability of code j given that code i
    # is in the same visit
    p_j_i = tf.nn.softmax(logits, -1)
    
    log_p_j_i = tf.log(p_j_i + args.log_eps)
    
    # Create mask, but don't use it yet. See docstring for col_masks
    col_mask = col_masks(patients, args)
    
    # non_norm because we haven't divided by the number
    # of real visits for each patient yet.
    non_norm_summands = tf.multiply(log_p_j_i, col_mask)
    
    # Now for each patient divide by number of real visits of that patient
    # Mask rows corresponding to NA ICDs and p_i_i's afterward to ensure 
    # patient-by-patient division
    summands_w_dummies = non_norm_summands / tf.reshape(visit_counts, [args.n_patients,1,1,1])
    summands = tf.boolean_mask(summands_w_dummies, row_masks)
    codes_cost_per_visit = tf.reduce_sum(summands, -1)
    
    # Final cost is the batch average per patient of each patient's average
    # per visit cost
    codes_cost = tf.reduce_mean(codes_cost_per_visit)
    return codes_cost

In [69]:
# x_ts gets used in predictions and visits cost calculations both, so make them outside of both functions
x_ts = tf.reduce_sum(pat

In [208]:
def predictions(x_ts, W_c=W_c, D_t=D_t, W_v=W_v, W_s=W_s, b_c=b_c, b_v=b_v, b_s=b_s, args=args):
    """Get \hat{y}_t."""
    
    # We don't need to group by visit in this branch. We also don't need
    # to buffer patients with dummy visits.
    x_2d = tf.reshape(x_ts, [-1, args.n_codes])
    dummy_visit_mask = tf.reshape(tf.minimum(tf.reduce_sum(x_2d, -1), 1), [-1,])
    
    d_2d = tf.reshape(D_t, [-1, demo_dim])

    u_ts = tf.matmul(W_c, x_2d, transpose_b=True)
    u_ts = tf.add(u_ts, b_c)
    u_ts = tf.transpose(u_ts)
    
    # In order to store D_t as a tensor it will need to have
    # dummy visits just like x_ts does. This also ensures that
    # everything aligns correctly when we concatenate, here.
    # But after concatenating, we can ditch the dummy visits.
    full_vec = tf.concat([u_ts, d_2d], axis=-1)
    full_vec = tf.boolean_mask(full_vec, dummy_visit_mask)

    v_t = tf.matmul(W_v, full_vec, transpose_b=True)
    v_t = tf.add(v_t, b_v)
    v_t = tf.transpose(v_t)

    pre_soft = tf.matmul(W_s, v_t, transpose_b=True)
    pre_soft = tf.add(pre_soft, b_s)
    pre_soft = tf.transpose(pre_soft)

    y_2d = tf.nn.softmax(pre_soft, axis=-1)
    return y_2d

In [209]:
y_2d = predictions(x_ts, W_c=W_c, D_t=D_t, W_v=W_v, W_s=W_s, b_c=b_c, b_v=b_v, b_s=b_s, args=args)
y_2d

<tf.Tensor 'Softmax_19:0' shape=(?, 7) dtype=float32>

In [210]:
# One row for every real visit, as expected.
y_2d.eval()

array([[8.6085853e-18, 9.0658553e-08, 4.4319490e-09, 3.5209838e-02,
        9.6479005e-01, 1.0396769e-26, 2.0756780e-22],
       [6.9340423e-08, 6.5690224e-17, 1.8994191e-30, 2.2932563e-12,
        5.3013420e-32, 9.9999988e-01, 2.2513388e-19],
       [1.2073571e-26, 3.6456187e-16, 1.6760706e-03, 3.2039849e-15,
        9.9832386e-01, 2.3809569e-35, 2.4894056e-18],
       [6.4953824e-08, 2.7575591e-17, 4.9990022e-21, 1.6656639e-14,
        3.2074747e-25, 9.9999988e-01, 7.8157717e-13],
       [3.5257922e-15, 1.3173596e-33, 3.2943777e-33, 1.4652709e-26,
        0.0000000e+00, 1.0000000e+00, 8.6632166e-19],
       [8.6561788e-04, 7.5667504e-07, 9.6243136e-15, 1.0567531e-08,
        2.7946128e-15, 9.9913329e-01, 4.0201925e-07],
       [3.1391869e-04, 5.1624817e-03, 4.7404148e-02, 7.8861404e-04,
        9.4576669e-01, 1.9582774e-05, 5.4457737e-04],
       [0.0000000e+00, 4.2276565e-20, 5.2035183e-13, 2.9162840e-27,
        1.0000000e+00, 0.0000000e+00, 2.3846293e-33],
       [2.5320991e-26, 1

In [213]:
def loop_ops(win_start, total):
    """Slide window function.
    
    Add x_ts from surrounding visits together before
    taking the dot product with log(\hat{y}).
    
    For passing to tf.while_loop
    """
    summand = tf.slice(x_double_pad, [win_start, 0], normed_x_pad_2d.shape)
    return (win_start - 1, tf.add(total, summand))

def visits_cost(x_ts, y_2d, visit_counts, args):
    """Calculate the visits cost."""
    
    # We'll add the x vectors within the window before taking the dot
    # product with \hat{y}_t. To do this, we need to use a sliding
    # window, and to make sure patients' sums don't gather terms
    # from other patients, we need to pad each patient
    x_pad = tf.pad(x_ts, [[0,0], [args.win, args.win], [0, 0]])
    
    # Because different \hat{y}_t have different numbers of
    # neighboring x_t in their window, we can't really avoid passing
    # 1-x_ts through the same loop as x_ts by subtracting final_x_totals
    # from 2*win / visit_counts, say
    z_pad = 1. - x_pad
    
    # Note that this is a different mask than the one produced in predictions.
    visit_mask = tf.minimum(tf.reduce_sum(x_pad, -1), 1)
    visit_mask = tf.reshape(visit_mask, [-1,])
    
    # We need to flatten x_pad to do the window function, so divide each x
    # by the number of visits of that patient *first*.
    normed_x_pad = x_pad / tf.reshape(visit_counts, [args.n_patients, 1, 1])
    normed_z_pad = z_pad / tf.reshape(visit_counts, [args.n_patients, 1, 1])
    
    normed_x_pad_2d = tf.reshape(normed_x_pad, [-1, args.n_codes])
    normed_z_pad_2d = tf.reshape(normed_z_pad, [-1, args.n_codes])
    
    # Before we padded around each patient. Now pad around the entire list of visits
    x_double_pad = tf.pad(normed_x_pad_2d, [[args.win, args.win], [0, 0]])
    z_double_pad = tf.pad(normed_z_pad_2d, [[args.win, args.win], [0, 0]])
    
    def loop_ops(win_start, totalx, totalz):
        """Slide window function.

        Add x_ts from surrounding visits together before
        taking the dot product with log(\hat{y}).

        For passing to tf.while_loop
        """
        summandx = tf.slice(x_double_pad, [win_start, 0], normed_x_pad_2d.shape)
        summandz = tf.slice(z_double_pad, [win_start, 0], normed_z_pad_2d.shape)
        return (win_start - 1, tf.add(totalx, summandx), tf.add(totalz, summandz))
    
    win_start = 2 * args.win
    totalx = tf.zeros(normed_x_pad_2d.shape)
    totalz = tf.zeros(normed_z_pad_2d.shape)
    loop_cond = lambda win_start, totalx, totalz: tf.less(-1, win_start)
    loop_fn = lambda win_start, totalx, totalz: loop_ops(win_start, totalx, totalz)
    _, window_x_total, window_z_total = tf.while_loop(loop_cond, loop_ops, (win_start, totalx, totalz))

    # Subtract out x_{t+0}
    correct_x_totals_pad = tf.subtract(window_x_total, normed_x_pad_2d)
    correct_z_totals_pad = tf.subtract(window_z_total, normed_z_pad_2d)
    
    final_x_total = tf.boolean_mask(correct_x_totals_pad, visit_mask)
    final_z_total = tf.boolean_mask(correct_z_totals_pad, visit_mask)

    summandsx = tf.multiply(final_x_total, tf.log(y_2d + args.log_eps))
    summandsz = tf.multiply(final_z_total, tf.log(1. - y_2d + args.log_eps))
    
    sumx = tf.reduce_sum(summandsx)
    sumz = tf.reduce_sum(summandsz)

    visits_cost = tf.subtract(sumz, sumx)
    return visits_cost

In [216]:
D_t.eval()

array([[[-4.468975  ,  0.23625278, -5.0017333 ,  6.235909  ,
          7.915654  ,  5.0930014 ],
        [-1.3927264 ,  2.830865  , -5.4919147 ,  7.143624  ,
         -1.1168783 ,  6.765215  ],
        [-1.8450476 ,  5.873818  , -0.3785586 ,  2.9186127 ,
          2.6538143 , -2.3643394 ],
        [ 4.080118  , -2.243655  ,  4.24465   , -3.3345423 ,
          5.734318  ,  2.8201718 ],
        [ 6.5554004 , -2.7130067 ,  4.681351  , -0.59912705,
          1.648201  ,  3.964186  ]],

       [[-2.48018   , -2.193721  ,  9.169472  ,  2.6991568 ,
          4.9216933 , -0.7434596 ],
        [ 3.7783136 , -0.4528087 ,  0.96824914,  0.4719281 ,
          0.0434884 , -1.6019058 ],
        [-3.5891201 ,  5.016054  ,  1.9721509 , -0.7044741 ,
          6.6981606 , -2.4262743 ],
        [ 2.2948596 ,  1.4530307 ,  2.363258  ,  0.5358587 ,
         -0.42536604,  2.8189712 ],
        [-2.6930475 ,  8.998437  , -3.9350052 ,  0.72095776,
          5.6756454 , -0.17493442]],

       [[-6.3981533 , -6.1

In [221]:
with open('seqs', 'wb') as seqs_file:
    pickle.dump(seqs, seqs_file, protocol=2)
    
with open('demo', 'wb') as demo_file:
    pickle.dump(D_t.eval(), demo_file, protocol=2)

In [222]:
patients.shape

TensorShape([Dimension(4), Dimension(5), Dimension(6), Dimension(7)])

In [226]:
seqs

[[5, 4],
 [2, 3, 4],
 [-1],
 [5, 6, 4],
 [0, 3, 6, 5],
 [-1],
 [0, 3],
 [5, 6, 2],
 [3, 4],
 [-1],
 [0, 6],
 [0, 5, 1],
 [5, 6, 2, 1],
 [0, 1, 5]]

In [227]:
# n_labels = 4
labs = [[3, 2],
 [0, 3, 2],
 [-1],
 [3, 3, 2],
 [0, 3, 3, 2],
 [-1],
 [0, 2],
 [3, 2, 0],
 [1, 2],
 [-1],
 [0, 1],
 [0, 3, 0],
 [2, 2, 0, 1],
 [0, 1, 3]]

In [228]:
with open('labs', 'wb') as labs_file:
    pickle.dump(labs, labs_file, protocol=2)