# Recurent Neural Network for diagnosis codes

## Overview

In this project, we will build a bi-directional RNN on diagnosis codes. The recurrent nature of RNN allows us to model the temporal relation of different visits of a patient. More specifically, we will still perform **Heart Failure Prediction**, but with different input formats.

In [31]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "lib/data"

## About Raw Data

To get started, we will implement a naive RNN model for heart failure prediction using the diagnosis codes.

We will use the same dataset synthesized from [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/), but with different input formats.

The data has been preprocessed for you. Let us load them and take a look.

In [32]:
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

assert len(pids) == len(vids) == len(hfs) == len(seqs) == 1000
assert len(types) == 619

where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes

Let us take a patient as an example.

In [33]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


Note that `seqs` is a list of list of list. That is, `seqs[i][j][k]` gives you the k-th diagnosis codes for the j-th visit for the i-th patient.

And you can look up the meaning of the ICD9 code online. For example, `DIAG_276` represetns *disorders of fluid electrolyte and acid-base balance*.

Further, let see number of heart failure patients.

In [34]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


Now we have the data. Let us build the naive RNN.

## 1 Build the dataset

### 1.1 CustomDataset

First, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the sequences of diagnosis codes `seqs` as input and heart failure `hfs` as output.

In [35]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):

        self.x = seqs
        self.y = hfs
    
    def __len__(self):

        return len(self.y)
    
    def __getitem__(self, index):
        
        # Generates one sample of data.
        return self.x[index], self.y[index]

dataset = CustomDataset(seqs, hfs)

### 1.2 Collate Function

Note that we do not convert the data to tensor in the built `CustomDataset`. Instead, we will do this using a collate function `collate_fn()`. 

This collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

For example, assume the `DataLoader` gets a list of two samples.

```
[ [ [0, 1, 2], [8, 0] ], 
  [ [12, 13, 6, 7], [12], [23, 11] ] ]
```

where the first sample has two visits `[0, 1, 2]` and `[8, 0]` and the second sample has three visits `[12, 13, 6, 7]`, `[12]`, and `[23, 11]`.

The collate function `collate_fn()` is supposed to pad them into the same shape (3, 4), where 3 is the maximum number of visits and 4 is the maximum number of diagnosis codes.

```
[ [ [0, 1, 2, *0*], [8, 0, *0*, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [12, 13, 6, 7], [12, *0*, *0*, *0*], [23, 11, *0*, *0*] ] ]
```

Further, the padding information will be stored in a mask with the same shape, where 1 indicates that the diagnosis code at this position is from the original input, and 0 indicates that the diagnosis code at this position is the padded value.

```
[ [ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 1, 1], [1, 0, 0, 0], [1, 1, 0, 0] ] ]
```

Lastly, we will have another diagnosis sequence in reversed time. This will be used in our RNN model for masking. Note that we only flip the true visits.

```
[ [ [8, 0, *0*, *0*], [0, 1, 2, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [23, 11, *0*, *0*], [12, *0*, *0*, *0*], [12, 13, 6, 7] ] ]
```

And a reversed mask as well.

```
[ [ [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1], ] ]
```

We need to pad the sequences into the same length so that we can do batch training on GPU. And we also need this mask so that when training, we can ignored the padded value as they actually do not contain any information.

In [37]:
def collate_fn(data):
    """
    Collate the the list of samples into batches. For each patient, pad the diagnosis
    sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
    is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
    """

    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):

            num_codes_ij = len(visit)
            num_pad_ij = max_num_codes - num_codes_ij
            codes_padded = torch.tensor(visit + [0] * num_pad_ij, dtype=torch.long)
            masks_padded = torch.tensor([1] * num_codes_ij + [0] * num_pad_ij, dtype=torch.bool)
            
            x[i_patient, j_visit] = codes_padded
            masks[i_patient, j_visit] = masks_padded
            
            rev_j_visit = len(sequences[i_patient]) - 1 - j_visit
            rev_x[i_patient, rev_j_visit] = codes_padded
            rev_masks[i_patient, rev_j_visit] = masks_padded
    
    return x, masks, rev_x, rev_masks, y

In [38]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
loader_iter = iter(loader)
next(loader_iter)

(tensor([[[ 85, 112, 346, 380, 269, 511, 114, 103, 530, 597, 511],
          [ 85, 103, 112, 513, 511,  19, 149, 530, 186,  66,   0]],
 
         [[167, 187, 274, 359,  85,   2, 186,  85,   0,   0,   0],
          [187, 306, 511, 103,  85, 186, 103,   0,   0,   0,   0]]]),
 tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           False]],
 
         [[ True,  True,  True,  True,  True,  True,  True,  True, False, False,
           False],
          [ True,  True,  True,  True,  True,  True,  True, False, False, False,
           False]]]),
 tensor([[[ 85, 103, 112, 513, 511,  19, 149, 530, 186,  66,   0],
          [ 85, 112, 346, 380, 269, 511, 114, 103, 530, 597, 511]],
 
         [[187, 306, 511, 103,  85, 186, 103,   0,   0,   0,   0],
          [167, 187, 274, 359,  85,   2, 186,  85,   0,   0,   0]]]),
 tensor([[[ True,  True,  True,  True,  True, 

In [39]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

Now we have `CustomDataset` and `collate_fn()`. Let us split the dataset into training and validation sets.

In [40]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 800
Length of val dataset: 200


### 1.3 DataLoader 

Now, we can load the dataset into the data loader.

In [41]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    Use this function to return the data loader for  train and validation dataset. 
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    '''
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader

train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

## 2 Naive RNN

Let us implement a naive bi-directional RNN model.

<img src="img/bi-rnn.jpg" width="600"/>

Remember from class that, first of all, we need to transform the diagnosis code for each visit of a patient to an embedding. To do this, we can use `nn.Embedding()`, where `num_embeddings` is the number of diagnosis codes and `embedding_dim` is the embedding dimension.

Then, we can construct a simple RNN structure. Each input is this multi-hot vector. At the 0-th visit, this has $\boldsymbol{X}_0$, and at t-th visit, this has $\boldsymbol{X}_t$.

Each one of the input will then map to a hidden state $\boldsymbol{\overleftrightarrow{h}}_t$. The forward hidden state $\boldsymbol{\overrightarrow{h}}_t$ can be determined by $\boldsymbol{\overrightarrow{h}}_{t-1}$ and the corresponding current input $\boldsymbol{X}_t$.

Similarly, we will have another RNN to process the sequence in the reverse order, so that the hidden state $\boldsymbol{\overleftarrow{h}}_t$ is determined by $\boldsymbol{\overleftarrow{h}}_{t+1}$ and $\boldsymbol{X}_t$.

Finally, once we have the $\boldsymbol{\overrightarrow{h}}_T$ and $\boldsymbol{\overleftarrow{h}}_{0}$, we will concatenate the two vectors as the feature vector and train a NN to perform the classification.

Now, let us build this model. The forward steps will be:

    1. Pass the sequence through the embedding layer;
    2. Sum the embeddings for each diagnosis code up for a visit of a patient;
    3. Pass the embeddings through the RNN layer;
    4. Obtain the hidden state at the last visit;
    5. Do 1-4 for both directions and concatenate the hidden states.
    6. Pass the hidden state through the linear and activation layers.

### 2.1 Mask Selection

Importantly, we need to use `masks` to mask out the paddings in before step 2 and before 4. So, let us first preform the mask selection.

In [43]:
def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.
    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)
    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    masks_ext = masks.unsqueeze(-1).long()
    sum_embeddings = (x * masks_ext).sum(dim=2)
    return sum_embeddings

In [44]:
import random
import ast
import inspect

def uses_loop(function):
    loop_statements = ast.For, ast.While, ast.AsyncFor

    nodes = ast.walk(ast.parse(inspect.getsource(function)))
    return any(isinstance(node, loop_statements) for node in nodes)

def generate_random_mask(batch_size, max_num_visits , max_num_codes):
    num_visits = [random.randint(1, max_num_visits) for _ in range(batch_size)]
    num_codes = []
    for n in num_visits:
        num_codes_visit = [0] * max_num_visits
        for i in range(n):
            num_codes_visit[i] = (random.randint(1, max_num_codes))
        num_codes.append(num_codes_visit)
    masks = [torch.ones((l,), dtype=torch.bool) for num_codes_visit in num_codes for l in num_codes_visit]
    masks = torch.stack([torch.cat([i, i.new_zeros(max_num_codes - i.size(0))], 0) for i in masks], 0)
    masks = masks.view((batch_size, max_num_visits, max_num_codes)).bool()
    return masks

batch_size = 16
max_num_visits = 10
max_num_codes = 20
embedding_dim = 100

torch.random.manual_seed(7)
x = torch.randn((batch_size, max_num_visits , max_num_codes, embedding_dim))
masks = generate_random_mask(batch_size, max_num_visits , max_num_codes)
out = sum_embeddings_with_mask(x, masks)

assert uses_loop(sum_embeddings_with_mask) is False
assert out.shape == (batch_size, max_num_visits, embedding_dim)

In [45]:
def get_last_visit(hidden_states, masks):
    """
    Obtain the hidden state for the last true visit (not padding visits)
    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)
    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
    """
    masks_bool = masks.sum(dim=masks.ndim-1).bool()
    masks_last_visit = masks_bool.sum(dim=masks_bool.ndim-1) - 1
    last_hidden_states = hidden_states[np.arange(len(hidden_states)), masks_last_visit]
    return last_hidden_states

In [46]:
max_num_visits = 10
batch_size = 16
max_num_codes = 20
embedding_dim = 100

torch.random.manual_seed(7)
hidden_states = torch.randn((batch_size, max_num_visits, embedding_dim))
masks = generate_random_mask(batch_size, max_num_visits , max_num_codes)
out = get_last_visit(hidden_states, masks)

assert out.shape == (batch_size, embedding_dim)

### 2.2 Build NaiveRNN

In [47]:
class NaiveRNN(nn.Module):
    
    def __init__(self, num_codes):
        super().__init__()
        """
        Arguments:
            num_codes: total number of diagnosis codes
        """
        self.embedding = nn.Embedding(num_embeddings=num_codes, embedding_dim=128)
        self.rnn = nn.GRU(input_size=128, hidden_size=128, batch_first=True)
        self.rev_rnn = nn.GRU(input_size=128, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(in_features=256, out_features=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        
        batch_size = x.shape[0]
        
        # 1. Pass the sequence through the embedding layer;
        x = self.embedding(x)
        
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        x = sum_embeddings_with_mask(x, masks)
        
        # 3. Pass the embegginds through the RNN layer;
        output, _ = self.rnn(x)
        
        # 4. Obtain the hidden state at the last visit.
        true_h_n = get_last_visit(output, masks)

        # 5. Do the step 1-4 again for the reverse order, and concatenate the hidden states for both directions;
        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        rev_output, _ = self.rev_rnn(rev_x)
        true_h_n_rev = get_last_visit(rev_output, rev_masks)
        
        # 6. Pass the hidden state through the linear and activation layers.
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 1))        
        probs = self.sigmoid(logits)
        return probs.view(batch_size)

# load the model here
naive_rnn = NaiveRNN(num_codes = len(types))
naive_rnn

NaiveRNN(
  (embedding): Embedding(619, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## 3 Model Training

### 3.1 Loss and Optimizer

In [50]:
# Specify Binary Cross Entropy as the loss function (`nn.BCELoss`) and assign it to `criterion`.
criterion = nn.BCELoss()

# Spcify Adam as the optimizer (`torch.optim.Adam`)  with learning rate 0.001 and assign it to `optimizer`.
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)

### 3.2 Evaluate

Then, let us implement the `eval_model()` function first.

In [52]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

def eval_model(model, val_loader):
    """
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_hat = model(x, masks, rev_x, rev_masks)
        y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)

    # Calculate precision, recall, f1, and roc auc scores.
    # Reference: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    p, r, f, support = precision_recall_fscore_support(
        y_true=y_true.cpu().data.numpy(),
        y_pred=y_pred.cpu().data.numpy(), average='binary')
    roc_auc = roc_auc_score(
        y_true=y_true.cpu().data.numpy(),
        y_score=y_score.cpu().data.numpy())
    return p, r, f, roc_auc

### 3.3 Training and evlauation

Now let us implement the `train()` function. Note that `train()` should call `eval_model()` at the end of each training epoch to see the results on the validaion dataset.

In [55]:
def train(model, train_loader, val_loader, n_epochs):
    """  
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            probs = model(x=x, masks=masks, rev_x=rev_x, rev_masks=rev_masks)
            loss = criterion(input=probs, target=y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        
        p, r, f, roc_auc = eval_model(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'
              .format(epoch+1, p, r, f, roc_auc))

In [56]:
# number of epochs to train the model
n_epochs = 5
train(naive_rnn, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.620101
Epoch: 1 	 Validation p: 0.71, r:0.83, f: 0.76, roc_auc: 0.78
Epoch: 2 	 Training Loss: 0.433294
Epoch: 2 	 Validation p: 0.73, r:0.83, f: 0.78, roc_auc: 0.79
Epoch: 3 	 Training Loss: 0.326476
Epoch: 3 	 Validation p: 0.71, r:0.88, f: 0.79, roc_auc: 0.78
Epoch: 4 	 Training Loss: 0.222812
Epoch: 4 	 Validation p: 0.75, r:0.80, f: 0.77, roc_auc: 0.79
Epoch: 5 	 Training Loss: 0.141010
Epoch: 5 	 Validation p: 0.76, r:0.82, f: 0.79, roc_auc: 0.79


In [57]:
p, r, f, roc_auc = eval_model(naive_rnn, val_loader)
print(roc_auc)

0.790997442455243
