<a href="https://colab.research.google.com/github/jlee3508/CSE6250_Team_J5/blob/main/CSE6250_Team_J5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyhealth
!pip install torchdiffeq
!pip install pdb

# Mount Notebook to Google Drive
Upload the data, pretrianed model, figures, etc to your Google Drive, then mount this notebook to Google Drive. After that, you can access the resources freely.

Instruction: https://colab.research.google.com/notebooks/io.ipynb

Example: https://colab.research.google.com/drive/1srw_HFWQ2SMgmWIawucXfusGzrj1_U0q

Video: https://www.youtube.com/watch?v=zc8g8lGcwQU

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Introduction

--------------------------------------------------------------------------------

The background context of the original paper is that predicting and identifying patients at risk of readmission to the intensive care unit is emphasized in contemporary healthcare studies due to its potential impact on early intervention, prevention of readmission, and optimization of healthcare resource allocation. Various approaches, including deep learning architectures, are being explored to enhance the accuracy of predicting readmission within 30 days of discharge from the ICU and identifying at-risk patients using electronic medical records.

The original paper proposed and compared different neural network architectures for processing longitudinal EMR data that is recorded at irregular intervals. The architectures proposed involved appending the appending time-related information to the numerical vectors used to represent time stamped codes, to modifying the internal workings of recurrent cells using exponential time-decay functions or ordinary differential equations. The original paper also proposed utilizing medical concept embeddings as well as using neural ODEs to describe how the embedding of a medical code changes over time. The evaluation of these algorithms considered multiple criteria such as average precision, AUROC, F1 score, sensitivity, and specificity. The insights derived from this research aim to assist healthcare providers in understanding ICU patients at increased risk of readmission and improving the efficiency of the entire ICU readmission process.

The original paper had fairly consistent results over all of the architectures. They had an average F1 score of 0.37 and an average AUROC of 0.72. The paper demonstrates that the variance among these different architectures is minimal.


# Scope of Reproducibility:

-------------------------------------------------------------------------------
(1) Evaluate the feasibility of using neural ODEs to model how the predictive relevance of recorded medical codes changes over time
I will implement and run ODE architecture models to determine the power of ODEs with time data.
(2) Perform a comprehensive comparison of deep learning models that have been proposed for processing time-series sampled at irregular intervals, including MCEs, neural ODEs, attention mechanisms, and recurrent layers
I will implement various types of architectures and compare them against each other using metrics such as: precision, AUROC, and F1 score


In [None]:
# no code is required for this section
'''
if you want to use an image outside this notebook for explanaition,
you can upload it to your google drive and show it with OpenCV or matplotlib
'''
# mount this notebook to your google drive
drive.mount('/content/gdrive')

# define dirs to workspace and data
img_dir = '/content/gdrive/My Drive/Colab Notebooks/<path-to-your-image>'

#import cv2
#img = cv2.imread(img_dir)
#cv2.imshow("Title", img)


# Methodology

This methodology is the core of your project. It consists of run-able codes with necessary annotations to show the expeiment you executed for testing the hypotheses.

The methodology at least contains two subsections **data** and **model** in your experiment.

In [None]:
# import  packages you need
import numpy as np
from google.colab import drive
import torch
import torch.nn as nn
from math import pi
from torchdiffeq import odeint, odeint_adjoint
from pdb import set_trace as bp
import torch.nn.functional as F
from torch.autograd import Variable
import math
import torch.utils.data as utils
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, average_precision_score, roc_auc_score, f1_score
from time import time

##  Data
Data includes raw data (MIMIC III tables), descriptive statistics (our homework questions), and data processing (feature engineering).
  * Source of the data: where the data is collected from; if data is synthetic or self-generated, explain how. If possible, please provide a link to the raw datasets.
  * Statistics: include basic descriptive statistics of the dataset like size, cross validation split, label distribution, etc.
  * Data process: how do you munipulate the data, e.g., change the class labels, split the dataset to train/valid/test, refining the dataset.
  * Illustration: printing results, plotting figures for illustration.
  * You can upload your raw dataset to Google Drive and mount this Colab to the same directory. If your raw dataset is too large, you can upload the processed dataset and have a code to load the processed dataset.

Data is originally from MIMIC 3. Link is as follows: https://physionet.org/content/mimiciii/1.4/
The database includes information such as demographics, vital sign measurements made at the bedside, laboratory test results, procedures, medications, caregiver notes, imaging reports, and mortality. There are 45298 patients, 25004496 diagnoses and procedures, and 17756816 charts and prescriptions in the dataset. Out of all patients, 5495 were readmitted and 39803 were not readmitted.
The original paper had a 90/10 train/test split. For variables, a variable was synthesized to keep track of the number of ICU stays in a year per patient.

This paper deals with multiple types of models and architectures.The following is a list of the architectures tested in the model and a brief description of each.
The following neural network architectures were compared for predicting readmission to the ICU:

ODE + RNN + Attention: dynamics in time of embeddings are modelled using neural ODEs, embeddings are passed to RNN layers, dot-product attention is applied to RNN outputs.

ODE + RNN: dynamics in time of embeddings are modelled using neural ODEs, embeddings are passed to RNN layers, the final memory states are used for further processing.

RNN (ODE time decay) + Attention: embeddings are passed to RNN layers with dynamics in time of the internal memory states modelled using neural ODEs, dot-product attention is applied to RNN outputs.

RNN (ODE time decay): embeddings are passed to RNN layers with dynamics in time of the internal memory states modelled using neural ODEs, the final memory states are used for further processing.

RNN (exp time decay) + Attention: embeddings are passed to RNN layers with internal memory states decaying exponentially over time, dot-product attention is applied to RNN outputs.

RNN (exp time decay): embeddings are passed to RNN layers with internal memory states decaying exponentially over time, the final memory states are used for further processing.

RNN (concatenated Δtime) + Attention: embeddings are concatenated with time differences between observations and passed to RNN layers, dot-product attention is applied to RNN outputs.

RNN (concatenated Δtime): embeddings are concatenated with time differences between observations and passed to RNN layers, the final memory states are used for further processing.

ODE + Attention: dynamics in time of embeddings are modelled using neural ODEs, dot-product attention is applied to the embeddings.

Attention (concatenated time): embeddings are concatenated with elapsed times, dot-product attention is applied to the embeddings.

MCE + RNN + Attention: MCE is used to compute the embeddings, embeddings are passed to RNN layers, dot-product attention is applied to RNN outputs.

MCE + RNN: MCE is used to compute the embeddings, embeddings are passed to RNN layers, the final memory states are used for further processing.

MCE + Attention: MCE is used to compute the embeddings, dot-product attention is applied to the embeddings.


The loss function is binary cross entropy and the optimizer is Adam. The model is not pretrained.


In [None]:

data_dir = '/content/gdrive/My Drive/cse6250_project_team_j5/data/'
logdir = '/content/gdrive/My Drive/cse6250_project_team_j5/logFolder/'
modeldir = '/content/gdrive/My Drive/cse6250_project_team_j5/models/'

def get_data(data, type):
  # Data
  static       = data['static'].astype('float32')
  label        = data['label'].astype('float32')
  dp           = data['dp'].astype('int64') # diagnoses/procedures
  cp           = data['cp'].astype('int64') # charts/prescriptions
  dp_times     = data['dp_times'].astype('float32')
  cp_times     = data['cp_times'].astype('float32')
  train_ids    = data['train_ids']
  validate_ids = data['validate_ids']
  test_ids     = data['test_ids']

  if (type == 'TRAIN'):
    ids = train_ids
  elif (type == 'VALIDATE'):
    ids = validate_ids
  elif (type == 'TEST'):
    ids = test_ids
  elif (type == 'ALL'):
    ids = np.full_like(label, True, dtype=bool)

  static   = static[ids, :]
  label    = label[ids]
  dp       = dp[ids, :]
  cp       = cp[ids, :]
  dp_times = dp_times[ids, :]
  cp_times = cp_times[ids, :]

  return static, dp, cp, dp_times, cp_times, label


def get_dictionaries(data):
  return data['static_vars'], data['dict_dp'][()], data['dict_cp'][()]


def num_statics(data):
  return data['static_vars'].shape[0]


def vocab_sizes(data):
  return data['dp'].max()+1, data['cp'].max()+1



## Hyperparameters

In [None]:
min_count = 100 # words whose occurred less than min_cnt are encoded as OTHER

# training
batch_size = 64
num_epochs = 1
dropout_rate = 0.5
patience = 10 # early stopping

# which data to load
# on_the_cloud = False
#all_train = False
# all_train = True

# network variants
#net_variant = 'birnn_concat_time_delta'
# net_variant = 'birnn_concat_time_delta_attention'
#net_variant = 'birnn_time_decay'
#net_variant = 'birnn_time_decay_attention'
#net_variant = 'ode_birnn'
#net_variant = 'ode_birnn_attention'
#net_variant = 'ode_attention'
#net_variant = 'attention_concat_time'
#net_variant = 'birnn_ode_decay'
#net_variant = 'birnn_ode_decay_attention'
net_variant = 'mce_attention'
#net_variant = 'mce_birnn'
#net_variant = 'mce_birnn_attention'

# bootstrapping
np_seed = 1234
bootstrap_samples = 2

# bayesian network
pi = 0.5
sigma1 = math.exp(-0)
sigma2 = math.exp(-6)
samples = 1
test_samples = 10

In [None]:
def get_trainloader(data, type, shuffle=True, idx=None):
  # Data
  static, dp, cp, dp_times, cp_times, label = get_data(data, type)

  # Bootstrap
  if idx is not None:
    static, dp, cp, dp_times, cp_times, label = static[idx], dp[idx], cp[idx], dp_times[idx], cp_times[idx], label[idx]

  # Compute total batch count
  num_batches = len(label) // batch_size

  # Create dataset
  dataset = utils.TensorDataset(torch.from_numpy(static),
                                torch.from_numpy(dp),
                                torch.from_numpy(cp),
                                torch.from_numpy(dp_times),
                                torch.from_numpy(cp_times),
                                torch.from_numpy(label))

  # Create batch queues
  trainloader = utils.DataLoader(dataset,
                                 batch_size = batch_size,
                                 shuffle = shuffle,
                                 sampler = None,
                                 num_workers = 2,
                                 drop_last = True)

  # Weight of positive samples for training
  pos_weight = torch.tensor((len(label) - np.sum(label))/np.sum(label))

  return trainloader, num_batches, pos_weight

In [None]:
print('Load data...')
data = np.load(data_dir + 'data_arrays.npz')
trainloader, num_batches, pos_weight = get_trainloader(data, 'TRAIN')
print("Data Loaded")

##   Models
The model includes the model definitation which usually is a class, model training, and other necessary parts.
  * Model architecture: layer number/size/type, activation function, etc
  * Training objectives: loss function, optimizer, weight of each loss term, etc
  * Others: whether the model is pretrained, Monte Carlo simulation for uncertainty analysis, etc
  * The code of model should have classes of the model, functions of model training, model validation, etc.
  * If your model training is done outside of this notebook, please upload the trained model here and develop a function to load and test it.

#### Cell Architecture

In [None]:

class ODEFunc(nn.Module):
    """MLP modeling the derivative of ODE system.
    Parameters
    ----------
    device : torch.device
    data_dim : int
        Dimension of data.
    hidden_dim : int
        Dimension of hidden layers.
    augment_dim: int
        Dimension of augmentation. If 0 does not augment ODE, otherwise augments
        it with augment_dim dimensions.
    time_dependent : bool
        If True adds time as input, making ODE time dependent.
    non_linearity : string
        One of 'relu' and 'softplus'
    """
    def __init__(self, device, data_dim, hidden_dim, augment_dim=0,
                 time_dependent=False, non_linearity='relu'):
        super(ODEFunc, self).__init__()
        self.device = device
        self.augment_dim = augment_dim
        self.data_dim = data_dim
        self.input_dim = data_dim + augment_dim
        self.hidden_dim = hidden_dim
        self.nfe = 0  # Number of function evaluations
        self.time_dependent = time_dependent

        if time_dependent:
            self.fc1 = nn.Linear(self.input_dim + 1, hidden_dim)
        else:
            self.fc1 = nn.Linear(self.input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, self.input_dim)

        if non_linearity == 'relu':
            self.non_linearity = nn.ReLU(inplace=True)
        elif non_linearity == 'softplus':
            self.non_linearity = nn.Softplus()

    def forward(self, t, x):
        """
        Parameters
        ----------
        t : torch.Tensor
            Current time. Shape (1,).
        x : torch.Tensor
            Shape (batch_size, input_dim)
        """
        # Forward pass of model corresponds to one function evaluation, so
        # increment counter
        self.nfe += 1
        if self.time_dependent:
            # Shape (batch_size, 1)
            t_vec = torch.ones(x.shape[0], 1).to(self.device) * t
            # Shape (batch_size, data_dim + 1)
            t_and_x = torch.cat([t_vec, x], 1)
            # Shape (batch_size, hidden_dim)
            out = self.fc1(t_and_x)
        else:
            out = self.fc1(x)
        out = self.non_linearity(out)
        out = self.fc2(out)
        out = self.non_linearity(out)
        out = self.fc3(out)
        return out


class ODEBlock(nn.Module):
    """Solves ODE defined by odefunc.
    Parameters
    ----------
    device : torch.device
    odefunc : ODEFunc instance or anode.conv_models.ConvODEFunc instance
        Function defining dynamics of system.
    is_conv : bool
        If True, treats odefunc as a convolutional model.
    tol : float
        Error tolerance.
    adjoint : bool
        If True calculates gradient with adjoint method, otherwise
        backpropagates directly through operations of ODE solver.
    """
    def __init__(self, device, odefunc, is_conv=False, tol=1e-3, adjoint=False):
        super(ODEBlock, self).__init__()
        self.adjoint = adjoint
        self.device = device
        self.is_conv = is_conv
        self.odefunc = odefunc
        self.tol = tol

    def forward(self, x, eval_times=None):
        """Solves ODE starting from x.
        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, self.odefunc.data_dim)
        eval_times : None or torch.Tensor
            If None, returns solution of ODE at final time t=1. If torch.Tensor
            then returns full ODE trajectory evaluated at points in eval_times.
        """
        # Forward pass corresponds to solving ODE, so reset number of function
        self.odefunc.nfe = 0

        if eval_times is None:
            integration_time = torch.tensor([0, 1]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)


        if self.odefunc.augment_dim > 0:
            if self.is_conv:
                batch_size, channels, height, width = x.shape
                aug = torch.zeros(batch_size, self.odefunc.augment_dim,
                                  height, width).to(self.device)
                # Shape (batch_size, channels + augment_dim, height, width)
                x_aug = torch.cat([x, aug], 1)
            else:
                aug = torch.zeros(x.shape[0], self.odefunc.augment_dim).to(self.device)
                # Shape (batch_size, data_dim + augment_dim)
                x_aug = torch.cat([x, aug], 1)
        else:
            x_aug = x

        if self.adjoint:
            out = odeint_adjoint(self.odefunc, x_aug, integration_time,
                                 rtol=self.tol, atol=self.tol, method='euler')
        else:
            out = odeint(self.odefunc, x_aug, integration_time,
                         rtol=self.tol, atol=self.tol, method='euler')

        if eval_times is None:
            return out[1]
        else:
            return out


class ODENet(nn.Module):
    """An ODEBlock followed by a Linear layer.
    Parameters
    ----------
    device : torch.device
    data_dim : int
        Dimension of data.
    hidden_dim : int
        Dimension of hidden layers.
    output_dim : int
        Dimension of output after hidden layer. Should be 1 for regression or
        num_classes for classification.
    augment_dim: int
        Dimension of augmentation. If 0 does not augment ODE, otherwise augments
        it with augment_dim dimensions.
    time_dependent : bool
        If True adds time as input, making ODE time dependent.
    non_linearity : string
        One of 'relu' and 'softplus'
    tol : float
        Error tolerance.
    adjoint : bool
        If True calculates gradient with adjoint method, otherwise
        backpropagates directly through operations of ODE solver.
    """
    def __init__(self, device, data_dim, hidden_dim, output_dim=1,
                 augment_dim=0, time_dependent=False, non_linearity='relu',
                 tol=1e-3, adjoint=False):
        super(ODENet, self).__init__()
        self.device = device
        self.data_dim = data_dim
        self.hidden_dim = hidden_dim
        self.augment_dim = augment_dim
        self.output_dim = output_dim
        self.time_dependent = time_dependent
        self.tol = tol

        odefunc = ODEFunc(device, data_dim, hidden_dim, augment_dim,
                          time_dependent, non_linearity)

        self.odeblock = ODEBlock(device, odefunc, tol=tol, adjoint=adjoint)

    def forward(self, x, eval_times=None):
        features = self.odeblock(x, eval_times)
        return features

class Attention(torch.nn.Module):
  """
  Dot-product attention module.

  Args:
    inputs: A `Tensor` with embeddings in the last dimension.
    mask: A `Tensor`. Dimensions are the same as inputs but without the embedding dimension.
      Values are 0 for 0-padding in the input and 1 elsewhere.

  Returns:
    outputs: The input `Tensor` whose embeddings in the last dimension have undergone a weighted average.
      The second-last dimension of the `Tensor` is removed.
    attention_weights: weights given to each embedding.
  """
  def __init__(self, embedding_dim):
    super(Attention, self).__init__()
    self.context = nn.Parameter(torch.Tensor(embedding_dim))
    self.linear_hidden = nn.Linear(embedding_dim, embedding_dim)
    self.reset_parameters()

  def reset_parameters(self):
    nn.init.normal_(self.context)

  def forward(self, inputs, mask):
    hidden = torch.tanh(self.linear_hidden(inputs))
    importance = torch.sum(hidden * self.context, dim=-1)
    importance = importance.masked_fill(mask == 0, -1e9)
    attention_weights = F.softmax(importance, dim=-1)
    weighted_projection = inputs * torch.unsqueeze(attention_weights, dim=-1)
    outputs = torch.sum(weighted_projection, dim=-2)
    return outputs, attention_weights


class GRUExponentialDecay(nn.Module):
  """
  GRU RNN module where the hidden state decays exponentially
  (see e.g. Che et al. 2018, Recurrent Neural Networks for Multivariate Time Series
  with Missing Values).

  Args:
    inputs: A `Tensor` with embeddings in the last dimension.
    times: A `Tensor` with the same shape as inputs containing the recorded times (but no embedding dimension).

  Returns:
    outs: Hidden states of the RNN.
  """
  def __init__(self, input_size, hidden_size, bias=True):
    super(GRUExponentialDecay, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.gru_cell = nn.GRUCell(input_size, hidden_size)
    self.decays = nn.Parameter(torch.Tensor(hidden_size)) # exponential decays vector

  def forward(self, inputs, times):
    if torch.cuda.is_available():
      hn = torch.zeros(inputs.size(0), self.hidden_size).cuda() # batch_size x hidden_size
      outs = torch.zeros(inputs.size(0), inputs.size(1), self.hidden_size).cuda() # batch_size x seq_len x hidden_size
    else:
      hn = torch.zeros(inputs.size(0), self.hidden_size) # batch_size x hidden_size
      outs = torch.zeros(inputs.size(0), inputs.size(1), self.hidden_size) # batch_size x seq_len x hidden_size

    # this is slow
    for seq in range(inputs.size(1)):
      hn = self.gru_cell(inputs[:,seq,:], hn)
      outs[:,seq,:] = hn
      hn = hn*torch.exp(-torch.clamp(torch.unsqueeze(times[:,seq], dim=-1)*self.decays, min=0))
    return outs


class GRUOdeDecay(nn.Module):
  """
  GRU RNN module where the hidden state decays according to an ODE.
  (see Rubanova et al. 2019, Latent ODEs for Irregularly-Sampled Time Series)

  Args:
    inputs: A `Tensor` with embeddings in the last dimension.
    times: A `Tensor` with the same shape as inputs containing the recorded times (but no embedding dimension).

  Returns:
    outs: Hidden states of the RNN.
  """
  def __init__(self, input_size, hidden_size, bias=True):
    super(GRUOdeDecay, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.gru_cell = nn.GRUCell(input_size, hidden_size)
    self.decays = nn.Parameter(torch.Tensor(hidden_size)) # exponential decays vector

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.ode_net = ODENet(self.device, self.input_size, self.input_size, output_dim=self.input_size, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)

  def forward(self, inputs, times):
    if torch.cuda.is_available():
      hn = torch.zeros(inputs.size(0), self.hidden_size).cuda() # batch_size x hidden_size
      outs = torch.zeros(inputs.size(0), inputs.size(1), self.hidden_size).cuda() # batch_size x seq_len x hidden_size
    else:
      hn = torch.zeros(inputs.size(0), self.hidden_size) # batch_size x hidden_size
      outs = torch.zeros(inputs.size(0), inputs.size(1), self.hidden_size) # batch_size x seq_len x hidden_size

    # this is slow
    for seq in range(inputs.size(1)):
      hn = self.gru_cell(inputs[:,seq,:], hn)
      outs[:,seq,:] = hn

      times_unique, inverse_indices = torch.unique(times[:,seq], sorted=True, return_inverse=True)
      if times_unique.size(0) > 1:
        hn = self.ode_net(hn, times_unique)
        hn = hn[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
    return outs


def abs_time_to_delta(times):
  delta = torch.cat((torch.unsqueeze(times[:, 0], dim=-1), times[:, 1:] - times[:, :-1]), dim=1)
  delta = torch.clamp(delta, min=0)
  return delta


#### Model Architectures

In [None]:
if net_variant == 'birnn_time_decay':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = GRUExponentialDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_fw = GRUExponentialDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)
      self.gru_dp_bw = GRUExponentialDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_bw = GRUExponentialDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      rnn_dp_fw = self.gru_dp_fw(embedded_dp_fw, dp_t_delta_fw)
      rnn_cp_fw = self.gru_cp_fw(embedded_cp_fw, cp_t_delta_fw)
      rnn_dp_bw = self.gru_dp_bw(embedded_dp_bw, dp_t_delta_bw)
      rnn_cp_bw = self.gru_cp_bw(embedded_cp_bw, cp_t_delta_bw)
      ## output dim rnn_hidden: batch_size x embedding_dim
      rnn_dp_fw = rnn_dp_fw[:,-1,:]
      rnn_cp_fw = rnn_cp_fw[:,-1,:]
      rnn_dp_bw = rnn_dp_bw[:,-1,:]
      rnn_cp_bw = rnn_cp_bw[:,-1,:]
      ## concatenate forward and backward: batch_size x 2*embedding_dim
      rnn_dp = torch.cat((rnn_dp_fw, rnn_dp_bw), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, rnn_cp_bw), dim=-1)

      # Scores
      score_dp = self.fc_dp(self.dropout(rnn_dp))
      score_cp = self.fc_cp(self.dropout(rnn_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []

if net_variant == 'birnn_concat_time_delta':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_fw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)
      self.gru_dp_bw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_bw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*(self.embed_dp_dim+1), 1)
      self.fc_cp  = nn.Linear(2*(self.embed_cp_dim+1), 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])

      # Concatate with time
      ## output dim: batch_size x seq_len x (embedding_dim+1)
      concat_dp_fw = torch.cat((embedded_dp_fw, torch.unsqueeze(dp_t_delta_fw, dim=-1)), dim=-1)
      concat_cp_fw = torch.cat((embedded_cp_fw, torch.unsqueeze(cp_t_delta_fw, dim=-1)), dim=-1)
      concat_dp_bw = torch.cat((embedded_dp_bw, torch.unsqueeze(dp_t_delta_bw, dim=-1)), dim=-1)
      concat_cp_bw = torch.cat((embedded_cp_bw, torch.unsqueeze(cp_t_delta_bw, dim=-1)), dim=-1)
      ## Dropout
      concat_dp_fw = self.dropout(concat_dp_fw)
      concat_cp_fw = self.dropout(concat_cp_fw)
      concat_dp_bw = self.dropout(concat_dp_bw)
      concat_cp_bw = self.dropout(concat_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x (embedding_dim+1)
      ## output dim rnn_hidden: batch_size x 1 x (embedding_dim+1)
      rnn_dp_fw, rnn_hidden_dp_fw = self.gru_dp_fw(concat_dp_fw)
      rnn_cp_fw, rnn_hidden_cp_fw = self.gru_cp_fw(concat_cp_fw)
      rnn_dp_bw, rnn_hidden_dp_bw = self.gru_dp_bw(concat_dp_bw)
      rnn_cp_bw, rnn_hidden_cp_bw = self.gru_cp_bw(concat_cp_bw)
      ## output dim rnn_hidden: batch_size x (embedding_dim+1)
      rnn_hidden_dp_fw = rnn_hidden_dp_fw.view(-1, self.embed_dp_dim+1)
      rnn_hidden_cp_fw = rnn_hidden_cp_fw.view(-1, self.embed_cp_dim+1)
      rnn_hidden_dp_bw = rnn_hidden_dp_bw.view(-1, self.embed_dp_dim+1)
      rnn_hidden_cp_bw = rnn_hidden_cp_bw.view(-1, self.embed_cp_dim+1)
      ## concatenate forward and backward: batch_size x 2*(embedding_dim+1)
      rnn_hidden_dp = torch.cat((rnn_hidden_dp_fw, rnn_hidden_dp_bw), dim=-1)
      rnn_hidden_cp = torch.cat((rnn_hidden_cp_fw, rnn_hidden_cp_bw), dim=-1)

      # Scores
      score_dp = self.fc_dp(self.dropout(rnn_hidden_dp))
      score_cp = self.fc_cp(self.dropout(rnn_hidden_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'birnn_concat_time_delta_attention':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_fw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)
      self.gru_dp_bw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_bw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=2*(self.embed_dp_dim+1)) #+1 for the concatenated time
      self.attention_cp = Attention(embedding_dim=2*(self.embed_cp_dim+1))

      # Fully connected output
      self.fc_dp  = nn.Linear(2*(self.embed_dp_dim+1), 1)
      self.fc_cp  = nn.Linear(2*(self.embed_cp_dim+1), 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])

      # Concatate with time
      ## output dim: batch_size x seq_len x (embedding_dim+1)
      concat_dp_fw = torch.cat((embedded_dp_fw, torch.unsqueeze(dp_t_delta_fw, dim=-1)), dim=-1)
      concat_cp_fw = torch.cat((embedded_cp_fw, torch.unsqueeze(cp_t_delta_fw, dim=-1)), dim=-1)
      concat_dp_bw = torch.cat((embedded_dp_bw, torch.unsqueeze(dp_t_delta_bw, dim=-1)), dim=-1)
      concat_cp_bw = torch.cat((embedded_cp_bw, torch.unsqueeze(cp_t_delta_bw, dim=-1)), dim=-1)
      ## Dropout
      concat_dp_fw = self.dropout(concat_dp_fw)
      concat_cp_fw = self.dropout(concat_cp_fw)
      concat_dp_bw = self.dropout(concat_dp_bw)
      concat_cp_bw = self.dropout(concat_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x (embedding_dim+1)
      ## output dim rnn_hidden: batch_size x 1 x (embedding_dim+1)
      rnn_dp_fw, rnn_hidden_dp_fw = self.gru_dp_fw(concat_dp_fw)
      rnn_cp_fw, rnn_hidden_cp_fw = self.gru_cp_fw(concat_cp_fw)
      rnn_dp_bw, rnn_hidden_dp_bw = self.gru_dp_bw(concat_dp_bw)
      rnn_cp_bw, rnn_hidden_cp_bw = self.gru_cp_bw(concat_cp_bw)
      # concatenate forward and backward
      ## output dim: batch_size x seq_len x 2*(embedding_dim+1)
      rnn_dp = torch.cat((rnn_dp_fw, torch.flip(rnn_dp_bw, [1])), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, torch.flip(rnn_cp_bw, [1])), dim=-1)

      # Attention
      ## output dim: batch_size x 2*(embedding_dim+1)
      attended_dp, weights_dp = self.attention_dp(rnn_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(rnn_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'birnn_time_decay':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = GRUExponentialDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_fw = GRUExponentialDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)
      self.gru_dp_bw = GRUExponentialDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_bw = GRUExponentialDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      rnn_dp_fw = self.gru_dp_fw(embedded_dp_fw, dp_t_delta_fw)
      rnn_cp_fw = self.gru_cp_fw(embedded_cp_fw, cp_t_delta_fw)
      rnn_dp_bw = self.gru_dp_bw(embedded_dp_bw, dp_t_delta_bw)
      rnn_cp_bw = self.gru_cp_bw(embedded_cp_bw, cp_t_delta_bw)
      ## output dim rnn_hidden: batch_size x embedding_dim
      rnn_dp_fw = rnn_dp_fw[:,-1,:]
      rnn_cp_fw = rnn_cp_fw[:,-1,:]
      rnn_dp_bw = rnn_dp_bw[:,-1,:]
      rnn_cp_bw = rnn_cp_bw[:,-1,:]
      ## concatenate forward and backward: batch_size x 2*embedding_dim
      rnn_dp = torch.cat((rnn_dp_fw, rnn_dp_bw), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, rnn_cp_bw), dim=-1)

      # Scores
      score_dp = self.fc_dp(self.dropout(rnn_dp))
      score_cp = self.fc_cp(self.dropout(rnn_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'birnn_time_decay_attention':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = GRUExponentialDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_fw = GRUExponentialDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)
      self.gru_dp_bw = GRUExponentialDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_bw = GRUExponentialDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=2*self.embed_dp_dim)
      self.attention_cp = Attention(embedding_dim=2*self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      rnn_dp_fw = self.gru_dp_fw(embedded_dp_fw, dp_t_delta_fw)
      rnn_cp_fw = self.gru_cp_fw(embedded_cp_fw, cp_t_delta_fw)
      rnn_dp_bw = self.gru_dp_bw(embedded_dp_bw, dp_t_delta_bw)
      rnn_cp_bw = self.gru_cp_bw(embedded_cp_bw, cp_t_delta_bw)
      # concatenate forward and backward
      ## output dim: batch_size x seq_len x 2*embedding_dim
      rnn_dp = torch.cat((rnn_dp_fw, torch.flip(rnn_dp_bw, [1])), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, torch.flip(rnn_cp_bw, [1])), dim=-1)

      # Attention
      ## output dim: batch_size x 2*embedding_dim
      attended_dp, weights_dp = self.attention_dp(rnn_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(rnn_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'attention_concat_time':
  # Attention Only
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(2*np.ceil(num_dp_codes**0.25))
      self.embed_cp_dim = int(2*np.ceil(num_cp_codes**0.25))

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=self.embed_dp_dim+1) #+1 for the concatenated time
      self.attention_cp = Attention(embedding_dim=self.embed_cp_dim+1)

      # Fully connected output
      self.fc_dp  = nn.Linear(self.embed_dp_dim+1, 1)
      self.fc_cp  = nn.Linear(self.embed_cp_dim+1, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp = self.embed_dp(dp)
      embedded_cp = self.embed_cp(cp)

      # Concatate with time
      ## output dim: batch_size x seq_len x (embedding_dim+1)
      concat_dp = torch.cat((embedded_dp, torch.unsqueeze(dp_t, dim=-1)), dim=-1)
      concat_cp = torch.cat((embedded_cp, torch.unsqueeze(cp_t, dim=-1)), dim=-1)
      ## Dropout
      concat_dp = self.dropout(concat_dp)
      concat_cp = self.dropout(concat_cp)

      # Attention
      ## output dim: batch_size x (embedding_dim+1)
      attended_dp, weights_dp = self.attention_dp(concat_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(concat_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'ode_birnn':
  # Attention Only
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # ODE layers
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.ode_dp = ODENet(self.device, self.embed_dp_dim, self.embed_dp_dim, output_dim=self.embed_dp_dim, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)
      self.ode_cp = ODENet(self.device, self.embed_cp_dim, self.embed_cp_dim, output_dim=self.embed_cp_dim, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)

      # GRU layers
      self.gru_dp_fw = nn.GRU(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim, num_layers=1, batch_first=True)
      self.gru_cp_fw = nn.GRU(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim, num_layers=1, batch_first=True)
      self.gru_dp_bw = nn.GRU(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim, num_layers=1, batch_first=True)
      self.gru_cp_bw = nn.GRU(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim, num_layers=1, batch_first=True)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp = self.embed_dp(dp)
      embedded_cp = self.embed_cp(cp)

      # ODE
      ## Round times
      dp_t = torch.round(100*dp_t)/100
      cp_t = torch.round(100*cp_t)/100

      embedded_dp_long = embedded_dp.view(-1, self.embed_dp_dim)
      dp_t_long = dp_t.view(-1)
      dp_t_long_unique, inverse_indices = torch.unique(dp_t_long, sorted=True, return_inverse=True)
      ode_dp_long = self.ode_dp(embedded_dp_long, dp_t_long_unique)
      ode_dp_long = ode_dp_long[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
      ode_dp = ode_dp_long.view(dp.size(0), dp.size(1), self.embed_dp_dim)

      embedded_cp_long = embedded_cp.view(-1, self.embed_cp_dim)
      cp_t_long = cp_t.view(-1)
      cp_t_long_unique, inverse_indices = torch.unique(cp_t_long, sorted=True, return_inverse=True)
      ode_cp_long = self.ode_cp(embedded_cp_long, cp_t_long_unique)
      ode_cp_long = ode_cp_long[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
      ode_cp = ode_cp_long.view(cp.size(0), cp.size(1), self.embed_cp_dim)

      ## Dropout
      ode_dp = self.dropout(ode_dp)
      ode_cp = self.dropout(ode_cp)

      # Forward and backward sequences
      ## output dim: batch_size x seq_len x embedding_dim
      ode_dp_fw = ode_dp
      ode_cp_fw = ode_cp
      ode_dp_bw = torch.flip(ode_dp_fw, [1])
      ode_cp_bw = torch.flip(ode_cp_fw, [1])

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      ## output dim rnn_hidden: batch_size x 1 x embedding_dim
      rnn_dp_fw, rnn_hidden_dp_fw = self.gru_dp_fw(ode_dp_fw)
      rnn_cp_fw, rnn_hidden_cp_fw = self.gru_cp_fw(ode_cp_fw)
      rnn_dp_bw, rnn_hidden_dp_bw = self.gru_dp_bw(ode_dp_bw)
      rnn_cp_bw, rnn_hidden_cp_bw = self.gru_cp_bw(ode_cp_bw)
      ## output dim rnn_hidden: batch_size x embedding_dim
      rnn_hidden_dp_fw = rnn_hidden_dp_fw.view(-1, self.embed_dp_dim)
      rnn_hidden_cp_fw = rnn_hidden_cp_fw.view(-1, self.embed_cp_dim)
      rnn_hidden_dp_bw = rnn_hidden_dp_bw.view(-1, self.embed_dp_dim)
      rnn_hidden_cp_bw = rnn_hidden_cp_bw.view(-1, self.embed_cp_dim)
      ## concatenate forward and backward: batch_size x 2*embedding_dim
      rnn_hidden_dp = torch.cat((rnn_hidden_dp_fw, rnn_hidden_dp_bw), dim=-1)
      rnn_hidden_cp = torch.cat((rnn_hidden_cp_fw, rnn_hidden_cp_bw), dim=-1)

      # Scores
      score_dp = self.fc_dp(self.dropout(rnn_hidden_dp))
      score_cp = self.fc_cp(self.dropout(rnn_hidden_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'ode_birnn_attention':
  # Attention Only
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # ODE layers
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.ode_dp = ODENet(self.device, self.embed_dp_dim, self.embed_dp_dim, output_dim=self.embed_dp_dim, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)
      self.ode_cp = ODENet(self.device, self.embed_cp_dim, self.embed_cp_dim, output_dim=self.embed_cp_dim, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)

      # GRU layers
      self.gru_dp_fw = nn.GRU(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim, num_layers=1, batch_first=True)
      self.gru_cp_fw = nn.GRU(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim, num_layers=1, batch_first=True)
      self.gru_dp_bw = nn.GRU(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim, num_layers=1, batch_first=True)
      self.gru_cp_bw = nn.GRU(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim, num_layers=1, batch_first=True)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=2*self.embed_dp_dim)
      self.attention_cp = Attention(embedding_dim=2*self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp = self.embed_dp(dp)
      embedded_cp = self.embed_cp(cp)

      # ODE
      ## Round times
      dp_t = torch.round(100*dp_t)/100
      cp_t = torch.round(100*cp_t)/100

      embedded_dp_long = embedded_dp.view(-1, self.embed_dp_dim)
      dp_t_long = dp_t.view(-1)
      dp_t_long_unique, inverse_indices = torch.unique(dp_t_long, sorted=True, return_inverse=True)
      ode_dp_long = self.ode_dp(embedded_dp_long, dp_t_long_unique)
      ode_dp_long = ode_dp_long[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
      ode_dp = ode_dp_long.view(dp.size(0), dp.size(1), self.embed_dp_dim)

      embedded_cp_long = embedded_cp.view(-1, self.embed_cp_dim)
      cp_t_long = cp_t.view(-1)
      cp_t_long_unique, inverse_indices = torch.unique(cp_t_long, sorted=True, return_inverse=True)
      ode_cp_long = self.ode_cp(embedded_cp_long, cp_t_long_unique)
      ode_cp_long = ode_cp_long[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
      ode_cp = ode_cp_long.view(cp.size(0), cp.size(1), self.embed_cp_dim)

      ## Dropout
      ode_dp = self.dropout(ode_dp)
      ode_cp = self.dropout(ode_cp)

      # Forward and backward sequences
      ## output dim: batch_size x seq_len x embedding_dim
      ode_dp_fw = ode_dp
      ode_cp_fw = ode_cp
      ode_dp_bw = torch.flip(ode_dp_fw, [1])
      ode_cp_bw = torch.flip(ode_cp_fw, [1])

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      ## output dim rnn_hidden: batch_size x 1 x embedding_dim
      rnn_dp_fw, rnn_hidden_dp_fw = self.gru_dp_fw(ode_dp_fw)
      rnn_cp_fw, rnn_hidden_cp_fw = self.gru_cp_fw(ode_cp_fw)
      rnn_dp_bw, rnn_hidden_dp_bw = self.gru_dp_bw(ode_dp_bw)
      rnn_cp_bw, rnn_hidden_cp_bw = self.gru_cp_bw(ode_cp_bw)
      # concatenate forward and backward
      ## output dim: batch_size x seq_len x 2*embedding_dim
      rnn_dp = torch.cat((rnn_dp_fw, torch.flip(rnn_dp_bw, [1])), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, torch.flip(rnn_cp_bw, [1])), dim=-1)

      # Attention
      ## output dim: batch_size x 2*embedding_dim
      attended_dp, weights_dp = self.attention_dp(rnn_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(rnn_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'ode_attention':
  # Attention Only
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(2*np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(2*np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # ODE layers
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.ode_dp = ODENet(self.device, self.embed_dp_dim, self.embed_dp_dim, output_dim=self.embed_dp_dim, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)
      self.ode_cp = ODENet(self.device, self.embed_cp_dim, self.embed_cp_dim, output_dim=self.embed_cp_dim, augment_dim=0, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=True)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=self.embed_dp_dim)
      self.attention_cp = Attention(embedding_dim=self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp = self.embed_dp(dp)
      embedded_cp = self.embed_cp(cp)

      # ODE
      ## Round times
      dp_t = torch.round(100*dp_t)/100
      cp_t = torch.round(100*cp_t)/100

      embedded_dp_long = embedded_dp.view(-1, self.embed_dp_dim)
      dp_t_long = dp_t.view(-1)
      dp_t_long_unique, inverse_indices = torch.unique(dp_t_long, sorted=True, return_inverse=True)
      ode_dp_long = self.ode_dp(embedded_dp_long, dp_t_long_unique)
      ode_dp_long = ode_dp_long[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
      ode_dp = ode_dp_long.view(dp.size(0), dp.size(1), self.embed_dp_dim)

      embedded_cp_long = embedded_cp.view(-1, self.embed_cp_dim)
      cp_t_long = cp_t.view(-1)
      cp_t_long_unique, inverse_indices = torch.unique(cp_t_long, sorted=True, return_inverse=True)
      ode_cp_long = self.ode_cp(embedded_cp_long, cp_t_long_unique)
      ode_cp_long = ode_cp_long[inverse_indices, torch.arange(0, inverse_indices.size(0)), :]
      ode_cp = ode_cp_long.view(cp.size(0), cp.size(1), self.embed_cp_dim)

      ## Dropout
      ode_dp = self.dropout(ode_dp)
      ode_cp = self.dropout(ode_cp)

      # Attention
      ## output dim: batch_size x (embedding_dim+1)
      attended_dp, weights_dp = self.attention_dp(ode_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(ode_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'birnn_ode_decay':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = GRUOdeDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_fw = GRUOdeDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)
      self.gru_dp_bw = GRUOdeDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_bw = GRUOdeDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      ## Round
      dp_t_delta_fw = torch.round(100*dp_t_delta_fw)/100
      cp_t_delta_fw = torch.round(100*cp_t_delta_fw)/100
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      rnn_dp_fw = self.gru_dp_fw(embedded_dp_fw, dp_t_delta_fw)
      rnn_cp_fw = self.gru_cp_fw(embedded_cp_fw, cp_t_delta_fw)
      rnn_dp_bw = self.gru_dp_bw(embedded_dp_bw, dp_t_delta_bw)
      rnn_cp_bw = self.gru_cp_bw(embedded_cp_bw, cp_t_delta_bw)
      ## output dim rnn_hidden: batch_size x embedding_dim
      rnn_dp_fw = rnn_dp_fw[:,-1,:]
      rnn_cp_fw = rnn_cp_fw[:,-1,:]
      rnn_dp_bw = rnn_dp_bw[:,-1,:]
      rnn_cp_bw = rnn_cp_bw[:,-1,:]
      ## concatenate forward and backward: batch_size x 2*embedding_dim
      rnn_dp = torch.cat((rnn_dp_fw, rnn_dp_bw), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, rnn_cp_bw), dim=-1)

      # Scores
      score_dp = self.fc_dp(self.dropout(rnn_dp))
      score_cp = self.fc_cp(self.dropout(rnn_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'birnn_ode_decay_attention':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))+1
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))+1

      # Embedding layers
      self.embed_dp = nn.Embedding(num_embeddings=num_dp_codes, embedding_dim=self.embed_dp_dim, padding_idx=0)
      self.embed_cp = nn.Embedding(num_embeddings=num_cp_codes, embedding_dim=self.embed_cp_dim, padding_idx=0)

      # GRU layers
      self.gru_dp_fw = GRUOdeDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_fw = GRUOdeDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)
      self.gru_dp_bw = GRUOdeDecay(input_size=self.embed_dp_dim, hidden_size=self.embed_dp_dim)
      self.gru_cp_bw = GRUOdeDecay(input_size=self.embed_cp_dim, hidden_size=self.embed_cp_dim)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=2*self.embed_dp_dim)
      self.attention_cp = Attention(embedding_dim=2*self.embed_cp_dim)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*self.embed_dp_dim, 1)
      self.fc_cp  = nn.Linear(2*self.embed_cp_dim, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Compute time delta
      ## output dim: batch_size x seq_len
      dp_t_delta_fw = abs_time_to_delta(dp_t)
      cp_t_delta_fw = abs_time_to_delta(cp_t)
      ## Round
      dp_t_delta_fw = torch.round(100*dp_t_delta_fw)/100
      cp_t_delta_fw = torch.round(100*cp_t_delta_fw)/100
      dp_t_delta_bw = abs_time_to_delta(torch.flip(dp_t, [1]))
      cp_t_delta_bw = abs_time_to_delta(torch.flip(cp_t, [1]))

      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = self.embed_dp(dp)
      embedded_cp_fw = self.embed_cp(cp)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x embedding_dim
      rnn_dp_fw = self.gru_dp_fw(embedded_dp_fw, dp_t_delta_fw)
      rnn_cp_fw = self.gru_cp_fw(embedded_cp_fw, cp_t_delta_fw)
      rnn_dp_bw = self.gru_dp_bw(embedded_dp_bw, dp_t_delta_bw)
      rnn_cp_bw = self.gru_cp_bw(embedded_cp_bw, cp_t_delta_bw)
      # concatenate forward and backward
      ## output dim: batch_size x seq_len x 2*embedding_dim
      rnn_dp = torch.cat((rnn_dp_fw, torch.flip(rnn_dp_bw, [1])), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, torch.flip(rnn_cp_bw, [1])), dim=-1)

      # Attention
      ## output dim: batch_size x 2*embedding_dim
      attended_dp, weights_dp = self.attention_dp(rnn_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(rnn_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'mce_attention':
  # Attention Only
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(2*np.ceil(num_dp_codes**0.25))
      self.embed_cp_dim = int(2*np.ceil(num_cp_codes**0.25))

      # Precomputed embedding weights
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.emb_weight_dp = torch.Tensor(np.load(data_dir + 'emb_weight_dp_13.npy')).to(self.device)
      self.emb_weight_cp = torch.Tensor(np.load(data_dir + 'emb_weight_cp_11.npy')).to(self.device)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=self.embed_dp_dim+1) #+1 for the concatenated time
      self.attention_cp = Attention(embedding_dim=self.embed_cp_dim+1)

      # Fully connected output
      self.fc_dp  = nn.Linear(self.embed_dp_dim+1, 1)
      self.fc_cp  = nn.Linear(self.embed_cp_dim+1, 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp = F.embedding(dp, self.emb_weight_dp, padding_idx=0)
      embedded_cp = F.embedding(cp, self.emb_weight_cp, padding_idx=0)
      ## Dropout
      embedded_dp = self.dropout(embedded_dp)
      embedded_cp = self.dropout(embedded_cp)

      # Attention
      ## output dim: batch_size x (embedding_dim+1)
      attended_dp, weights_dp = self.attention_dp(embedded_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(embedded_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


if net_variant == 'mce_birnn':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))

      # Precomputed embedding weights
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.emb_weight_dp = torch.Tensor(np.load(data_dir + 'emb_weight_dp_7.npy')).to(self.device)
      self.emb_weight_cp = torch.Tensor(np.load(data_dir + 'emb_weight_cp_6.npy')).to(self.device)

      # GRU layers
      self.gru_dp_fw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_fw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)
      self.gru_dp_bw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_bw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)

      # Fully connected output
      self.fc_dp  = nn.Linear(2*(self.embed_dp_dim+1), 1)
      self.fc_cp  = nn.Linear(2*(self.embed_cp_dim+1), 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = F.embedding(dp, self.emb_weight_dp, padding_idx=0)
      embedded_cp_fw = F.embedding(cp, self.emb_weight_cp, padding_idx=0)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x (embedding_dim+1)
      ## output dim rnn_hidden: batch_size x 1 x (embedding_dim+1)
      rnn_dp_fw, rnn_hidden_dp_fw = self.gru_dp_fw(embedded_dp_fw)
      rnn_cp_fw, rnn_hidden_cp_fw = self.gru_cp_fw(embedded_cp_fw)
      rnn_dp_bw, rnn_hidden_dp_bw = self.gru_dp_bw(embedded_dp_bw)
      rnn_cp_bw, rnn_hidden_cp_bw = self.gru_cp_bw(embedded_cp_bw)
      ## output dim rnn_hidden: batch_size x (embedding_dim+1)
      rnn_hidden_dp_fw = rnn_hidden_dp_fw.view(-1, self.embed_dp_dim+1)
      rnn_hidden_cp_fw = rnn_hidden_cp_fw.view(-1, self.embed_cp_dim+1)
      rnn_hidden_dp_bw = rnn_hidden_dp_bw.view(-1, self.embed_dp_dim+1)
      rnn_hidden_cp_bw = rnn_hidden_cp_bw.view(-1, self.embed_cp_dim+1)
      ## concatenate forward and backward: batch_size x 2*(embedding_dim+1)
      rnn_hidden_dp = torch.cat((rnn_hidden_dp_fw, rnn_hidden_dp_bw), dim=-1)
      rnn_hidden_cp = torch.cat((rnn_hidden_cp_fw, rnn_hidden_cp_bw), dim=-1)

      # Scores
      score_dp = self.fc_dp(self.dropout(rnn_hidden_dp))
      score_cp = self.fc_cp(self.dropout(rnn_hidden_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []


elif net_variant == 'mce_birnn_attention':
  # GRU
  class Net(nn.Module):
    def __init__(self, num_static, num_dp_codes, num_cp_codes):
      super(Net, self).__init__()

      # Embedding dimensions
      self.embed_dp_dim = int(np.ceil(num_dp_codes**0.25))
      self.embed_cp_dim = int(np.ceil(num_cp_codes**0.25))

      # Precomputed embedding weights
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.emb_weight_dp = torch.Tensor(np.load(data_dir + 'emb_weight_dp_7.npy')).to(self.device)
      self.emb_weight_cp = torch.Tensor(np.load(data_dir + 'emb_weight_cp_6.npy')).to(self.device)

      # GRU layers
      self.gru_dp_fw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_fw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)
      self.gru_dp_bw = nn.GRU(input_size=self.embed_dp_dim+1, hidden_size=self.embed_dp_dim+1, num_layers=1, batch_first=True)
      self.gru_cp_bw = nn.GRU(input_size=self.embed_cp_dim+1, hidden_size=self.embed_cp_dim+1, num_layers=1, batch_first=True)

      # Attention layers
      self.attention_dp = Attention(embedding_dim=2*(self.embed_dp_dim+1)) #+1 for the concatenated time
      self.attention_cp = Attention(embedding_dim=2*(self.embed_cp_dim+1))

      # Fully connected output
      self.fc_dp  = nn.Linear(2*(self.embed_dp_dim+1), 1)
      self.fc_cp  = nn.Linear(2*(self.embed_cp_dim+1), 1)
      self.fc_all = nn.Linear(num_static + 2, 1)

      # Others
      self.dropout = nn.Dropout(p=0.5)

    def forward(self, stat, dp, cp, dp_t, cp_t):
      # Embedding
      ## output dim: batch_size x seq_len x embedding_dim
      embedded_dp_fw = F.embedding(dp, self.emb_weight_dp, padding_idx=0)
      embedded_cp_fw = F.embedding(cp, self.emb_weight_cp, padding_idx=0)
      embedded_dp_bw = torch.flip(embedded_dp_fw, [1])
      embedded_cp_bw = torch.flip(embedded_cp_fw, [1])
      ## Dropout
      embedded_dp_fw = self.dropout(embedded_dp_fw)
      embedded_cp_fw = self.dropout(embedded_cp_fw)
      embedded_dp_bw = self.dropout(embedded_dp_bw)
      embedded_cp_bw = self.dropout(embedded_cp_bw)

      # GRU
      ## output dim rnn:        batch_size x seq_len x (embedding_dim+1)
      ## output dim rnn_hidden: batch_size x 1 x (embedding_dim+1)
      rnn_dp_fw, rnn_hidden_dp_fw = self.gru_dp_fw(embedded_dp_fw)
      rnn_cp_fw, rnn_hidden_cp_fw = self.gru_cp_fw(embedded_cp_fw)
      rnn_dp_bw, rnn_hidden_dp_bw = self.gru_dp_bw(embedded_dp_bw)
      rnn_cp_bw, rnn_hidden_cp_bw = self.gru_cp_bw(embedded_cp_bw)
      # concatenate forward and backward
      ## output dim: batch_size x seq_len x 2*(embedding_dim+1)
      rnn_dp = torch.cat((rnn_dp_fw, torch.flip(rnn_dp_bw, [1])), dim=-1)
      rnn_cp = torch.cat((rnn_cp_fw, torch.flip(rnn_cp_bw, [1])), dim=-1)

      # Attention
      ## output dim: batch_size x 2*(embedding_dim+1)
      attended_dp, weights_dp = self.attention_dp(rnn_dp, (dp > 0).float())
      attended_cp, weights_cp = self.attention_cp(rnn_cp, (cp > 0).float())

      # Scores
      score_dp = self.fc_dp(self.dropout(attended_dp))
      score_cp = self.fc_cp(self.dropout(attended_cp))

      # Concatenate to variable collection
      all = torch.cat((stat, score_dp, score_cp), dim=1)

      # Final linear projection
      out = self.fc_all(self.dropout(all)).squeeze()

      return out, []

## Train

In [None]:
print('Load data...')
data = np.load(data_dir + 'data_arrays.npz')

trainloader, num_batches, pos_weight = get_trainloader(data, 'TRAIN')

static = num_statics(data)
num_dp_codes, num_cp_codes = vocab_sizes(data)

print('-----------------------------------------')
print('Train...')

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
torch.backends.cudnn.benchmark = True

# Network
net = Net(static, num_dp_codes, num_cp_codes).to(device)

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
optimizer = optim.Adam(net.parameters(), lr = 0.001)

# Store times
epoch_times = []

# Train
for epoch in tqdm(range(num_epochs)):
  # print('-----------------------------------------')
  # print('Epoch: {}'.format(epoch))
  net.train()
  time_start = time()
  for i, (stat, dp, cp, dp_t, cp_t, label) in enumerate(tqdm(trainloader), 0):
    # move to GPU if available
    stat  = stat.to(device)
    dp    = dp.to(device)
    cp    = cp.to(device)
    dp_t  = dp_t.to(device)
    cp_t  = cp_t.to(device)
    label = label.to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    label_pred, _ = net(stat, dp, cp, dp_t, cp_t)
    loss = criterion(label_pred, label)
    loss.backward()
    optimizer.step()

# timing
time_end = time()
epoch_times.append(time_end-time_start)

In [None]:
print('Saving...')
torch.save(net.state_dict(), logdir + net_variant+ '.pt')
np.savez(logdir + net_variant, epoch_times=epoch_times)
print('Done')

## Test

In [None]:
from __future__ import print_function
import torch
import numpy as np
import pandas as pd
import pickle
import scipy.stats as st
import os
from tqdm import tqdm
#import matplotlib.pyplot as plt
from sklearn.metrics import *
from sklearn.calibration import calibration_curve
from pdb import set_trace as bp

def round(num):
  return np.round(num*1000)/1000

if __name__ == '__main__':
  # Load data
  print('Load data...')
  data = np.load(data_dir + 'data_arrays.npz')
  test_ids_patients = pd.read_pickle(data_dir + 'test_ids_patients.pkl')

  # Patients in test data
  patients = test_ids_patients.drop_duplicates()
  num_patients = patients.shape[0]
  row_ids = pd.DataFrame({'ROW_IDX': test_ids_patients.index}, index=test_ids_patients)

  # Vocabulary sizes
  num_static = num_statics(data)
  num_dp_codes, num_cp_codes = vocab_sizes(data)

  # CUDA for PyTorch
  use_cuda = torch.cuda.is_available()
  device = torch.device('cuda:0' if use_cuda else 'cpu')
  torch.backends.cudnn.benchmark = True

  # Network
  net = Net(num_static, num_dp_codes, num_cp_codes).to(device)

  print('Evaluate...')
  # Set log dir to read trained model from
  model = logdir + net_variant +'.pt'

  # Restore variables from disk
  net.load_state_dict(torch.load(model, map_location=device))

  # Bootstrapping
  np.random.seed(np_seed)
  avpre_vec = np.zeros(bootstrap_samples)
  auroc_vec = np.zeros(bootstrap_samples)
  f1_vec    = np.zeros(bootstrap_samples)
  sensitivity_vec = np.zeros(bootstrap_samples)
  specificity_vec = np.zeros(bootstrap_samples)
  ppv_vec = np.zeros(bootstrap_samples)
  npv_vec = np.zeros(bootstrap_samples)

  for sample in range(bootstrap_samples):
    print('Bootstrap sample {}'.format(sample))

    # Test data
    sample_patients = patients.sample(n=num_patients, replace=True)
    idx = np.squeeze(row_ids.loc[sample_patients].values)
    testloader, _, _ = get_trainloader(data, 'TEST', shuffle=False, idx=idx)

    # evaluate on test data
    net.eval()
    label_pred = torch.Tensor([])
    label_test = torch.Tensor([])
    with torch.no_grad():
      for i, (stat, dp, cp, dp_t, cp_t, label_batch) in enumerate(tqdm(testloader), 0):
        # move to GPU if available
        stat  = stat.to(device)
        dp    = dp.to(device)
        cp    = cp.to(device)
        dp_t  = dp_t.to(device)
        cp_t  = cp_t.to(device)

        label_pred_batch, _ = net(stat, dp, cp, dp_t, cp_t)
        label_pred = torch.cat((label_pred, label_pred_batch.cpu()))
        label_test = torch.cat((label_test, label_batch))

    label_sigmoids = torch.sigmoid(label_pred).cpu().numpy()

    # Average precision
    avpre = average_precision_score(label_test, label_sigmoids)

    # Determine AUROC score
    auroc = roc_auc_score(label_test, label_sigmoids)

    # Sensitivity, specificity
    fpr, tpr, thresholds = roc_curve(label_test, label_sigmoids)
    youden_idx = np.argmax(tpr - fpr)
    sensitivity = tpr[youden_idx]
    specificity = 1-fpr[youden_idx]

    # F1, PPV, NPV score
    f1 = 0
    ppv = 0
    npv = 0
    for t in thresholds:
      label_pred = (np.array(label_sigmoids) >= t).astype(int)
      f1_temp = f1_score(label_test, label_pred)
      ppv_temp = precision_score(label_test, label_pred, pos_label=1)
      npv_temp = precision_score(label_test, label_pred, pos_label=0)
      if f1_temp > f1:
        f1 = f1_temp
      if (ppv_temp+npv_temp) > (ppv+npv):
        ppv = ppv_temp
        npv = npv_temp

    # Store in vectors
    avpre_vec[sample] = avpre
    auroc_vec[sample] = auroc
    f1_vec[sample]    = f1
    sensitivity_vec[sample]  = sensitivity
    specificity_vec[sample]  = specificity
    ppv_vec[sample]  = ppv
    npv_vec[sample]  = npv

  avpre_mean = np.mean(avpre_vec)
  avpre_lci, avpre_uci = st.t.interval(0.95, bootstrap_samples-1, loc=avpre_mean, scale=st.sem(avpre_vec))
  auroc_mean = np.mean(auroc_vec)
  auroc_lci, auroc_uci = st.t.interval(0.95, bootstrap_samples-1, loc=auroc_mean, scale=st.sem(auroc_vec))
  f1_mean = np.mean(f1_vec)
  f1_lci, f1_uci = st.t.interval(0.95,bootstrap_samples-1, loc=f1_mean, scale=st.sem(f1_vec))
  ppv_mean = np.mean(ppv_vec)
  ppv_lci, ppv_uci = st.t.interval(0.95, bootstrap_samples-1, loc=ppv_mean, scale=st.sem(ppv_vec))
  npv_mean = np.mean(npv_vec)
  npv_lci, npv_uci = st.t.interval(0.95, bootstrap_samples-1, loc=npv_mean, scale=st.sem(npv_vec))
  sensitivity_mean = np.mean(sensitivity_vec)
  sensitivity_lci, sensitivity_uci = st.t.interval(0.95, bootstrap_samples-1, loc=sensitivity_mean, scale=st.sem(sensitivity_vec))
  specificity_mean = np.mean(specificity_vec)
  specificity_lci, specificity_uci = st.t.interval(0.95, bootstrap_samples-1, loc=specificity_mean, scale=st.sem(specificity_vec))

  epoch_times = np.load(logdir + net_variant + '.npz')['epoch_times']
  times_mean = np.mean(epoch_times)
  times_lci, times_uci = st.t.interval(0.95, len(epoch_times)-1, loc=np.mean(epoch_times), scale=st.sem(epoch_times))
  times_std = np.std(epoch_times)

  print('------------------------------------------------')
  print('Net variant: {}'.format(net_variant))
  print('Average Precision: {} [{},{}]'.format(round(avpre_mean), round(avpre_lci), round(avpre_uci)))
  print('AUROC: {} [{},{}]'.format(round(auroc_mean), round(auroc_lci), round(auroc_uci)))
  print('F1: {} [{},{}]'.format(round(f1_mean), round(f1_lci), round(f1_uci)))
  print('PPV: {} [{},{}]'.format(round(ppv_mean), round(ppv_lci), round(ppv_uci)))
  print('NPV: {} [{},{}]'.format(round(npv_mean), round(npv_lci), round(npv_uci)))
  print('Sensitivity: {} [{},{}]'.format(round(sensitivity_mean), round(sensitivity_lci), round(sensitivity_uci)))
  print('Specificity: {} [{},{}]'.format(round(specificity_mean), round(specificity_lci), round(specificity_uci)))
  print('Time: {} [{},{}] std: {}'.format(round(times_mean), round(times_lci), round(times_uci), round(times_std)))
  print('Done')

In [None]:
from sklearn import linear_model
def round(num):
  return np.round(num*1000)/1000

if 1 ==1:
  # Load icu_pat table
  print('Loading data...')
  icu_pat = pd.read_pickle(data_dir + 'icu_pat_admit.pkl')

  print('Loading last vital signs measurements...')
  charts = pd.read_pickle(data_dir + 'charts_outputs_last_only.pkl')
  charts = charts.drop(columns=['CHARTTIME'])
  charts = pd.get_dummies(charts, columns = ['VALUECAT']).groupby('ICUSTAY_ID').sum()
  charts.drop(columns=['VALUECAT_CHART_BP_n', 'VALUECAT_CHART_BT_n', 'VALUECAT_CHART_GC_n', 'VALUECAT_CHART_HR_n', 'VALUECAT_CHART_RR_n', 'VALUECAT_CHART_UO_n'], inplace=True) # drop reference columns

  print('-----------------------------------------')

  print('Create array of static variables...')

  num_icu_stays = len(icu_pat['ICUSTAY_ID'])

  # static variables
  print('Create static array...')
  icu_pat = pd.get_dummies(icu_pat, columns = ['ADMISSION_LOCATION', 'INSURANCE', 'MARITAL_STATUS', 'ETHNICITY'])
  icu_pat.drop(columns=['ADMISSION_LOCATION_Emergency Room Admit', 'INSURANCE_Medicare', 'MARITAL_STATUS_Married/Life Partner', 'ETHNICITY_White'], inplace=True) # drop reference columns

  # merge with last vital signs measurements
  icu_pat = pd.merge(icu_pat, charts, how='left', on='ICUSTAY_ID').fillna(0)

  static_columns = icu_pat.columns.str.contains('AGE|GENDER_M|LOS|NUM_RECENT_ADMISSIONS|ADMISSION_LOCATION|INSURANCE|MARITAL_STATUS|ETHNICITY|PRE_ICU_LOS|ELECTIVE_SURGERY|VALUECAT')
  static = icu_pat.loc[:, static_columns].values
  static_vars = icu_pat.loc[:, static_columns].columns.values.tolist()

  # classification label
  print('Create label array...')
  label = icu_pat.loc[:, 'POSITIVE'].values

  print('-----------------------------------------')

  print('Split data into train/validate/test...')
  # Split patients to avoid data leaks
  patients = icu_pat['SUBJECT_ID'].drop_duplicates()
  train, validate, test = np.split(patients.sample(frac=1, random_state=123), [int(.9*len(patients)), int(.9*len(patients))])
  train_ids = icu_pat['SUBJECT_ID'].isin(train).values
  test_ids = icu_pat['SUBJECT_ID'].isin(test).values

  data_train = static[train_ids, :]
  data_test = static[test_ids, :]

  label_train = label[train_ids]
  label_test = label[test_ids]

  # Patients in test data
  test_ids_patients = pd.read_pickle(data_dir + 'test_ids_patients.pkl')
  patients = test_ids_patients.drop_duplicates()
  num_patients = patients.shape[0]
  row_ids = pd.DataFrame({'ROW_IDX': test_ids_patients.index}, index=test_ids_patients)

  print('-----------------------------------------')

  # Fit logistic regression model
  print('Fit logistic regression model...')
  regr = linear_model.LogisticRegression()
  regr.fit(data_train, label_train)

  # Bootstrapping
  np.random.seed(np_seed)
  avpre_vec = np.zeros(bootstrap_samples)
  auroc_vec = np.zeros(bootstrap_samples)
  f1_vec    = np.zeros(bootstrap_samples)
  sensitivity_vec = np.zeros(bootstrap_samples)
  specificity_vec = np.zeros(bootstrap_samples)
  ppv_vec = np.zeros(bootstrap_samples)
  npv_vec = np.zeros(bootstrap_samples)

  for sample in range(bootstrap_samples):
    print('Bootstrap sample {}'.format(sample))

    sample_patients = patients.sample(n=num_patients, replace=True)
    idx = np.squeeze(row_ids.loc[sample_patients].values)
    data_test_bs, label_test_bs = data_test[idx], label_test[idx]

    label_sigmoids = regr.predict_proba(data_test_bs)[:, 1]

    print('Evaluate...')
    # Average precision
    avpre = average_precision_score(label_test_bs, label_sigmoids)

    # Determine AUROC score
    auroc = roc_auc_score(label_test_bs, label_sigmoids)

    # Sensitivity, specificity
    fpr, tpr, thresholds = roc_curve(label_test_bs, label_sigmoids)
    youden_idx = np.argmax(tpr - fpr)
    sensitivity = tpr[youden_idx]
    specificity = 1-fpr[youden_idx]

    # F1, PPV, NPV score
    f1 = 0
    ppv = 0
    npv = 0
    for t in thresholds:
      label_pred = (np.array(label_sigmoids) >= t).astype(int)
      f1_temp = f1_score(label_test_bs, label_pred)
      ppv_temp = precision_score(label_test_bs, label_pred, pos_label=1)
      npv_temp = precision_score(label_test_bs, label_pred, pos_label=0)
      if f1_temp > f1:
        f1 = f1_temp
      if (ppv_temp+npv_temp) > (ppv+npv):
        ppv = ppv_temp
        npv = npv_temp

    # Store in vectors
    avpre_vec[sample] = avpre
    auroc_vec[sample] = auroc
    f1_vec[sample]    = f1
    sensitivity_vec[sample]  = sensitivity
    specificity_vec[sample]  = specificity
    ppv_vec[sample]  = ppv
    npv_vec[sample]  = npv

  avpre_mean = np.mean(avpre_vec)
  avpre_lci, avpre_uci = st.t.interval(0.95, bootstrap_samples-1, loc=avpre_mean, scale=st.sem(avpre_vec))
  auroc_mean = np.mean(auroc_vec)
  auroc_lci, auroc_uci = st.t.interval(0.95, bootstrap_samples-1, loc=auroc_mean, scale=st.sem(auroc_vec))
  f1_mean = np.mean(f1_vec)
  f1_lci, f1_uci = st.t.interval(0.95, bootstrap_samples-1, loc=f1_mean, scale=st.sem(f1_vec))
  ppv_mean = np.mean(ppv_vec)
  ppv_lci, ppv_uci = st.t.interval(0.95, bootstrap_samples-1, loc=ppv_mean, scale=st.sem(ppv_vec))
  npv_mean = np.mean(npv_vec)
  npv_lci, npv_uci = st.t.interval(0.95, bootstrap_samples-1, loc=npv_mean, scale=st.sem(npv_vec))
  sensitivity_mean = np.mean(sensitivity_vec)
  sensitivity_lci, sensitivity_uci = st.t.interval(0.95, bootstrap_samples-1, loc=sensitivity_mean, scale=st.sem(sensitivity_vec))
  specificity_mean = np.mean(specificity_vec)
  specificity_lci, specificity_uci = st.t.interval(0.95, bootstrap_samples-1, loc=specificity_mean, scale=st.sem(specificity_vec))

  print('------------------------------------------------')
  print('Net variant: logistic regression')
  print('Average Precision: {} [{},{}]'.format(round(avpre_mean), round(avpre_lci), round(avpre_uci)))
  print('AUROC: {} [{},{}]'.format(round(auroc_mean), round(auroc_lci), round(auroc_uci)))
  print('F1: {} [{},{}]'.format(round(f1_mean), round(f1_lci), round(f1_uci)))
  print('PPV: {} [{},{}]'.format(round(ppv_mean), round(ppv_lci), round(ppv_uci)))
  print('NPV: {} [{},{}]'.format(round(npv_mean), round(npv_lci), round(npv_uci)))
  print('Sensitivity: {} [{},{}]'.format(round(sensitivity_mean), round(sensitivity_lci), round(sensitivity_uci)))
  print('Specificity: {} [{},{}]'.format(round(specificity_mean), round(specificity_lci), round(specificity_uci)))
  print('Done')

# Results
At the end of this project, I have attempted to replicate the results of the orginal paper and test the 2 main hypotheses.

1.  Evaluate the feasibility of using neural ODEs to model how the predictive relevance of recorded medical codes changes over time
I will implement and run ODE architecture models to determine the power of ODEs with time data.

  To this end, I have implemented neural ODEs to model the nuances in time of medical code embeddings. These neural ODEs should have provided better insight and weight to variables such as length of stay and change in vital signs over time. From the table below, we can conclude that the inclusion of ODE architecture led to a sizable increase in AUROC. In particular, the ODE architecture outperformed the base logistic regression that is being used as a baseline. However, the use of decay gates caused the models to perform worse. This is true for both the ODE models and non-ODE models. In conclusion, neural ODEs have merit in predictive model architecture but should be used in isolation from decay gates.

2. Perform a comprehensive comparison of deep learning models that have been proposed for processing time-series sampled at irregular intervals, including MCEs, neural ODEs, attention mechanisms, and recurrent layers I will implement various types of architectures and compare them against each other using metrics such as: precision, AUROC, and F1 score

  I have implemented the 14 models outlined in this model including the baseline logistic regression. They have been compared against each other using precision, AUROC, and F1 score. Overall, the results of the various architectures had consistent results with each other.
  
  Decay functionality that took into account time-related information was applied by instating an exponential decay to the time differences between observations to the internal memory state of the recurrent cell. This approach to time embeddings caused most models that contianed this architecture to perform significantly worse than their counterparts. This can be clearly seen with the results of the birnn_time_decay model.
  
  ODE architecture as mentioned above seemed to perform best in most models. Modeling the time dynamics in the memory state of the recurrent cells using ODEs proved to be far more effective than decay for the purpose of representing time in RNN models.

  MCE or medical concept embeddings provided marginal increases to the accuracy of these models. Pretrained MCEs provided by the original paper assisted the two models utilizing MCEs to outperform all other models aside from the ODE attention models.

  Attention mechanisms that allowed these models to focus on specific inputs was also a winning portion of the architecture as all attention models outperformed their non-atttention counterparts. This can be most clearly seen in the increase in precision of the birnn_time_decay models.

  


##### Table 1

 Summary statistics for the different algorithms used to predict readmission within 30 days of discharge from the intensive care unit.

Model | Precision | AUROC | F1
-------------------|------------------|-----------------|-------------
birnn_concat_time_delta | 0.277 | 0.652 | 0.297
birnn_concat_time_delta_attention | 0.262 | 0.618 | 0.282
birnn_time_decay |0.136 | 0.516|0.227
birnn_time_decay_attention |0.295 | 0.691|0.323
ode_birnn |0.293 | 0.693|0.328
ode_birnn_attention |0.289 |0.698 |0.336
ode_attention |0.305 | 0.707|0.338
attention_concat_time |0.258 | 0.625|0.285
birnn_ode_decay |0.241 | 0.64|0.284
birnn_ode_decay_attention |0.263 | 0.633|0.288
mce_attention |0.276 | 0.669|0.306
mce_birnn_attention|0.272 | 0.673|0.314
logisitic regression | 0.265 | 0.658 | 0.296

## Model comparison

In [None]:
# compare you model with others
# you don't need to re-run all other experiments, instead, you can directly refer the metrics/numbers in the paper

# Discussion

Make assessment that the paper is reproducible or
not.
∗ Explain why it is not reproducible if your results
are kind negative.
∗ Describe “What was easy” and “What was
difficult” during the reproduction.
∗ Make suggestions to the author or other
reproducers on how to improve the reproducibility.

-------------------------------------------------------------------------------

The paper is reproducible though it is difficult to do within the training time constraint as this particular paper requires the training of multiple models. In this demo, I used a small set of data in order to get a reasonable training time. In addition, I elected to upload preprocessed data from the original paper's codebase as the processing of the data is 1. too time consuming and 2. not the main focus of the paper. I believe that further training with hyperparameter tuning could lead to better performances than what is displayed here and in the paper. However, for some of these architectures, the complexity renders the amount of calculations to exceed what is reasonable to be replicated on base versions of colab.

The easy portions of the reproduction of this paper were the MCEs as they were provided for us as pretrained embeddings. In addition, the usage of attention mechanisms were not foreign and mostly straightfoward in its implementation.

The difficult portions of this paper were the neural ODEs as well as the decay functionality. The orignal dataset had to be slightly modified to create data that the de novo ODE and decay cell architecture could accept. This led to awkward data restructuring that caused the largest portion of delays in this project. The ODE decay combination cell was particularly difficult to understand. The decay of the hidden state did not feel intuitive as the decay was computed by an ODE.

To the authors, I would recommend limiting the scope of the architectures for the sake of reproducibility. In terms of difficulty, the understanding of the time decay models was by far the most difficult. For this notebook, I recommend that the graders run the model that I have left uncommented for validation.

# References

1. Barbieri, S., Kemp, J., Perez-Concha, O. et al. Benchmarking Deep Learning Architectures for Predicting Readmission to the ICU and Describing Patients-at-Risk. Sci Rep 10, 1111 (2020). https://doi.org/10.1038/s41598-020-58053-z
2. Cai, X. et al. Medical concept embedding with time-aware attention. arXiv preprint arXiv:1806.02873 (2018).
3. Rubanova, Y., Chen, R. T. & Duvenaud, D. Latent odes for irregularly-sampled time series. arXiv preprint arXiv:1907.03907 (2019).
4. Mozer, M. C., Kazakov, D. & Lindsey, R. V. Discrete event, continuous time rnns. arXiv preprint arXiv:1710.04110 (2017).


