# Deep Learning For Healthcare Course Project: INPREM

https://www.kdd.org/kdd2020/accepted-papers/view/inprem-an-interpretable-and-trustworthy-predictive-model-for-healthcare

## Setup

In [1]:
!pip3 install -U sparsemax

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import os
import pickle
import json
import ast
import random
import numpy as np
import pandas as pd

from icd9 import ICD9


# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sparsemax import Sparsemax

In [3]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# define data path
use_demo = True
if use_demo:
    DATA_PATH = "demodata/" # work with open source data
else:
    DATA_PATH = "data/" # work with PATIENT Data

tree = ICD9('codes.json')
# tree.find('001.1').parent.parent.code

In [4]:
!ls {DATA_PATH}

ADMISSIONS.csv	   D_ICD_DIAGNOSES.csv	rtypes.pkl  types.pkl
DIAGNOSES_ICD.csv  ICUSTAYS.csv		rtypes.txt  types.txt


## Import Raw Data

For example, SUBJECT_ID refers to a unique patient, HADM_ID refers to a unique admission to the hospital, and ICUSTAY_ID refers to a unique admission to an intensive care unit.

In [5]:
def load_dataset(filepath):
    return pd.read_csv(filepath)

def convert_datetime_to_day(df):
    temp = pd.DataFrame()
    temp["date"] = pd.to_datetime(df['outtime'], format="%Y-%m-%d %H:%M:%S")
    return str(temp['date'].dt.year) + str(temp['date'].dt.month) + str(temp['date'].dt.day)

diag_icd = load_dataset(os.path.join(DATA_PATH, 'DIAGNOSES_ICD.csv'))
icd_descriptions = load_dataset(os.path.join(DATA_PATH, 'D_ICD_DIAGNOSES.csv'))
icustays = load_dataset(os.path.join(DATA_PATH, 'ICUSTAYS.csv'))
admissions = load_dataset(os.path.join(DATA_PATH, 'ADMISSIONS.csv'))

diag_icd = diag_icd.rename(columns={"hadm_id".upper(): "hadm_id", "icd9_code".upper(): "icd9_code"})
icustays = icustays.rename(columns={"subject_id".upper(): "subject_id", "hadm_id".upper(): "hadm_id", "icustay_id".upper(): "icustay_id", "outtime".upper(): "outtime"})

diag_icd = diag_icd[["hadm_id", "icd9_code"]]
icustays = icustays[["subject_id", "hadm_id", "icustay_id", "outtime"]]


print(f"diag_icd ({len(diag_icd)} lines):\n", diag_icd.head(), end="\n\n")
print(f"icustays ({len(icustays)} lines):\n", icustays.head(), end="\n\n")
# print(f"admissions ({admissions.size} lines):\n", admissions.head(), end="\n\n")

diag_icd (1761 lines):
    hadm_id icd9_code
0   142345     99591
1   142345     99662
2   142345      5672
3   142345     40391
4   142345     42731

icustays (136 lines):
    subject_id  hadm_id  icustay_id              outtime
0       10006   142345      206504  2164-10-25 12:21:07
1       10011   105331      232110  2126-08-28 18:59:00
2       10013   165520      264446  2125-10-07 15:13:52
3       10017   199207      204881  2149-05-31 22:19:17
4       10019   177759      228977  2163-05-16 03:47:04



In [6]:
joined_df = pd.merge(diag_icd, icustays, how='inner', on='hadm_id')[["icd9_code", "subject_id", "icustay_id", "outtime"]]
print(f"joined_df ({len(joined_df)} lines):\n", joined_df.head())

joined_df (1897 lines):
   icd9_code  subject_id  icustay_id              outtime
0     99591       10006      206504  2164-10-25 12:21:07
1     99662       10006      206504  2164-10-25 12:21:07
2      5672       10006      206504  2164-10-25 12:21:07
3     40391       10006      206504  2164-10-25 12:21:07
4     42731       10006      206504  2164-10-25 12:21:07


Convert y in to category labels. If (3, 4 ,5) always use left 3. If it starts with E, use (Exxx). If it starts with V, use (Vxx).

In [29]:
def convert_codes(codes):
    out = []
    for code in codes:
        code = str(code)
        if code[0] == "E":
            c = code[:4]
        else:
            c = code[:3]
        out.append(c)
        
        unique_codes.add(c)
       
    return out


def build_dictionaries(codes): # use our codes (kept getting keyError)
    """Construct dicts to map/ reverse map string input codes to keys, e.g. {'001': 0, '002': 1}
    """
    types = dict((diag, idx) for idx, diag in enumerate(codes))
    rtypes = dict((idx, diag) for idx, diag in enumerate(codes))
    
    return types, rtypes


def build_dictionaries_using_tree(): # use tree (all codes)
    """Construct dicts to map/ reverse map string codes to keys using icd9 tree, e.g. {'001': 0, '002': 1}
    
    Currently doesn't handle V correctly, heirarchy has a decimal place.
    """
    categories = tree.find('001-139').siblings
    category_codes = [category.code for category in categories] # 001-139, 140-239...
    
    subcategories = []
    all_codes = []
    
    for category in category_codes:
        nodes = (tree.find(category).children)
        subcategories.append([node.code for node in nodes])
    subcategories = [item for sublist in subcategories for item in sublist] # flatten list
    for subcategory in subcategories:
        nodes = (tree.find(subcategory).children)
        all_codes.append([node.code for node in nodes])
    all_codes = [item for sublist in all_codes for item in sublist] # flatten list
    
    types = dict((diag, idx) for idx, diag in enumerate(all_codes))
    rtypes = dict((idx, diag) for idx, diag in enumerate(all_codes))
    
    return types, rtypes


X = []
y = []
all_codes = []
unique_codes = set()

for name, patient in joined_df.sort_values("outtime").groupby(["subject_id"]):
    visits = []
    for _, visit in patient.groupby(["icustay_id"]):
        codes = visit["icd9_code"].tolist()
        codes = convert_codes(codes)
        visits.append(codes)
    if len(visits) >= 2:
        x, y_ = visits[:-1], visits[-1]
        X.append(x)
        y.append(y_)

# flattened_codes = [] # don't need any of this, use set
# [flattened_codes.append(item) for sublist in all_codes for item in sublist if item not in flattened_codes]
types, rtypes = build_dictionaries(list(unique_codes))
# types, rtypes = build_dictionaries_using_tree() # use tree

print(f"Using {len(X)} patients")

# check mapping
# print(unique_codes)
print('diag mapping for DIAG_V10:', types['V10']) # 75
print('reverse mapping for index 75:', rtypes[75])


Using 19 patients
diag mapping for DIAG_V10: 75
reverse mapping for index 75: V10


In [30]:
assert(len(X) == len(y))

In [31]:
# print("visits (x):", X, "\n")
# print("last visit (y):", y)

print(len(unique_codes), "codes") 

275 codes


In [32]:
# ICD 9 Codes for Binary Classification
diabetes = ("Diabetes", "250.xx")
heart_failure = ("Heary Failure", "428.xx")
chronic_kidney_disease = ("Chronic Kidney Disease", "585.xx")

## Split Dataset

For each task, we randomly split each dataset into training, validation, and testing sets five times in a 75:10:15 ratio

In [33]:
from sklearn.model_selection import train_test_split

train_size = 0.75

X_train, X_remain, y_train, y_remain = train_test_split(X, y, train_size=0.75)

test_size = 0.6 # (valid is 10% of remaining 25%, test is 15% of remaining 25%)

X_valid, X_test, y_valid, y_test = train_test_split(X_remain, y_remain, test_size=0.6)

## Build Custom Dataset

In [34]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __len__(self):
        
        return len(self.y)
    
    def __getitem__(self, index):
        
        return (self.x[index], self.y[index])
        
        
train_dataset = CustomDataset(X_train, y_train)
val_dataset = CustomDataset(X_valid, y_valid)
test_dataset = CustomDataset(X_test, y_test)

## Load the Data (DataLoader)

For each task, we randomly split each dataset into training, validation, and testing sets five times in a 75:10:15 ratio

In [35]:
def collate_fn(data):
    """
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patients, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patients, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patients) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    y = torch.zeros(len(labels), dtype=torch.float)
    
    # for label in labels:
    #    # create one-hot vector
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        count = 0
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """
            visit_len = len(visit)
            
            typed_visit = visit
            for idx, code in enumerate(visit): # convert to mapping
                typed_visit[idx] = types[visit[idx]]
            
            x[i_patient][j_visit][:visit_len] = torch.tensor(typed_visit, dtype=torch.long)
            masks[i_patient][j_visit][:visit_len] = torch.ones((visit_len),dtype=torch.bool)
            count += 1
            
        reverse_x = x[i_patient][:count]
        reverse_mask = masks[i_patient][:count]
        
        rev_x[i_patient][:count] = torch.flip(reverse_x, [0])
        rev_masks[i_patient][:count] = torch.flip(reverse_mask, [0])
        
    return x, masks, rev_x, rev_masks, y

In [36]:
x, masks, rev_x, rev_masks, y = collate_fn(train_dataset)

In [37]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, test_loader, collate_fn):
    
    batch_size = 32
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    test_loader = DataLoader(test_loader, batch_size=batch_size, collate_fn=collate_fn)
    
    return train_loader, val_loader, test_loader


train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset, collate_fn)

## Build Model

We treat the medical events taking place in EHR as medical codes, which are denoted as $c_{1}, c_{2},... c_{|C|}$ ∈ 𝐶, where |𝐶| is the total number of unique medical codes.

One specific patient consist of a sequence of visits $v_{1}, v_{2},... v_{T}$ where we denote the number of visits in total as T.

Each visit contains a subset of medical codes, and we denote each visit as a binary vector  $v_{t} ∈ \{0, 1\}_{|C|}$, where the 𝑖-th element is set to 1 if the 𝑡-th visit contains the medical code $c_{i}$, otherwise 0. The visits  $v_{1}, v_{2},... v_{T}$ are stacked to form an input matrix $X ∈ \{0, 1\}^{|C|xT}$ , which we use as the input for the network

$E_{v} = {W}_{v}X$

$E_{o} = {W}_{o}O$

$E_{r} = \alpha(\beta \odot (E_{v}+E_{o}))^{T}$

In [16]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim=256):
        super().__init__()
        
        self.a_att = nn.Linear(hidden_dim, 1)
        
        self.sparsemax = Sparsemax(dim=-1)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, g):
        
        y = self.a_att(g)
        sparse_max = self.sparsemax(y)
        soft_max = self.softmax(y)
        
        out = (sparse_max + soft_max) / 2
        
        return out
    
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim=256):
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        
        y = self.b_att(h)
        out = torch.tanh(y)
        
        return out

In [17]:
class INPREM(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=256):
        super().__init__()
        
        self.embedding_v = nn.Embedding(num_codes, embedding_dim)
        self.embedding_o = nn.Embedding(num_codes, embedding_dim)
        
        self.att_a = AlphaAttention(embedding_dim)
        
        self.att_b = BetaAttention(embedding_dim)
        
        self.do = nn.Dropout(.5)
    
    def forward(self, X):
    
        # Pass through embedding
        ev = self.embedding_v(X)
        eo = self.embedding_o(o)
        
        er = self.att_a * (self.att_b @ (ev + eo)).T # double check this
        
        # Softmax
        out = F.softmax(x)
    

# load the model here
model = INPREM(num_codes = len(types))
model

AttributeError: cannot assign module before Module.__init__() call

## Evaluation

In [18]:
def eval_model(model, dataloader, device=None):
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    
    for DATA in dataloader:
        y_logit = model(DATA)

        y_hat = (y_logit > 0.5).int()

        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
    
    return p, r, f, roc_auc

## Train the Model

In [None]:
def train(model, train_loader, val_loader, n_epochs):

    for epoch in range(n_epochs):
        model.train()
        
        train_loss = 0
        for DATA, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)

            loss = criterion(y_hat, y)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        
        p, r, f, roc_auc = eval(model, val_loader)
        
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'.format(epoch+1, p, r, f, roc_auc))
        
    return round(roc_auc, 2)

## Run

For training all approaches, we use Adam with the batch size of 32 and the learning rate of 0.0005. The weight decay is set to 𝜆 = 0.0001 and the dropout rate is set to 0.5 for all approaches

In [7]:
# load the model
model = IMPREM(num_codes = len(types))

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

n_epochs = 5
train(model, train_loader, val_loader, n_epochs)

NameError: name 'IMPREM' is not defined

## Abblations