In [46]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

### Reading Preprocessed Data

In [47]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# define data path
DATA_PATH = "/Users/shunfan/Developer/MIMIC/output2/"

In [48]:
labels = pickle.load(open(os.path.join(DATA_PATH,'MIMICIIIPROCESSED.morts'), 'rb'))
icd_seqs = pickle.load(open(os.path.join(DATA_PATH,'MIMICIIIPROCESSED.3digitICD9.seqs'), 'rb'))

In [49]:
print("number of passed patients:", sum(labels))
print("ratio of passed patients: %.2f" % (sum(labels) / len(labels)))

number of passed patients: 2825
ratio of passed patients: 0.37


### Load ICD

In [50]:
print(len(icd_seqs))
print(icd_seqs[0])

7537
[[0, 1, 2, 3, 4, 5, 6, 7], [1, 4, 6, 8, 7, 9, 10, 11]]


### Build the Dataset

1. CustomDataset

In [51]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, input_seqs, labels):
        self.x = input_seqs
        self.y = labels
    
    def __len__(self):
        
        """
        Return the number of samples (i.e. patients).
        """
        return len(self.y)
    
    def __getitem__(self, index):
        
        """
        Generates one sample of data.
        """
        return self.x[index], self.y[index]
        

dataset = CustomDataset(icd_seqs, labels)
print(len(dataset))

7537


2. Collate Function

In [52]:
def collate_fn(data):
    """
    Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of mortality labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            """
            update `x`, `rev_x`, `masks`, and `rev_masks`
            """
            num_code_pad = max_num_codes - len(visit)
            x[i_patient][j_visit] = torch.tensor(visit + [0]*num_code_pad)
            masks[i_patient][j_visit] = torch.tensor([1]*len(visit) + [0]*num_code_pad)
        
        num_true_visit = len(patient)
        rev_x[i_patient][:num_true_visit] = torch.tensor(x[i_patient][:num_true_visit].tolist()[::-1])
        rev_masks[i_patient][:num_true_visit] = torch.tensor(masks[i_patient][:num_true_visit].tolist()[::-1])

    return x, masks, rev_x, rev_masks, y

In [53]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

In [54]:
print(x.shape, rev_x.shape)
print(masks.shape, rev_masks.shape)
print(y.shape)

torch.Size([10, 5, 21]) torch.Size([10, 5, 21])
torch.Size([10, 5, 21]) torch.Size([10, 5, 21])
torch.Size([10])


3. Split dataset into train and test

In [55]:
from torch.utils.data.dataset import random_split

train_size= int(len(dataset)*0.8)

lengths = [train_size, len(dataset)-train_size]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6029
Length of val dataset: 1508


4. Data loader

In [56]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    Implement this function to return the data loader for  train, validation and test dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        test dataset: test dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader, test_loader: train, validation and test dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 32
   
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn = collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [57]:
print(len(train_loader), len(val_loader))

189 48


### RETAIN 

1. Alpha Attention

In [58]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.a_att` for alpha-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        Implement the alpha attention.
        
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
            
        HINT: consider `torch.softmax`
        """
        e = self.a_att(g)
        alpha = torch.softmax(e, dim = 1)
        return alpha

2. beta attation

In [59]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.b_att` for beta-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        Implement the beta attention.
        
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
            
        HINT: consider `torch.tanh`
        """
        f = self.b_att(h)
        beta = torch.tanh(f)
        return beta

3. Attention Sum

In [60]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
        mask select the hidden states for true visits (not padding visits) and then
        sum the them up.

    Arguments:
        alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
        beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
        rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

    Outputs:
        c: the context vector of shape (batch_size, hidden_dim)
        
    NOTE: Do NOT use for loop.
    """
    
    b, v, e = rev_v.detach().numpy().shape
    rev_masks_true = rev_masks.sum(dim=2)>0
    rev_masks_true = rev_masks_true.reshape(b,v,1)
    rev_v_padded = rev_v * rev_masks_true
    c = torch.sum(alpha * beta * rev_v_padded, dim = 1)
    return c

4. Build Model

In [61]:
def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [62]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, masks, rev_x, rev_masks):
        '''
        Arguments:
            rev_x: the diagnosis sequence in reversed time of shape (# visits, batch_size, # diagnosis codes)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        '''
        # 1. Pass the reversed sequence through the embedding layer;
        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    

# load the model here
retain = RETAIN(num_codes = len(icd_seqs))
retain

RETAIN(
  (embedding): Embedding(7537, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

### Model Training

In [63]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval(model, val_loader):
    
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    REFERENCE: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_logit = model(x, masks, rev_x, rev_masks)
        """
        obtain the predicted class (0, 1) by comparing y_logit against 0.5, 
        assign the predicted class to y_hat.
        """
        y_hat = None
        # your code here
        y_hat = (y_logit > 0.5).int()
#         raise NotImplementedError
        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
    return p, r, f, roc_auc

In [64]:
def train(model, train_loader, val_loader, n_epochs):
    """
    Train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
            """ 
            TODO: calculate the loss using `criterion`, save the output to loss.
            """
            loss = None
            # your code here
            loss = criterion(y_hat, y)
#             raise NotImplementedError
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc = eval(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'.format(epoch+1, p, r, f, roc_auc))
    return round(roc_auc, 2)

In [65]:
# load the model
retain = RETAIN(num_codes = len(icd_seqs))

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
optimizer = torch.optim.Adam(retain.parameters(), lr=1e-3)

n_epochs = 10
train(retain, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.607088
Epoch: 1 	 Validation p: 0.66, r:0.58, f: 0.62, roc_auc: 0.76
Epoch: 2 	 Training Loss: 0.442579
Epoch: 2 	 Validation p: 0.66, r:0.58, f: 0.62, roc_auc: 0.77
Epoch: 3 	 Training Loss: 0.274551
Epoch: 3 	 Validation p: 0.67, r:0.46, f: 0.55, roc_auc: 0.74
Epoch: 4 	 Training Loss: 0.137491
Epoch: 4 	 Validation p: 0.65, r:0.58, f: 0.61, roc_auc: 0.75
Epoch: 5 	 Training Loss: 0.056937
Epoch: 5 	 Validation p: 0.65, r:0.56, f: 0.60, roc_auc: 0.75
Epoch: 6 	 Training Loss: 0.019754
Epoch: 6 	 Validation p: 0.66, r:0.58, f: 0.62, roc_auc: 0.76
Epoch: 7 	 Training Loss: 0.008140
Epoch: 7 	 Validation p: 0.67, r:0.55, f: 0.60, roc_auc: 0.76
Epoch: 8 	 Training Loss: 0.004866
Epoch: 8 	 Validation p: 0.68, r:0.55, f: 0.61, roc_auc: 0.76
Epoch: 9 	 Training Loss: 0.003395
Epoch: 9 	 Validation p: 0.67, r:0.56, f: 0.61, roc_auc: 0.76
Epoch: 10 	 Training Loss: 0.002544
Epoch: 10 	 Validation p: 0.68, r:0.55, f: 0.61, roc_auc: 0.76


0.76

1. Alpha Attention

In [66]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.a_att` for alpha-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        Implement the alpha attention.
        
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
            
        HINT: consider `torch.softmax`
        """
        e = self.a_att(g)
        alpha = torch.softmax(e, dim = 1)
        return alpha

2. beta attation

In [67]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.b_att` for beta-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        Implement the beta attention.
        
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
            
        HINT: consider `torch.tanh`
        """
        f = self.b_att(h)
        beta = torch.tanh(f)
        return beta

3. Attention Sum

In [68]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
        mask select the hidden states for true visits (not padding visits) and then
        sum the them up.

    Arguments:
        alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
        beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
        rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

    Outputs:
        c: the context vector of shape (batch_size, hidden_dim)
        
    NOTE: Do NOT use for loop.
    """
    
    b, v, e = rev_v.detach().numpy().shape
    rev_masks_true = rev_masks.sum(dim=2)>0
    rev_masks_true = rev_masks_true.reshape(b,v,1)
    rev_v_padded = rev_v * rev_masks_true
    c = torch.sum(alpha * beta * rev_v_padded, dim = 1)
    return c

4. Build Model

In [69]:
def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [70]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, masks, rev_x, rev_masks):
        '''
        Arguments:
            rev_x: the diagnosis sequence in reversed time of shape (# visits, batch_size, # diagnosis codes)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        '''
        # 1. Pass the reversed sequence through the embedding layer;
        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    

# load the model here
retain = RETAIN(num_codes = len(icd_seqs))
retain

RETAIN(
  (embedding): Embedding(7537, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [71]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, input_seqs, labels):
        self.x = input_seqs
        self.y = labels
    
    def __len__(self):
        
        """
        Return the number of samples (i.e. patients).
        """
        return len(self.y)
#         raise NotImplementedError
    
    def __getitem__(self, index):
        
        """
        Generates one sample of data.
        """
        
        # your code here
        return self.x[index], self.y[index]
#         raise NotImplementedError
        

dataset = CustomDataset(icd_seqs, labels)
print(len(dataset))

7537


2. Collate Function

In [72]:
def collate_fn(data):
    """
    Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of mortality labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            """
            update `x`, `rev_x`, `masks`, and `rev_masks`
            """
            # your code here
            num_code_pad = max_num_codes - len(visit)
            x[i_patient][j_visit] = torch.tensor(visit + [0]*num_code_pad)
            masks[i_patient][j_visit] = torch.tensor([1]*len(visit) + [0]*num_code_pad)
        
        num_true_visit = len(patient)
        rev_x[i_patient][:num_true_visit] = torch.tensor(x[i_patient][:num_true_visit].tolist()[::-1])
        rev_masks[i_patient][:num_true_visit] = torch.tensor(masks[i_patient][:num_true_visit].tolist()[::-1])

    return x, masks, rev_x, rev_masks, y

In [73]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

In [74]:
print(x.shape, rev_x.shape)
print(masks.shape, rev_masks.shape)
print(y.shape)

torch.Size([10, 5, 21]) torch.Size([10, 5, 21])
torch.Size([10, 5, 21]) torch.Size([10, 5, 21])
torch.Size([10])


3. Split dataset into train and test

In [75]:
from torch.utils.data.dataset import random_split

train_size= int(len(dataset)*0.8)

lengths = [train_size, len(dataset)-train_size]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6029
Length of val dataset: 1508


4. Data loader

In [76]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    Implement this function to return the data loader for  train, validation and test dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        test dataset: test dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader, test_loader: train, validation and test dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 32
   
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn = collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [77]:
print(len(train_loader), len(val_loader))

189 48
