# RETAIN

## Overview

Previously, we tried heart failure predictioin with classical machine learning models, neural network (NN), and recurrent neural network (RNN). 

In this project, we will try a different approach. We will implement RETAIN, a RNN model with attention mechanism, proposed by Choi et al. in the paper [RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism](https://arxiv.org/abs/1608.05745).

In [1]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# define data path
DATA_PATH = "lib/data/"

## About Raw Data

We will perform heart failure prediction using the diagnosis codes. We will use the same dataset as before, which is synthesized from [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/).

The data has been preprocessed. Let us load them and take a look.

In [3]:
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

assert len(pids) == len(vids) == len(hfs) == len(seqs) == 1000
assert len(types) == 619

where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes

Let us take a patient as an example.

In [4]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


Note that `seqs` is a list of list of list. That is, `seqs[i][j][k]` gives you the k-th diagnosis codes for the j-th visit for the i-th patient.

And you can look up the meaning of the ICD9 code online. For example, `DIAG_276` represetns *disorders of fluid electrolyte and acid-base balance*.

Further, let see number of heart failure patients.

In [5]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


## 1 Build the dataset

### 1.1 CustomDataset

First, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the sequences of diagnosis codes `seqs` as input and heart failure `hfs` as output.

In [7]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        self.x = seqs
        self.y = hfs
    
    def __len__(self):

        return len(self.y)      
    
    def __getitem__(self, index):

        x = self.x[index]
        y = self.y[index]
        return x, y        

dataset = CustomDataset(seqs, hfs)

### 1.2 Collate Function

We do not convert the data to tensor in the built `CustomDataset`. Instead, we will do this using a collate function `collate_fn()`. 

This collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

For example, assume the `DataLoader` gets a list of two samples.

```
[ [ [0, 1, 2], [8, 0] ], 
  [ [12, 13, 6, 7], [12], [23, 11] ] ]
```

where the first sample has two visits `[0, 1, 2]` and `[8, 0]` and the second sample has three visits `[12, 13, 6, 7]`, `[12]`, and `[23, 11]`.

The collate function `collate_fn()` is supposed to pad them into the same shape (3, 4), where 3 is the maximum number of visits and 4 is the maximum number of diagnosis codes.

```
[ [ [0, 1, 2, *0*], [8, 0, *0*, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [12, 13, 6, 7], [12, *0*, *0*, *0*], [23, 11, *0*, *0*] ] ]
```

Further, the padding information will be stored in a mask with the same shape, where 1 indicates that the diagnosis code at this position is from the original input, and 0 indicates that the diagnosis code at this position is the padded value.

```
[ [ [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 1, 1], [1, 0, 0, 0], [1, 1, 0, 0] ] ]
```

Lastly, we will have another diagnosis sequence in reversed time. This will be used in our RNN model for masking. Note that we only flip the true visits.

```
[ [ [8, 0, *0*, *0*], [0, 1, 2, *0*], [*0*, *0*, *0*, *0*]  ], 
  [ [23, 11, *0*, *0*], [12, *0*, *0*, *0*], [12, 13, 6, 7] ] ]
```

And a reversed mask as well.

```
[ [ [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0] ], 
  [ [1, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1], ] ]
```

We need to pad the sequences into the same length so that we can do batch training on GPU. And we also need this mask so that when training, we can ignored the padded value as they actually do not contain any information.

In [9]:
def collate_fn(data):
    """
    Collate the the list of samples into batches. For each patient, we pad the diagnosis
    sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
    is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
    """

    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            n = len(patient)
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                masks[i_patient, j_visit, k_code] = 1
                rev_x[i_patient, n-j_visit-1, k_code] = code
                rev_masks[i_patient, n-j_visit-1, k_code] = 1            
    
    return x, masks, rev_x, rev_masks, y

In [10]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

assert x.dtype == rev_x.dtype == torch.long
assert y.dtype == torch.float
assert masks.dtype == rev_masks.dtype == torch.bool

assert x.shape == rev_x.shape == masks.shape == rev_masks.shape == (10, 3, 24)
assert y.shape == (10,)

Now we have `CustomDataset` and `collate_fn()`. Let us split the dataset into training and validation sets.

In [11]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 800
Length of val dataset: 200


### 1.3 DataLoader

Now, we can load the dataset into the data loader.

In [12]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    Return the data loader for  train and validation dataset.     
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    '''
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn, shuffle=False)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

## 2 RETAIN

RETAIN is essentially a RNN model with attention mechanism.
 
The idea of attention is quite simple: it boils down to weighted averaging. Let us consider machine translation in class as an example. When generating a translation of a source text, we first pass the source text through an encoder (an LSTM or an equivalent model) to obtain a sequence of encoder hidden states $\boldsymbol{h}_1, \dots, \boldsymbol{h}_T$. Then, at each step of generating a translation (decoding), we selectively attend to these encoder hidden states, that is, we construct a context vector $\boldsymbol{c}_i$ that is a weighted average of encoder hidden states.

$$\boldsymbol{c}_i = \underset{j}{\Sigma} a_{ij}\boldsymbol{h}_j$$

We choose the weights $a_{ij}$ based both on encoder hidden states $\boldsymbol{h}_1, \dots, \boldsymbol{h}_T$ and decoder hidden states $\boldsymbol{s}_1, \dots, \boldsymbol{s}_T$ and normalize them so that they encode a categorical probability distribution $p(\boldsymbol{h}_j | \boldsymbol{s}_i)$.

$$\boldsymbol{a}_{i} = \text{Softmax}\left( a(\boldsymbol{s}_i, \boldsymbol{h}_j) \right)$$

RETAIN has two different attention mechanisms. 
- One is to help figure out what are the important visits. This attention $\alpha_i$, which is scalar for the i-th visit, tells you the importance of the i-th visit.
- Then we have another similar attention mechanism. But in this case, this attention ways $\mathbf{\beta}_i$ is a vector. That gives us a more detailed view of underlying cause of the input. That is, which are the important features within a visit.

<img src=./img/retain-1.png>

Unfolded view of RETAIN’s architecture: Given input sequence $\mathbf{x}_1 , . . . , \mathbf{x}_i$, we predict the label $\mathbf{y}_i$. 
- Step 1: Embedding, 
- Step 2: generating $\alpha$ values using RNN-$\alpha$, 
- Step 3: generating $\mathbf{\beta}$ values using RNN-$\beta$, 
- Step 4: Generating the context vector using attention and representation vectors, 
- Step 5: Making prediction. 

Note that in Steps 2 and 3 we use RNN in the reversed time.

<img src=./img/retain-2.png>

<img src=./img/retain-3.png>

Let us first implement RETAIN step-by-step.

### 2.1 Step 2: AlphaAttention

Implement the alpha attention in the second equation of step 2.

In [35]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
        """

        alpha = torch.softmax(self.a_att(g), dim=1)
        return alpha

### 2.2 Step 3: BetaAttention

Implement the beta attention in the second equation of step 3.

In [37]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """     
        Arguments:
            hidden_dim: the hidden dimension
        """
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim)         
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
        """
        beta = torch.tanh(self.b_att(h))
        return beta

### 2.3 Attention Sum

Implement the sum of attention in step 4.

In [85]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
    Mask select the hidden states for true visits (not padding visits) and then
        sum the them up.

    Arguments:
        alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
        beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
        rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

    Outputs:
        c: the context vector of shape (batch_size, hidden_dim)
    """

    rev_masks = torch.sum(rev_masks, dim=-1).unsqueeze(dim=-1) > 0 
    rev_v = rev_v*rev_masks
    c = alpha*beta*rev_v
    return torch.sum(c, dim = -2)

### 2.4 Build RETAIN

Now, we can build the RETAIN model.

In [44]:
def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [45]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            rev_x: the diagnosis sequence in reversed time of shape (# visits, batch_size, # diagnosis codes)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        # 1. Pass the reversed sequence through the embedding layer;
        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    

# load the model here
retain = RETAIN(num_codes = len(types))
retain

RETAIN(
  (embedding): Embedding(619, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

###  go through the steps

In [14]:
# get example data
for x, masks, rev_x, rev_masks, y in train_loader:
    break

In [25]:
loader_iter = iter(train_loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)
print(x.shape)

torch.Size([32, 6, 27])


In [27]:
num_codes=len(types) #619
embedding_dim=128
embedding = nn.Embedding(num_codes, embedding_dim)
print(rev_x.shape)
rev_x = embedding(rev_x)
print(rev_x.shape)
rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
print(rev_x.shape)

torch.Size([32, 6, 27])
torch.Size([32, 6, 27, 128])
torch.Size([32, 6, 128])


In [28]:
rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
g, _ = rnn_a(rev_x)
h, _ = rnn_b(rev_x)
print(g.shape)
print(h.shape)

torch.Size([32, 6, 128])
torch.Size([32, 6, 128])


In [39]:
att_a = AlphaAttention(embedding_dim)
att_b = BetaAttention(embedding_dim)
alpha = att_a(g)
beta = att_b(h)
print(alpha.shape)
print(beta.shape)

torch.Size([32, 6, 1])
torch.Size([32, 6, 128])


In [62]:
c = attention_sum(alpha, beta, rev_x, rev_masks)
c

tensor([[-0.7792, -0.2554, -0.0401,  ..., -1.7485,  0.2221, -0.2301],
        [ 0.0572, 14.0182,  3.5569,  ..., 11.2421,  8.2220, -3.2425],
        [ 0.5128,  0.0239,  0.1417,  ...,  0.4237, -0.1517, -0.0331],
        ...,
        [ 2.4871,  2.6028,  0.9678,  ..., -1.9011, -5.8530, -2.8644],
        [-0.2177, -1.8827, -1.1202,  ..., -1.6616, -3.6643,  0.3459],
        [ 3.1906, -6.0042,  3.5691,  ...,  0.8539,  2.6944, -1.5251]],
       grad_fn=<SumBackward1>)

In [63]:
c.shape

torch.Size([32, 128])

In [65]:
fc = nn.Linear(embedding_dim, 1)
logits = fc(c)
print(logits.shape)

torch.Size([32, 1])


In [67]:
logits

tensor([[  0.2872],
        [  3.0242],
        [  0.2081],
        [  1.8926],
        [ -2.0172],
        [  1.0088],
        [  0.2704],
        [  0.6865],
        [ -0.8555],
        [ -0.6037],
        [  0.0144],
        [ -0.2468],
        [ -1.7344],
        [ -3.2518],
        [  0.7698],
        [ -0.7277],
        [  1.3738],
        [ -1.3513],
        [-14.2405],
        [  0.2334],
        [ -0.1596],
        [  1.8194],
        [  0.3618],
        [ -0.4476],
        [ -1.4095],
        [ -1.3761],
        [ -0.7304],
        [  0.5540],
        [ -2.6352],
        [  2.9106],
        [ -1.8770],
        [  1.6818]], grad_fn=<AddmmBackward0>)

In [66]:
sigmoid = nn.Sigmoid()
probs = sigmoid(logits)
print(probs.shape)

torch.Size([32, 1])


In [68]:
probs

tensor([[5.7131e-01],
        [9.5365e-01],
        [5.5183e-01],
        [8.6905e-01],
        [1.1741e-01],
        [7.3279e-01],
        [5.6719e-01],
        [6.6518e-01],
        [2.9828e-01],
        [3.5350e-01],
        [5.0361e-01],
        [4.3860e-01],
        [1.5003e-01],
        [3.7262e-02],
        [6.8348e-01],
        [3.2571e-01],
        [7.9799e-01],
        [2.0567e-01],
        [6.5379e-07],
        [5.5809e-01],
        [4.6019e-01],
        [8.6050e-01],
        [5.8947e-01],
        [3.8992e-01],
        [1.9632e-01],
        [2.0164e-01],
        [3.2510e-01],
        [6.3507e-01],
        [6.6908e-02],
        [9.4837e-01],
        [1.3273e-01],
        [8.4314e-01]], grad_fn=<SigmoidBackward0>)

In [69]:
probs.squeeze()

tensor([5.7131e-01, 9.5365e-01, 5.5183e-01, 8.6905e-01, 1.1741e-01, 7.3279e-01,
        5.6719e-01, 6.6518e-01, 2.9828e-01, 3.5350e-01, 5.0361e-01, 4.3860e-01,
        1.5003e-01, 3.7262e-02, 6.8348e-01, 3.2571e-01, 7.9799e-01, 2.0567e-01,
        6.5379e-07, 5.5809e-01, 4.6019e-01, 8.6050e-01, 5.8947e-01, 3.8992e-01,
        1.9632e-01, 2.0164e-01, 3.2510e-01, 6.3507e-01, 6.6908e-02, 9.4837e-01,
        1.3273e-01, 8.4314e-01], grad_fn=<SqueezeBackward0>)

## 3 Training and Inferencing

Then, let us implement the `eval()` function first.

In [47]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

def eval(model, val_loader):
    
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    REFERENCE: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_logit = model(x, masks, rev_x, rev_masks)
        y_hat = y_logit > 0.5
        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
    return p, r, f, roc_auc

Now let us implement the `train()` function. Note that `train()` should call `eval()` at the end of each training epoch to see the results on the validation dataset.

In [56]:
def train(model, train_loader, val_loader, n_epochs):
    """
    Train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
            optimizer.zero_grad()
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        
        p, r, f, roc_auc = eval(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'.format(epoch+1, p, r, f, roc_auc))
        
    return round(roc_auc, 2)

In [86]:
# load the model
retain = RETAIN(num_codes = len(types))

# load the loss function
criterion = nn.BCELoss()

# load the optimizer
optimizer = torch.optim.Adam(retain.parameters(), lr=1e-3)

n_epochs = 5
train(retain, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 12.402614
Epoch: 1 	 Validation p: 0.73, r:0.74, f: 0.73, roc_auc: 0.79
Epoch: 2 	 Training Loss: 6.872137
Epoch: 2 	 Validation p: 0.71, r:0.66, f: 0.68, roc_auc: 0.72
Epoch: 3 	 Training Loss: 3.879046
Epoch: 3 	 Validation p: 0.73, r:0.57, f: 0.64, roc_auc: 0.69
Epoch: 4 	 Training Loss: 2.832272
Epoch: 4 	 Validation p: 0.76, r:0.66, f: 0.71, roc_auc: 0.75
Epoch: 5 	 Training Loss: 2.150762
Epoch: 5 	 Validation p: 0.73, r:0.64, f: 0.68, roc_auc: 0.76


0.76

## 4 Sensitivity analysis

We will train the same model but with different hyperparameters. We will be using 0.1 and 0.001 for learning rate, and 16, 128 for embedding dimensions. It shows how model performance varies with different values of learning rate and embedding dimensions.

In [88]:
lr_hyperparameter = [1e-1, 1e-3]
embedding_dim_hyperparameter = [8, 128]
n_epochs = 5
results = {}

for lr in lr_hyperparameter:
    for embedding_dim in embedding_dim_hyperparameter:
        print ('='*50)
        print ({'learning rate': lr, "embedding_dim": embedding_dim})
        print ('-'*50)

        retain = RETAIN(num_codes = len(types), embedding_dim=embedding_dim)
        criterion = nn.BCELoss()
        optimizer = torch.optim.Adam(retain.parameters(), lr=lr) 

        roc_auc = train(retain, train_loader, val_loader, n_epochs)
        results['lr:{},emb:{}'.format(str(lr), str(embedding_dim))] =  roc_auc

{'learning rate': 0.1, 'embedding_dim': 8}
--------------------------------------------------
Epoch: 1 	 Training Loss: 9.507623
Epoch: 1 	 Validation p: 0.65, r:0.68, f: 0.67, roc_auc: 0.66
Epoch: 2 	 Training Loss: 9.543952
Epoch: 2 	 Validation p: 0.67, r:0.91, f: 0.77, roc_auc: 0.79
Epoch: 3 	 Training Loss: 9.473220
Epoch: 3 	 Validation p: 0.61, r:0.80, f: 0.69, roc_auc: 0.69
Epoch: 4 	 Training Loss: 15.100709
Epoch: 4 	 Validation p: 0.72, r:0.77, f: 0.74, roc_auc: 0.78
Epoch: 5 	 Training Loss: 16.169873
Epoch: 5 	 Validation p: 0.66, r:0.88, f: 0.76, roc_auc: 0.75
{'learning rate': 0.1, 'embedding_dim': 128}
--------------------------------------------------
Epoch: 1 	 Training Loss: 34.746652
Epoch: 1 	 Validation p: 0.55, r:1.00, f: 0.71, roc_auc: 0.58
Epoch: 2 	 Training Loss: 37.935333
Epoch: 2 	 Validation p: 0.57, r:0.98, f: 0.72, roc_auc: 0.60
Epoch: 3 	 Training Loss: 36.611856
Epoch: 3 	 Validation p: 0.55, r:0.98, f: 0.70, roc_auc: 0.56
Epoch: 4 	 Training Loss: 39.