# Deep Learning For Healthcare Course Project: INPREM

https://www.kdd.org/kdd2020/accepted-papers/view/inprem-an-interpretable-and-trustworthy-predictive-model-for-healthcare

# 1. Setup

## 1.1 Import Libraries

We will need the [sparsemax](https://pypi.org/project/sparsemax/) library later.

In [1]:
!pip3 install -U sparsemax
!pip3 install -U psutil

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import os
import psutil
import pickle
import json
import ast
import random
import math

import numpy as np
import pandas as pd

from icd9 import ICD9

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sparsemax import Sparsemax

In [4]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# validate GPU usage
print("using GPU") if torch.cuda.is_available() else print("no GPU found")

# define data path
use_demo = True
if use_demo:
    DATA_PATH = "demodata/" # work with open source data
    print("using demo data")
else:
    DATA_PATH = "data/" # work with certified patient data
    print("using patient data")

# print metrics for each epoch (CPU, RAM, VRAM data) - will add time to each epoch
print_metrics = False

# icd9 tree structure
tree = ICD9('codes.json')
# tree.find('001.1').parent.parent.code

no GPU found
using demo data


In [5]:
!ls {DATA_PATH}

ADMISSIONS.csv	   D_ICD_DIAGNOSES.csv	rtypes.pkl  types.pkl
DIAGNOSES_ICD.csv  ICUSTAYS.csv		rtypes.txt  types.txt


## 1.2 Import Raw Data

First, we load the MIMIC-III dataset. We use 4 files from the dataset, `DIAGNOSES_ICD`, `D_ICD_DIAGNOSES`, `ICUSTAYS`, and `ADMISSIONS`. We initially format the data and drop the columns we don't need.

For example, `subject_id` refers to a unique patient, `hadm_id` refers to a unique admission to the hospital, and `icustay_id` refers to a unique admission to an intensive care unit. We also keep the `out time` for each ICU stay.

In [6]:
def load_dataset(filepath):
    return pd.read_csv(filepath)

def convert_datetime_to_day(df):
    temp = pd.DataFrame()
    temp["date"] = pd.to_datetime(df['outtime'], format="%Y-%m-%d %H:%M:%S")
    return str(temp['date'].dt.year) + str(temp['date'].dt.month) + str(temp['date'].dt.day)

diag_icd = load_dataset(os.path.join(DATA_PATH, 'DIAGNOSES_ICD.csv'))
icd_descriptions = load_dataset(os.path.join(DATA_PATH, 'D_ICD_DIAGNOSES.csv'))
icustays = load_dataset(os.path.join(DATA_PATH, 'ICUSTAYS.csv'))
admissions = load_dataset(os.path.join(DATA_PATH, 'ADMISSIONS.csv'))

diag_icd = diag_icd.rename(columns={"hadm_id".upper(): "hadm_id", "icd9_code".upper(): "icd9_code"})
icustays = icustays.rename(columns={"subject_id".upper(): "subject_id", "hadm_id".upper(): "hadm_id", "icustay_id".upper(): "icustay_id", "outtime".upper(): "outtime"})

diag_icd = diag_icd[["hadm_id", "icd9_code"]]
icustays = icustays[["subject_id", "hadm_id", "icustay_id", "outtime"]]


print(f"diag_icd ({len(diag_icd)} lines):\n", diag_icd.head(), end="\n\n")
print(f"icustays ({len(icustays)} lines):\n", icustays.head(), end="\n\n")
# print(f"admissions ({admissions.size} lines):\n", admissions.head(), end="\n\n")

diag_icd (1761 lines):
    hadm_id icd9_code
0   142345     99591
1   142345     99662
2   142345      5672
3   142345     40391
4   142345     42731

icustays (136 lines):
    subject_id  hadm_id  icustay_id              outtime
0       10006   142345      206504  2164-10-25 12:21:07
1       10011   105331      232110  2126-08-28 18:59:00
2       10013   165520      264446  2125-10-07 15:13:52
3       10017   199207      204881  2149-05-31 22:19:17
4       10019   177759      228977  2163-05-16 03:47:04



We merge the datasets:

In [7]:
joined_df = pd.merge(diag_icd, icustays, how='inner', on='hadm_id')[["icd9_code", "subject_id", "icustay_id", "outtime"]]
print(f"joined_df ({len(joined_df)} lines):\n", joined_df.head())

joined_df (1897 lines):
   icd9_code  subject_id  icustay_id              outtime
0     99591       10006      206504  2164-10-25 12:21:07
1     99662       10006      206504  2164-10-25 12:21:07
2      5672       10006      206504  2164-10-25 12:21:07
3     40391       10006      206504  2164-10-25 12:21:07
4     42731       10006      206504  2164-10-25 12:21:07


Clean up ICD-9 codes and convert to category labels for `y`, append all of our converted codes to `X` and `y`, and finally build a dictionary to convert types to codes.

In [8]:
def convert_codes(codes):
    """Clean up ICD-9 codes and convert to category labels for `y`.
    
    If the codes contain 3, 4, or 5 digits, always use left 3. 
    If it starts with E, use (Exxx). If it starts with V, use (Vxx).
    Also appends to a unique set of data to check number of unique codes.
    
    Inputs:
        codes: a set of ICD-9 codes
    
    Outputs:
        out: an array of category codes
    """
    out = []
    for code in codes:
        code = str(code)
        if code[0] == "E":
            c = code[:4]
        else:
            c = code[:3]
        out.append(c)
        
        unique_codes.add(c)
       
    return out


def build_dictionaries(codes):
    """Construct dicts to map/ reverse map string input codes to keys, e.g. {'001': 0, '002': 1}
    
    Inputs:
        codes: a set of unique category codes
    
    Outputs:
        types: a dictionary to map codes to types
        rtypes: a dictionary to map types to codes
    """
    types = dict((diag, idx) for idx, diag in enumerate(codes))
    rtypes = dict((idx, diag) for idx, diag in enumerate(codes))
    
    return types, rtypes


X = []
y = []
all_codes = []
unique_codes = set()

# set up X and y
for name, patient in joined_df.sort_values("outtime").groupby(["subject_id"]):
    visits = []
    for _, visit in patient.groupby(["icustay_id"]):
        codes = visit["icd9_code"].tolist()
        codes = convert_codes(codes)
        visits.append(codes)
    if len(visits) >= 2:
        x, y_ = visits[:-1], visits[-1]
        X.append(x)
        y.append(y_)

types, rtypes = build_dictionaries(list(unique_codes))
print(f"Using {len(X)} patients")

# check mapping
# print(unique_codes)
# print('diag mapping for DIAG_V10:', types['V10']) # 75
# print('reverse mapping for index 75:', rtypes[75])

Using 19 patients


In [9]:
assert(len(X) == len(y))

In [10]:
# print("visits (x):", X, "\n")
# print("last visit (y):", y)

print(len(unique_codes), "codes") 
print(len(types), "types mappings") 

275 codes
275 types mappings


In [11]:
# ICD 9 Codes for Binary Classification
diabetes = ("Diabetes", "250.xx")
heart_failure = ("Heary Failure", "428.xx")
chronic_kidney_disease = ("Chronic Kidney Disease", "585.xx")

## 2. Build the Dataset

## 2.1 Build Custom Dataset

We implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the sequences of diagnosis codes seqs up to the last visit as input and the diagnosis code of last visit as output.

In [12]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __len__(self):
        '''Return the number of samples.
        '''
        
        return len(self.y)
    
    def __getitem__(self, index):
        '''Generates one sample of data.
        '''
        
        return (self.x[index], self.y[index])
        
        
dataset = CustomDataset(X, y)

## 2.2 Collate Function

This collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

In [70]:
def collate_fn_retain(data):
    """
    Collate the the list of samples into batches. For each patient, pad the diagnosis sequences 
    to the sample shape. The padding infomation is stored in `mask`.

    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patients, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patients, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patients) of type torch.float
    """
    sequences, labels = zip(*data)
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    y = torch.zeros((len(labels), len(types)), dtype=torch.bool)
    
    for i, label in enumerate(labels):
        # create one-hot vector
        for l in label:
            plc = types[l]
            y[i][plc] = True
    
    # y = y[y.any(axis=1)]
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        count = 0
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """
            visit_len = len(visit)
            
            typed_visit = visit.copy()
            for idx, code in enumerate(visit): # convert to mapping
                typed_visit[idx] = types[visit[idx]]
            
            x[i_patient][j_visit][:visit_len] = torch.tensor(typed_visit, dtype=torch.long)
            masks[i_patient][j_visit][:visit_len] = torch.ones((visit_len),dtype=torch.bool)
            count += 1
            
        reverse_x = x[i_patient][:count]
        reverse_mask = masks[i_patient][:count]
        
        rev_x[i_patient][:count] = torch.flip(reverse_x, [0])
        rev_masks[i_patient][:count] = torch.flip(reverse_mask, [0])
        
    return x, masks, rev_x, rev_masks, y

In [83]:
num_visits = [len(patient) for patient in X]

max_num_visits = max(num_visits)
max_num_codes = len(types)

def collate_fn_inprem(data):
    """
    Collate the the list of samples into batches. For each patient, pad the diagnosis sequences 
    to the sample shape. 
        
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patients, max # visits, max # diagnosis codes) of type torch.long
        o:a tensor of shape (# patients, max # visits) of type torch.long
        y: a tensor of shape (# patients) of type torch.float
    """
    sequences, labels = zip(*data)
    num_patients = len(sequences)

    num_types = len(types)
    max_num_codes = num_types
    
    y = torch.zeros((len(labels), num_types), dtype=torch.bool)
    
    for i, label in enumerate(labels):
        # create one-hot vector
        for l in label:
            plc = types[l]
            y[i][plc] = True
        print(y[i])
        
    # y = y[y.any(axis=1)]
    
    x = torch.zeros((num_patients, max_num_visits, num_types), dtype=torch.long)
    o = torch.zeros((num_patients, max_num_visits), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, num_types), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        count = 1
        for j_visit, visit in enumerate(patient):
            for code in visit : # convert to mapping
                plc = types[code]
                x[i_patient][j_visit][plc] = True
            o[i_patient][j_visit] = count
            count += 1
    
    return x, o, y


print(max_num_codes, max_num_visits)

275 14


## 2.3 Split Dataset

For each task, we randomly split each dataset into training and validation sets in an 80:20 ratio. This differs from the paper's implementation which uses a 75:10:15 training, testing, and validation split five times. We ran into issues evaluating the model against a test set and chose to return back to the 80:20 training/ validation split we had learned in previous implementations like RETAIN.

In [84]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.75)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 14
Length of val dataset: 5


## 2.4 Dataloader

In [85]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    '''    
    Return the data loader for  train and validation dataset
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    '''
    
    batch_size = 32
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn_inprem, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn_inprem, shuffle=False)

    return train_loader, val_loader


train_loader, val_loader= load_data(train_dataset, val_dataset, collate_fn_inprem)

# 3. Build the Models

We treat the medical events taking place in EHR as medical codes, which are denoted as  𝑐1,𝑐2,...𝑐|𝐶|  ∈ 𝐶, where |𝐶| is the total number of unique medical codes.

One specific patient consist of a sequence of visits  𝑣1,𝑣2,...𝑣𝑇  where we denote the number of visits in total as T.

## Retain (Baseline)

First we implement the baseline RETAIN model used in the paper to compare against INPREM.

In [86]:
def load_data_retain(train_dataset, val_dataset, collate_fn):
    
    batch_size = 32
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    return train_loader, val_loader

train_loader_retain, val_loader_retain = load_data_retain(train_dataset, val_dataset, collate_fn_retain)

In [87]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
    TODO: mask select the hidden states for true visits (not padding visits) and then
        sum the them up.

    Arguments:
        alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
        beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
        rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        rev_masks: the padding masks in reversed time of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        c: the context vector of shape (batch_size, hidden_dim)
        
    NOTE: Do NOT use for loop.
    """
    
    out = None
    
    rev_masks = rev_masks.max(axis=2).values
    rev_masks = rev_masks.unsqueeze(-1)
    
    v = rev_masks * rev_v
    
    out = torch.sum(alpha * beta * v, 1)
    
    return out

def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [88]:
# alpha attention
# beta attention
# attention sum
# sum embeddings with mask

class AlphaAttentionRetain(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.a_att` for alpha-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        TODO: Implement the alpha attention.
        
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
            
        HINT: consider `torch.softmax`
        """
        
        
        y = self.a_att(g)
        y = torch.softmax(y, 2)
        
        return y
    
class BetaAttentionRetain(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.b_att` for beta-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        TODO: Implement the beta attention.
        
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
            
        HINT: consider `torch.tanh`
        """
        
        y = self.b_att(h)
        y = torch.tanh(y)
        
        return y
    

class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(num_codes, embedding_dim)

        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)

        self.att_a = AlphaAttentionRetain(embedding_dim)
        self.att_b = BetaAttentionRetain(embedding_dim)

        self.fc = nn.Linear(embedding_dim, num_codes)
        self.sigmoid = nn.Sigmoid()
        
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            rev_x: the diagnosis sequence in reversed time of shape (# visits, batch_size, # diagnosis codes)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)

        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)

        alpha = self.att_a(g)
        beta = self.att_b(h)

        c = attention_sum(alpha, beta, rev_x, rev_masks)

        logits = self.fc(c)
        probs = self.sigmoid(logits)
        
        return probs.squeeze()
    
    
# load the model
retain = RETAIN(num_codes = len(types))
retain

RETAIN(
  (embedding): Embedding(275, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttentionRetain(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttentionRetain(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=275, bias=True)
  (sigmoid): Sigmoid()
)

In [89]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_retain(model, val_loader):
    
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_logit = model(x, masks, rev_x, rev_masks)
        y_hat = (y_logit > 0.5).int()
        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred)
    # roc_auc = roc_auc_score(y_true, y_score)
    roc_auc = 0.0
    
    return p, r, f, roc_auc


def train_retain(model, train_loader, val_loader, n_epochs):
    """
    Train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """
    
    for epoch in range(n_epochs):
        model.train()
        
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
            y = y.type(torch.FloatTensor)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()  
            
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc = eval_retain(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'.format(epoch+1, p, r, f, roc_auc))
    
    return round(roc_auc, 2)


# load the model
retain = RETAIN(num_codes = len(types))

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
optimizer = torch.optim.Adam(retain.parameters(), lr=1e-3)

n_epochs = 5
train_retain(retain, train_loader_retain, val_loader_retain, n_epochs)

Epoch: 1 	 Training Loss: 0.905809


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TypeError: unsupported format string passed to numpy.ndarray.__format__

## INPREM

Each visit contains a subset of medical codes, and we denote each visit as a binary vector  $v_{t} ∈ \{0, 1\}_{|C|}$, where the 𝑖-th element is set to 1 if the 𝑡-th visit contains the medical code $c_{i}$, otherwise 0. The visits  $v_{1}, v_{2},... v_{T}$ are stacked to form an input matrix $X ∈ \{0, 1\}^{|C|xT}$ , which we use as the input for the network

$E_{v} = {W}_{v}X$

$E_{o} = {W}_{o}O$

$E_{r} = \alpha(\beta \odot (E_{v}+E_{o}))^{T}$

In [90]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim=256):
        super().__init__()
        
        self.lin = nn.Linear(hidden_dim, 1)
        
        self.sparsemax = Sparsemax()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, g):
        
        y = self.lin(g)
        sparse_max = self.sparsemax(y)
        soft_max = F.softmax(y, dim=1)
        
        out = (sparse_max + soft_max) / 2
        
        return out
    
class BetaAttention(torch.nn.Module):
    
    def __init__(self, hidden_dim=256):
        super().__init__()
        
        self.lin = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        
        y = self.lin(h)
        out = torch.tanh(y)
        
        return out
    
class MultiHeadAttention(torch.nn.Module):
    
    def __init__(self, num_codes, hidden_dim=256, num_attention_layers=2):
        super().__init__()
        
        self.fc_q = nn.Linear(hidden_dim, hidden_dim)
        self.fc_v = nn.Linear(hidden_dim, hidden_dim)
        self.fc_k = nn.Linear(hidden_dim, hidden_dim)
        
        self.fc_multi = nn.Linear(num_attention_layers * hidden_dim, hidden_dim)
        
        self.conv1 = torch.nn.Conv1d(hidden_dim, hidden_dim, 1)
        
        self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, 1)
        
        self.hidden_dim = hidden_dim
        
        self.relu = nn.ReLU()

    def forward(self, E):
        
        Q = self.fc_q(E)
        K = self.fc_k(E)
        V = self.fc_v(E)
        
        attn = V * F.softmax((Q * K) / math.sqrt(self.hidden_dim), dim=1)
        
        x = self.fc_multi(torch.concat((attn, attn), dim=2))
        
        x = x.permute(0, 2, 1)
        
        x = self.conv1(x)
        x = self.conv2(x)
        
        x = self.relu(x)
        
        x = x.permute(0, 2, 1)
        
        return x

In [91]:
class INPREM(nn.Module):
    
    def __init__(self, num_codes, num_visits, embedding_dim=256):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        
        # self.embedding_v = nn.Embedding(num_codes, embedding_dim)
        # self.embedding_o = nn.Embedding(num_codes, embedding_dim)
        
        self.embedding_v = nn.Linear(num_codes, embedding_dim, bias=False)
        self.embedding_o = nn.Linear(1, embedding_dim, bias=False)
        
        self.att_a = AlphaAttention(embedding_dim)
        self.att_b = BetaAttention(embedding_dim)
        
        self.multi_attn = MultiHeadAttention(num_codes, embedding_dim)
        
        self.fc = nn.Linear(embedding_dim, num_codes)
        
        self.do = nn.Dropout(.5)
        
        self.tan = nn.Tanh()
    
    def forward(self, X, O):
        """
        Implement the INPREM Model
        """
        
        X = X.type(torch.FloatTensor)
        O = O.type(torch.FloatTensor)
        
        O = O.unsqueeze(dim=2)
        
        # Get E_v
        E_v = self.embedding_v(X)
        E_o = self.embedding_o(O)
        
        E = torch.add(E_o,E_v)
        
        H =  self.multi_attn(E)
        H =  self.multi_attn(H)
        
        beta = self.att_b(H)
        alpha = self.att_a(H).permute(0,2,1)
        
        E_r = (alpha @ (beta * E))
        
        out = self.fc(E_r)
        
        out = out.squeeze()
        out = F.softmax(out, dim=1)
        # out = self.tan(out)
        
        return out
    
    
# load the model here
model = INPREM(len(types), max_num_visits)
model

INPREM(
  (embedding_v): Linear(in_features=275, out_features=256, bias=False)
  (embedding_o): Linear(in_features=1, out_features=256, bias=False)
  (att_a): AlphaAttention(
    (lin): Linear(in_features=256, out_features=1, bias=True)
    (sparsemax): Sparsemax(dim=-1)
    (softmax): Softmax(dim=1)
  )
  (att_b): BetaAttention(
    (lin): Linear(in_features=256, out_features=256, bias=True)
  )
  (multi_attn): MultiHeadAttention(
    (fc_q): Linear(in_features=256, out_features=256, bias=True)
    (fc_v): Linear(in_features=256, out_features=256, bias=True)
    (fc_k): Linear(in_features=256, out_features=256, bias=True)
    (fc_multi): Linear(in_features=512, out_features=256, bias=True)
    (conv1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
  )
  (fc): Linear(in_features=256, out_features=275, bias=True)
  (do): Dropout(p=0.5, inplace=False)
  (tan): Tanh()
)

### INPREM_O (Removing Ordering)

In [54]:
class INPREM_O(nn.Module):
    
    def __init__(self, num_codes, num_visits, embedding_dim=256):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        
        # self.embedding_v = nn.Embedding(num_codes, embedding_dim)
        # self.embedding_o = nn.Embedding(num_codes, embedding_dim)
        
        self.embedding_v = nn.Linear(num_codes, embedding_dim, bias=False)
        self.embedding_o = nn.Linear(1, embedding_dim, bias=False)
        
        self.att_a = AlphaAttention(embedding_dim)
        self.att_b = BetaAttention(embedding_dim)
        
        self.multi_attn = MultiHeadAttention(num_codes, embedding_dim)
        
        self.fc = nn.Linear(embedding_dim, num_codes)
        
        self.do = nn.Dropout(.5)
        
        self.tan = nn.Tanh()
    
    def forward(self, X, O):
        """
        Implement the INPREM Model
        """
        
        X = X.type(torch.FloatTensor)
        
        E = self.embedding_v(X)
        
        H =  self.multi_attn(E)
        H =  self.multi_attn(H)
        
        beta = self.att_b(H)
        alpha = self.att_a(H).permute(0,2,1)
        
        E_r = (alpha @ (beta * E))
        
        out = self.fc(E_r)
        
        out = out.squeeze()
        out = F.softmax(out, dim=1)
        
        return out
    
    
# load the model here
model = INPREM_O(len(types), max_num_visits)
model

INPREM_O(
  (embedding_v): Linear(in_features=1071, out_features=256, bias=False)
  (embedding_o): Linear(in_features=1, out_features=256, bias=False)
  (att_a): AlphaAttention(
    (lin): Linear(in_features=256, out_features=1, bias=True)
    (sparsemax): Sparsemax(dim=-1)
    (softmax): Softmax(dim=1)
  )
  (att_b): BetaAttention(
    (lin): Linear(in_features=256, out_features=256, bias=True)
  )
  (multi_attn): MultiHeadAttention(
    (fc_q): Linear(in_features=256, out_features=256, bias=True)
    (fc_v): Linear(in_features=256, out_features=256, bias=True)
    (fc_k): Linear(in_features=256, out_features=256, bias=True)
    (fc_multi): Linear(in_features=512, out_features=256, bias=True)
    (conv1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
  )
  (fc): Linear(in_features=256, out_features=1071, bias=True)
  (do): Dropout(p=0.5, inplace=False)
  (tan): Tanh()
)

In [55]:
# load the model
model = INPREM_O(len(types), max_num_visits)

# load the loss function
criterion = nn.BCELoss()

# load the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

n_epochs = 5
train(model, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.082935


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2 	 Training Loss: 0.081949


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 3 	 Training Loss: 0.081131


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 4 	 Training Loss: 0.080633


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 5 	 Training Loss: 0.080049


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0.0

### INPREM_S (Removing Sparsemax)

In [56]:
class AlphaAttentionAlt(torch.nn.Module):

    def __init__(self, hidden_dim=256):
        super().__init__()
        
        self.lin = nn.Linear(hidden_dim, 1)
        
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, g):
        
        y = self.lin(g)
        soft_max = F.softmax(y, dim=1)
        
        out = soft_max
        
        return out

In [57]:
class INPREM_S(nn.Module):
    
    def __init__(self, num_codes, num_visits, embedding_dim=256):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        
        self.embedding_v = nn.Linear(num_codes, embedding_dim, bias=False)
        self.embedding_o = nn.Linear(1, embedding_dim, bias=False)
        
        self.att_a = AlphaAttentionAlt(embedding_dim)
        self.att_b = BetaAttention(embedding_dim)
        
        self.multi_attn = MultiHeadAttention(num_codes, embedding_dim)
        
        self.fc = nn.Linear(embedding_dim, num_codes)
        
        self.do = nn.Dropout(.5)
        
        self.tan = nn.Tanh()
    
    def forward(self, X, O):
        """
        Implement the INPREM Model
        """
        
        X = X.type(torch.FloatTensor)
        
        E = self.embedding_v(X)
        
        H =  self.multi_attn(E)
        H =  self.multi_attn(H)
        
        beta = self.att_b(H)
        alpha = self.att_a(H).permute(0,2,1)
        
        E_r = (alpha @ (beta * E))
        
        out = self.fc(E_r)
        
        out = out.squeeze()
        out = F.softmax(out, dim=1)
        
        return out
    
    
# load the model here
model = INPREM_S(len(types), max_num_visits)
model

INPREM_S(
  (embedding_v): Linear(in_features=1071, out_features=256, bias=False)
  (embedding_o): Linear(in_features=1, out_features=256, bias=False)
  (att_a): AlphaAttentionAlt(
    (lin): Linear(in_features=256, out_features=1, bias=True)
    (softmax): Softmax(dim=1)
  )
  (att_b): BetaAttention(
    (lin): Linear(in_features=256, out_features=256, bias=True)
  )
  (multi_attn): MultiHeadAttention(
    (fc_q): Linear(in_features=256, out_features=256, bias=True)
    (fc_v): Linear(in_features=256, out_features=256, bias=True)
    (fc_k): Linear(in_features=256, out_features=256, bias=True)
    (fc_multi): Linear(in_features=512, out_features=256, bias=True)
    (conv1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
  )
  (fc): Linear(in_features=256, out_features=1071, bias=True)
  (do): Dropout(p=0.5, inplace=False)
  (tan): Tanh()
)

In [None]:
# load the model
model = INPREM_S(len(types), max_num_visits)

# load the loss function
criterion = nn.BCELoss()

# load the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

n_epochs = 5
train(model, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.082849


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2 	 Training Loss: 0.081906


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 3 	 Training Loss: 0.081091


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 4 	 Training Loss: 0.080491


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# 4. Training and Inferencing

## 4.1 Define Evaluation Method

In [92]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval(model, val_loader):
    
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, o, y in val_loader:
        y_logit = model(x, o)
        """
        TODO: obtain the predicted class (0, 1) by comparing y_logit against 0.5, 
              assign the predicted class to y_hat.
        """
        y_hat = (y_logit > 0.5).int()

        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred)
    
    # roc_auc = roc_auc_score(y_true, y_score)
    roc_auc = 0.0
    
    return p, r, f, roc_auc

## 4.2 Train the Model

In [93]:
def train(model, train_loader, val_loader, n_epochs):
    """
    Train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    
    Outputs: 
        roc_auc (rounded, 2): the ROC AUC score for the predictions
    """
    for epoch in range(n_epochs):
        model.train()
        
        train_loss = 0
        for x, o, y in train_loader:
        # for DATA, y in train_loader:
            optimizer.zero_grad()
            
            y_hat = model(x, o)
            torch.set_printoptions(threshold=6)
            
            y = y.type(torch.FloatTensor)

            loss = criterion(y_hat, y)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        
        p, r, f, roc_auc = eval(model, val_loader)
        
        # print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'.format(epoch+1, p, r, f, roc_auc))
        
        # get performance data
        if print_metrics:
            print('The CPU usage over the last 5 seconds is: {}'.format(psutil.cpu_percent(5)))
            print('RAM used is: {} GB'.format(psutil.virtual_memory()[3]/1000000000))
            print('GPU memory (VRAM) used: {} GB'.format((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0])/1000000000))
        
    return round(roc_auc, 2)

### Run Training

For training all approaches, we use Adam with the batch size of 32 and the learning rate of 0.0005. The weight decay is set to 𝜆 = 0.0001 and the dropout rate is set to 0.5 for all approaches

In [94]:
# load the model
model = INPREM(len(types), max_num_visits)

# load the loss function
criterion = nn.BCELoss()

# load the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

n_epochs = 5
train(model, train_loader, val_loader, n_epochs)

tensor([False, False, False,  ..., False, False, False])
tensor([ True, False, False,  ..., False,  True, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([ True, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False,  True])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False,  True])
tensor([False, False, False,  ..., False, False, False])
Epoch: 1 	 Training Loss: 0.359537
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, Fals

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0.0

## Abblations