NaiveRNN for Heart Failure Prediction
---

In [1]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

  from .autonotebook import tqdm as notebook_tqdm


---

In [2]:
with open('../output/output.hfs', 'rb') as f:
	hfs = pickle.load(f)
with open('../output/output.seqs', 'rb') as f:
	seqs = pickle.load(f)
with open('../output/output.types', 'rb') as f:
	types = pickle.load(f)
print(len(hfs))
print(len(seqs))
assert len(hfs) == len(seqs)
print(len(types))

5447
5447
4512


where

- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels

In [3]:
# GRAM paper Table 1 replication:

# number of patients
print("Table 1: Basic statistics of MIMIC-III:")
print("# of patients:", len(hfs))

# Number of visits
num_visits = 0
num_codes = 0
max_codes_in_visit = 0
for patient_record in seqs:
    num_visits += len(patient_record)
    for visit in patient_record:
        num_codes += len(visit)
        if len(visit) > max_codes_in_visit:
            max_codes_in_visit = len(visit)
print("# of visits:", num_visits)

# Avg. # of visits per patient
print("Avg. # of visits per patient:", round(num_visits / len(hfs), 2))

# Num of unique ICD9 codes
print("# of unique ICD9 codes: ", len(types))

# Avg. # of codes per visit
print("Avg. # of codes per visit:", round(num_codes / num_visits, 2))

# Max # of codes per visit
print("Max # of codes per visit:", max_codes_in_visit)
print("Number of HF patients:", sum(hfs))
print("Number of normal patients:", len(hfs) - sum(hfs))
print("Ratio of HF patients: %.2f" % (sum(hfs) / len(hfs)))

# take the 3rd patient as an example
print("\nPatient at index 14:")
print("Heart Failure status:", hfs[14])
for visit in range(len(seqs[14])):

    print("Visits and diagnosis:", hfs[14])
    print(f"\t{visit}-th visit id:", visit)
    print(f"\t{visit}-th visit diagnosis labels:", seqs[14][visit])

Table 1: Basic statistics of MIMIC-III:
# of patients: 5447
# of visits: 11902
Avg. # of visits per patient: 2.19
# of unique ICD9 codes:  4512
Avg. # of codes per visit: 11.32
Max # of codes per visit: 39
Number of HF patients: 1280
Number of normal patients: 4167
Ratio of HF patients: 0.23

Patient at index 14:
Heart Failure status: 1
Visits and diagnosis: 1
	0-th visit id: 0
	0-th visit diagnosis labels: [250, 157, 251, 252, 5, 11, 253, 0, 12]
Visits and diagnosis: 1
	1-th visit id: 1
	1-th visit diagnosis labels: [107, 28, 16, 254, 255, 256, 0, 5, 11]
Visits and diagnosis: 1
	2-th visit id: 2
	2-th visit diagnosis labels: [257, 41, 258, 62, 259, 260, 261, 139, 16, 180, 59, 262, 263, 264, 265]


## 1 Build the dataset

### 1.1 CustomDataset

First, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the sequences of diagnosis codes `seqs` as input and heart failure `hfs` as output.

In [4]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        self.x = seqs
        self.y = hfs

    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
dataset = CustomDataset(seqs, hfs)

### 1.2 Collate Function

In [5]:
def collate_fn(data):
    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            for k, code in enumerate(sequences[i_patient][j_visit]):
                x[i_patient][j_visit][k] = code
                masks[i_patient][j_visit][k] = 1
    
    return x, masks, y

In [7]:
from torch.utils.data.dataset import random_split

# get int for 20% of the dataset
#split = int(len(dataset)*0.07)
split = int(len(dataset)*0.14)
#split = int(len(dataset)*0.21)
#split = int(len(dataset)*0.28)


"""
I modified RNN/rnn.ipynb to calculate AUC for a given train dataset ratio.
You can change the train ratio here:
split = int(len(dataset)*0.2)
We need AUC for 10%,20%,…,100% .

I trained the model, using 7%,14%,…,70% of train dataset, with 20% as test dataset and the rest as valid dataset.
So
int(len(dataset)*0.07)
int(len(dataset)*0.14)
int(len(dataset)*0.21)
…
int(len(dataset)*0.70)
might be better to compare apple to apple.
"""

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 762
Length of val dataset: 4685


### 1.3 DataLoader

In [8]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    batch_size = 32
    train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = torch.utils.data.DataLoader(dataset = val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader

train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

## 2 Naive RNN

### 2.1 Mask Selection 

In [9]:
def sum_embeddings_with_mask(x, masks):
    masked_x = x * masks[..., None]
    return torch.sum(masked_x, dim = 2)

In [11]:
def get_last_visit(hidden_states, masks):
    batch_size = hidden_states.shape[0]
    masks = torch.sum(masks, dim = 2) > 0
    masks = torch.sum(masks, dim = 1) - 1
    last_hidden_state = hidden_states[range(batch_size), masks, :]
    return last_hidden_state


### 2.2 Build NaiveRNN

In [13]:
class NaiveRNN(nn.Module):
    def __init__(self, num_codes):
        super().__init__()
        embDimSize = 128
        self.embedding = nn.Embedding(num_embeddings = num_codes, embedding_dim=embDimSize)
        self.rnn = nn.GRU(input_size = embDimSize, hidden_size=128, batch_first = True)
        self.fc = nn.Linear(in_features=128, out_features=1)
        self.sigmoid = nn.Sigmoid() # GRAM paper uses Softmax for activation function
    
    def forward(self, x, masks):
        batch_size = x.shape[0]
        x = self.embedding(x)
        x = sum_embeddings_with_mask(x, masks)
        output, _ = self.rnn(x)
        true_h_n = get_last_visit(output, masks)
        logits = self.fc(true_h_n)        
        probs = self.sigmoid(logits)  ## GRAM paper uses Softmax for activation function
        return probs.view(batch_size)
    
naive_rnn = NaiveRNN(num_codes = len(types))
naive_rnn

NaiveRNN(
  (embedding): Embedding(4512, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## 3 Model Training 

### 3.1 Loss and Optimizer

In [14]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr = 0.001)

### 3.2 Evaluate

In [15]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

def eval_model(model, val_loader):
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, y in val_loader:
        y_hat = model(x, masks)
        y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)

    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average = 'binary')
    roc_auc = roc_auc_score(y_true, y_score)

    return p, r, f, roc_auc

### 3.3 Training and evlauation 

In [17]:
def train(model, train_loader, val_loader, n_epochs):
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc = eval_model(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'
              .format(epoch+1, p, r, f, roc_auc))

In [18]:
# number of epochs to train the model
n_epochs = 100
train(naive_rnn, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.591764
Epoch: 1 	 Validation p: 0.41, r:0.03, f: 0.06, roc_auc: 0.72
Epoch: 2 	 Training Loss: 0.417939
Epoch: 2 	 Validation p: 0.63, r:0.20, f: 0.30, roc_auc: 0.79
Epoch: 3 	 Training Loss: 0.299142
Epoch: 3 	 Validation p: 0.72, r:0.29, f: 0.41, roc_auc: 0.83
Epoch: 4 	 Training Loss: 0.193046
Epoch: 4 	 Validation p: 0.73, r:0.44, f: 0.55, roc_auc: 0.85
Epoch: 5 	 Training Loss: 0.106736
Epoch: 5 	 Validation p: 0.71, r:0.54, f: 0.61, roc_auc: 0.86
Epoch: 6 	 Training Loss: 0.058417
Epoch: 6 	 Validation p: 0.73, r:0.58, f: 0.65, roc_auc: 0.87
Epoch: 7 	 Training Loss: 0.033947
Epoch: 7 	 Validation p: 0.72, r:0.62, f: 0.67, roc_auc: 0.87
Epoch: 8 	 Training Loss: 0.021183
Epoch: 8 	 Validation p: 0.71, r:0.63, f: 0.67, roc_auc: 0.87
Epoch: 9 	 Training Loss: 0.014437
Epoch: 9 	 Validation p: 0.73, r:0.65, f: 0.69, roc_auc: 0.88
Epoch: 10 	 Training Loss: 0.010510
Epoch: 10 	 Validation p: 0.73, r:0.66, f: 0.69, roc_auc: 0.88
Epoch: 11 	 Training Loss: 0

In [19]:
p, r, f, roc_auc = eval_model(naive_rnn, val_loader)
print(roc_auc)

0.8890676001642535
