In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
import collections

import random

In [2]:
sess = tf.InteractiveSession()

In [3]:
seqs = [[5, 8], [-1], [5, 6, 7], [0, 6, 9, 5], [-1], [9, 6], [5, 8, 7], [7, 4]]

In [4]:
def pairs(alist):
    return zip(alist, alist[1:])

# Code-level cost:

$\frac{1}{T}\sum\limits_{t=1}^T\sum\limits_{i:c_i\in V_t}\sum\limits_{j:c_j\in V_t, i\neq j}\log{\text{p}(c_j|c_i)}$

$\text{p}(c_j|c_i) = \frac{\exp{(e_j\cdot e_i)}}{\sum\limits_{k=1}^{|C|}\exp{(e_k\cdot e_i)}}$

Where $e_i$ is the $i^{th}$ column of $W_c$

## Assumptions:
- Should work with process_mimic, so seqs and types will have the form of the originals
- However, we will assume a visit has a maximum number (maxV = unspecified) of ICDs, some of which may be NA.
- A batch is a batch of patients, who have T_k visits, each of which has maxV ICDs, some of which may be NA (blank).
- That is, after processing, a batch will have shape [batch, maxT, maxV, |C|], where each [maxV, |C|] vector is a one_hot encoding of an ICD present in the visit (all zeros if one of the maxT ICDs is blank).

In [5]:
seqs = [[5, 8], [-1], [5, 6, 7], [0, 6, 9, 5], [-1], [9, 6], [5, 8, 7], [7, 4]]
seqs

[[5, 8], [-1], [5, 6, 7], [0, 6, 9, 5], [-1], [9, 6], [5, 8, 7], [7, 4]]

In [6]:
options = {'max_v': 5, 'max_t': 3, 'n_codes': 10}

def fill_visit(visit, **options):
    """Fill all deficit visits with -2.
    
    Ensure that all visits have the same number of ICDs
    for efficient tensor logic. If a visit has fewer ICDs,
    filler ICDs get one-hot encoded as the zero vector,
    so that they affect nothing.
    
    visit: a list of integer medical codes
    
    Note: No visit in training or testing should have more
    than max_v visits.
    """
    max_v = options['max_v']
    if visit != [-1]:
        new_visit = []
        new_visit.extend(visit)
        n_icd = len(visit)
        deficit = max_v - n_icd
        new_visit.extend([-2] * deficit)
        return new_visit
    

def fill_patient(patient, mask_batch, **options):
    """Ensure that all patients have max_t visits.
    
    Create visits full of -2s, which are one-hot
    encoded as zero vectors. This makes all patients
    commensurate for efficient tensor logic.
    
    patient: list of list of integer codes
    max_t: the number of visits all patients ought to have
    
    Note: No patient in training or test data should have more 
    than max_t visits.
    """
    max_t = options['max_t']
    max_v = options['max_v']
    new_patient = []
    new_patient.extend(patient)
    new_mask_batch = mask_batch
    t = len(new_patient)
    deficit = (max_t - t)
    new_patient.extend([[-2] * max_v] * deficit)
    new_mask_batch.append([[0] * max_v] * deficit)
    return new_patient, new_mask_batch, t

def tensorize_seqs(seqs, **options):
    """Convert med2vec to tensorflow data.
    
    seqs: list of list. cf  https://github.com/mp2893/med2vec
    
    returns:
        patients: tensor with shape [patients, max_t, max_v, |C|]
        row_masks: numpy array with shape [patients, max_t, max_v]
               Later, we will create a [patients, max_t, max_v, |C|]
               tensor where the [p, t, i, j] entry is p(c_j|c_i).
               Row_masks will drop the rows where c_i is the zero
               vector--that is, an NA ICD.
               
               A separate mask, col_mask, will be created from
               patients in order to mask, for each t, those j for
               which c_j did not appear in visit t, as well as p(c_i|c_i).
               
               The masks are to be applied in reverse order of creation.
               col_mask is applied with tf.multiply and row_masks
               with tf.boolean_mask to avoid needless reshaping.
    """
    max_v = options['max_v']
    n_codes = options['n_codes']
    patients = []
    new_patient = []
    row_masks = []
    mask_batch = []
    patients_ts = []
    for visit in seqs + [[-1]]:
        if visit != [-1]:
            visit = fill_visit(visit, **options)
            new_patient.append(visit)
        else:
            new_patient, mask_batch, t = fill_patient(new_patient,
                                                   mask_batch,
                                                   **options)
            patients.append(new_patient)
            patients_ts.append(t)
            row_masks.append(mask_batch)
            new_patient = []
            mask_batch = []
    patients = np.array(patients)
    row_masks = (patients != -2)
    patients = tf.one_hot(patients, depth=n_codes)
    return patients, row_masks, np.array(patients_ts, dtype=np.float32)

## Processing the cost of a single visit. Need to work on broadcasting for multiple visits, so batches can be in terms of patients

In [7]:
W = np.array([[1,2,3,5],[7,11,13,17]], dtype=np.float32)
W

array([[ 1.,  2.,  3.,  5.],
       [ 7., 11., 13., 17.]], dtype=float32)

In [8]:
c_i = np.array([1,0,0,0], dtype=np.float32)
c_j = np.array([0,0,1,0], dtype=np.float32)
c_k = np.array([0,0,0,0], dtype=np.float32)
x_t = c_i + c_j + c_k

In [9]:
V_t = np.stack([c_i, c_j, c_k])
V_t

array([[1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)

In [10]:
w_ij = tf.matmul(W, V_t, transpose_b=True)
w_ij = tf.reshape(tf.transpose(w_ij),[3,2,1])
w_ij.eval()

array([[[ 1.],
        [ 7.]],

       [[ 3.],
        [13.]],

       [[ 0.],
        [ 0.]]], dtype=float32)

In [11]:
w_i = tf.matmul(W, tf.reshape(c_i, [1,4]), transpose_b=True).eval()
w_i

array([[1.],
       [7.]], dtype=float32)

In [12]:
W_tiled = tf.tile(tf.expand_dims(W, 0), [3, 1, 1])
W_tiled.eval()

array([[[ 1.,  2.,  3.,  5.],
        [ 7., 11., 13., 17.]],

       [[ 1.,  2.,  3.,  5.],
        [ 7., 11., 13., 17.]],

       [[ 1.,  2.,  3.,  5.],
        [ 7., 11., 13., 17.]]], dtype=float32)

In [13]:
pre_sum = tf.multiply(W_tiled, w_ij)
pre_sum.eval()

array([[[  1.,   2.,   3.,   5.],
        [ 49.,  77.,  91., 119.]],

       [[  3.,   6.,   9.,  15.],
        [ 91., 143., 169., 221.]],

       [[  0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.]]], dtype=float32)

In [14]:
pre_soft = tf.reduce_sum(pre_sum, axis=-2)
pre_soft.eval()

array([[ 50.,  79.,  94., 124.],
       [ 94., 149., 178., 236.],
       [  0.,   0.,   0.,   0.]], dtype=float32)

In [15]:
tf.nn.softmax(pre_soft, axis=-1).eval()

array([[7.2812905e-33, 2.8625186e-20, 9.3576229e-14, 1.0000000e+00],
       [0.0000000e+00, 1.6458115e-38, 6.4702347e-26, 1.0000000e+00],
       [2.5000000e-01, 2.5000000e-01, 2.5000000e-01, 2.5000000e-01]],
      dtype=float32)

To do: deal with mask, log, sum.

## The above section all happens at the single-visit level. Need to expand to the level of multiple patients.

In [16]:
W = np.array([[1,2,3,5, 7],[11,13,17,23, 29]], dtype=np.float32)
W

array([[ 1.,  2.,  3.,  5.,  7.],
       [11., 13., 17., 23., 29.]], dtype=float32)

In [17]:
seqs = [[0,1],[1],[-1],[3,1],[1,2,3],[3]] # 2 patients, 2 and 3 visits respectively

In [18]:
options = {'max_t': 4, 'max_v': 3, 'n_codes': 5}

In [177]:
patients, row_masks, patients_ts = tensorize_seqs(seqs, **options)

In [20]:
row_masks

array([[[ True,  True, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False]],

       [[ True,  True, False],
        [ True,  True,  True],
        [ True, False, False],
        [False, False, False]]])

In [181]:
patients.eval()

array([[[[1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]],


       [[[0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]], dtype=float32)

In [22]:
W_tiled = tf.expand_dims(W, 0)
W_tiled = tf.expand_dims(W_tiled, 0)
W_tiled = tf.tile(W_tiled, [2, 4, 1, 1])
W_tiled.eval()

array([[[[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]],

        [[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]],

        [[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]],

        [[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]]],


       [[[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]],

        [[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]],

        [[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]],

        [[ 1.,  2.,  3.,  5.,  7.],
         [11., 13., 17., 23., 29.]]]], dtype=float32)

In [23]:
W_tiled.eval().shape

(2, 4, 2, 5)

In [24]:
w_ij = tf.matmul(W_tiled, patients, transpose_b=True)
w_ij.eval()

array([[[[ 1.,  2.,  0.],
         [11., 13.,  0.]],

        [[ 2.,  0.,  0.],
         [13.,  0.,  0.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.]]],


       [[[ 5.,  2.,  0.],
         [23., 13.,  0.]],

        [[ 2.,  3.,  5.],
         [13., 17., 23.]],

        [[ 5.,  0.,  0.],
         [23.,  0.,  0.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.]]]], dtype=float32)

In [25]:
w_ij.eval().shape

(2, 4, 2, 3)

In [26]:
w_ij = tf.transpose(w_ij, [0,1,3,2])
w_ij.eval()

array([[[[ 1., 11.],
         [ 2., 13.],
         [ 0.,  0.]],

        [[ 2., 13.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [ 0.,  0.]]],


       [[[ 5., 23.],
         [ 2., 13.],
         [ 0.,  0.]],

        [[ 2., 13.],
         [ 3., 17.],
         [ 5., 23.]],

        [[ 5., 23.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [ 0.,  0.]]]], dtype=float32)

In [27]:
w_ij = tf.reshape(w_ij,[2,4,3,2,1])
w_ij.eval()

array([[[[[ 1.],
          [11.]],

         [[ 2.],
          [13.]],

         [[ 0.],
          [ 0.]]],


        [[[ 2.],
          [13.]],

         [[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]]]],



       [[[[ 5.],
          [23.]],

         [[ 2.],
          [13.]],

         [[ 0.],
          [ 0.]]],


        [[[ 2.],
          [13.]],

         [[ 3.],
          [17.]],

         [[ 5.],
          [23.]]],


        [[[ 5.],
          [23.]],

         [[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.]]]]], dtype=float32)

In [28]:
W

array([[ 1.,  2.,  3.,  5.,  7.],
       [11., 13., 17., 23., 29.]], dtype=float32)

In [29]:
pre_sum = tf.multiply(W, w_ij)
pre_sum.eval()

array([[[[[  1.,   2.,   3.,   5.,   7.],
          [121., 143., 187., 253., 319.]],

         [[  2.,   4.,   6.,  10.,  14.],
          [143., 169., 221., 299., 377.]],

         [[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]]],


        [[[  2.,   4.,   6.,  10.,  14.],
          [143., 169., 221., 299., 377.]],

         [[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]]],


        [[[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]]],


        [[[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,   0.,   0.],
      

In [30]:
logits = tf.reduce_sum(pre_sum, -2)
logits.eval()

array([[[[122., 145., 190., 258., 326.],
         [145., 173., 227., 309., 391.],
         [  0.,   0.,   0.,   0.,   0.]],

        [[145., 173., 227., 309., 391.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.]],

        [[  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.]],

        [[  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.]]],


       [[[258., 309., 406., 554., 702.],
         [145., 173., 227., 309., 391.],
         [  0.,   0.,   0.,   0.,   0.]],

        [[145., 173., 227., 309., 391.],
         [190., 227., 298., 406., 514.],
         [258., 309., 406., 554., 702.]],

        [[258., 309., 406., 554., 702.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.]],

        [[  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.]]]

In [31]:
p_j_i = tf.nn.softmax(logits, -1)
p_j_i.eval()

array([[[[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.9374821e-30,
          1.0000000e+00],
         [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.4426007e-36,
          1.0000000e+00],
         [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01]],

        [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.4426007e-36,
          1.0000000e+00],
         [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01],
         [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01]],

        [[2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01],
         [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01],
         [2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01]],

        [[2.0000000e-01, 2.0000000e-01, 2.0000000e-01, 2.0000000e-01,
          2.0000000e-01],
         [2.0000000e-01, 2.0000000

In [32]:
log_eps = 1e-6

In [33]:
log_p_j_i = tf.log(p_j_i + log_eps)
log_p_j_i.eval()

array([[[[-1.3815511e+01, -1.3815511e+01, -1.3815511e+01,
          -1.3815511e+01,  9.5367386e-07],
         [-1.3815511e+01, -1.3815511e+01, -1.3815511e+01,
          -1.3815511e+01,  9.5367386e-07],
         [-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -1.6094329e+00]],

        [[-1.3815511e+01, -1.3815511e+01, -1.3815511e+01,
          -1.3815511e+01,  9.5367386e-07],
         [-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -1.6094329e+00],
         [-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -1.6094329e+00]],

        [[-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -1.6094329e+00],
         [-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -1.6094329e+00],
         [-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -1.6094329e+00]],

        [[-1.6094329e+00, -1.6094329e+00, -1.6094329e+00,
          -1.6094329e+00, -

In [34]:
log_p_j_i.eval().shape

(2, 4, 3, 5)

In [35]:
patients_ts.shape

(2,)

In [36]:
patients.eval().shape

(2, 4, 3, 5)

In [37]:
def col_masks(patients, **options):
    """Create a mask to cover non-present ICDs.
    
    For each V_t, for each c_i in V_t,
    zero out those p(c_j|c_i) for which c_j is not
    in V_t or for which i==j.
    
    See doc string for tensorize_seqs.
    
    patients: [patients,max_t,max_v,|C|] tensor
    
    returns: a binary tensor with shape patients.shape
    """
    max_v = options['max_v']
    x_t = tf.reduce_sum(patients, axis=-2)
    x_t = tf.expand_dims(x_t, -2)
    x_t = tf.tile(x_t, [1,1,max_v,1])
    col_masks = x_t - patients
    return col_masks

In [38]:
col_mask = col_masks(patients, **options)
col_mask.eval()

array([[[[0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]],


       [[[0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 1., 0., 1., 0.]],

        [[0., 0., 1., 1., 0.],
         [0., 1., 0., 1., 0.],
         [0., 1., 1., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]], dtype=float32)

In [39]:
# Mask p_i_i before dividing by number of visits to save computations
non_norm_summands = tf.multiply(log_p_j_i, col_mask)
non_norm_summands.eval()

array([[[[ -0.       , -13.815511 ,  -0.       ,  -0.       ,
            0.       ],
         [-13.815511 ,  -0.       ,  -0.       ,  -0.       ,
            0.       ],
         [ -1.6094329,  -1.6094329,  -0.       ,  -0.       ,
           -0.       ]],

        [[ -0.       ,  -0.       ,  -0.       ,  -0.       ,
            0.       ],
         [ -0.       ,  -1.6094329,  -0.       ,  -0.       ,
           -0.       ],
         [ -0.       ,  -1.6094329,  -0.       ,  -0.       ,
           -0.       ]],

        [[ -0.       ,  -0.       ,  -0.       ,  -0.       ,
           -0.       ],
         [ -0.       ,  -0.       ,  -0.       ,  -0.       ,
           -0.       ],
         [ -0.       ,  -0.       ,  -0.       ,  -0.       ,
           -0.       ]],

        [[ -0.       ,  -0.       ,  -0.       ,  -0.       ,
           -0.       ],
         [ -0.       ,  -0.       ,  -0.       ,  -0.       ,
           -0.       ],
         [ -0.       ,  -0.       ,  -0.       ,

In [40]:
# Now for each patient divide by number of real visits of that patient
# Mask rows corresponding to NA ICDs afterward to ensure patient-by-patient division
summands_w_dummies = non_norm_summands / tf.reshape(patients_ts, [2,1,1,1])
summands_w_dummies.eval()

array([[[[-0.        , -6.9077554 , -0.        , -0.        ,
           0.        ],
         [-6.9077554 , -0.        , -0.        , -0.        ,
           0.        ],
         [-0.80471647, -0.80471647, -0.        , -0.        ,
          -0.        ]],

        [[-0.        , -0.        , -0.        , -0.        ,
           0.        ],
         [-0.        , -0.80471647, -0.        , -0.        ,
          -0.        ],
         [-0.        , -0.80471647, -0.        , -0.        ,
          -0.        ]],

        [[-0.        , -0.        , -0.        , -0.        ,
          -0.        ],
         [-0.        , -0.        , -0.        , -0.        ,
          -0.        ],
         [-0.        , -0.        , -0.        , -0.        ,
          -0.        ]],

        [[-0.        , -0.        , -0.        , -0.        ,
          -0.        ],
         [-0.        , -0.        , -0.        , -0.        ,
          -0.        ],
         [-0.        , -0.        , -0.        ,

In [41]:
summands = tf.boolean_mask(summands_w_dummies, row_masks)
summands.eval()

array([[-0.       , -6.9077554, -0.       , -0.       ,  0.       ],
       [-6.9077554, -0.       , -0.       , -0.       ,  0.       ],
       [-0.       , -0.       , -0.       , -0.       ,  0.       ],
       [-0.       , -4.6051702, -0.       , -0.       ,  0.       ],
       [-0.       , -0.       , -0.       , -4.6051702,  0.       ],
       [-0.       , -0.       , -4.6051702, -4.6051702,  0.       ],
       [-0.       , -4.6051702, -0.       , -4.6051702,  0.       ],
       [-0.       , -4.6051702, -4.6051702, -0.       ,  0.       ],
       [-0.       , -0.       , -0.       , -0.       ,  0.       ]],
      dtype=float32)

In [42]:
codes_cost_per_visit = tf.reduce_sum(summands, -1)
codes_cost_per_visit.eval()

array([-6.9077554, -6.9077554,  0.       , -4.6051702, -4.6051702,
       -9.2103405, -9.2103405, -9.2103405,  0.       ], dtype=float32)

In [43]:
codes_cost = tf.reduce_mean(codes_cost_per_visit)
codes_cost.eval()

-5.6285415

## That's the code-level cost function. Now for the visits level cost function.

$\frac{1}{T}\sum\limits_{t=1}^T\sum\limits_{-w\leq i\leq w}-x_{t+i}^T\log{\hat{y_t}}+(1-x_{t+i})^T\log{(1-\hat{y_t})}$

Where $w$ is a pre-defined window of visits and

$\hat{y_t} = \frac{\exp{(W_sv_t+b_s)}}{\sum\limits_{j=1}^{|C|}\exp(W_{sj}v_t+b_{sj})}$

Where $W_{sj}$ is the $j^{th}$ row of $W_s$ and $b_{sj}$ is the $j^{th}$ element of $b_s$.

In [186]:
x_ts = tf.reduce_sum(patients, -2)
x_ts.eval()

array([[[1., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 1., 0., 1., 0.],
        [0., 1., 1., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0.]]], dtype=float32)

In [249]:
ys = tf.multiply(x_ts, np.array([[1,2,3,4],[5,6,7,8]]).reshape([2,4,1]))
ys.eval()

array([[[1., 1., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 5., 0., 5., 0.],
        [0., 6., 6., 6., 0.],
        [0., 0., 0., 7., 0.],
        [0., 0., 0., 0., 0.]]], dtype=float32)

In [252]:
y_2d = tf.reshape(ys, [8,5])
y_2d.eval()

array([[1., 1., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 5., 0., 5., 0.],
       [0., 6., 6., 6., 0.],
       [0., 0., 0., 7., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

In [250]:
x_2d = x_ts.eval().reshape([8,5])
x_2d

array([[1., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 1., 0., 1., 0.],
       [0., 1., 1., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)