In [1]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "../Project/code/processed_data/"

In [2]:
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
targets = pickle.load(open(os.path.join(DATA_PATH,'train/targets.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
diags = pickle.load(open(os.path.join(DATA_PATH,'train/diags.pkl'), 'rb'))
categories = pickle.load(open(os.path.join(DATA_PATH,'train/categories.pkl'), 'rb'))
sub_categories = pickle.load(open(os.path.join(DATA_PATH,'train/subcategories.pkl'), 'rb'))
assert len(pids) == len(vids) == len(targets) == len(seqs)

In [3]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, targets):
        
        """
        TODO: Store `seqs`. to `self.x` and `hfs` to `self.y`.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        Do NOT permute the data.
        """
        self.x = seqs
        self.y = targets
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        return(len(self.x))
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        return (self.x[index], self.y[index])

In [4]:
dataset = CustomDataset(seqs, targets)

In [39]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

    y = torch.tensor(targets, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = len(sequences[0][0])
#     num_codes = [len(visit) for patient in sequences for visit in patient]
#     global categories
# #     import pdb; pdb.set_trace()
#     num_categories = len(categories)

    max_num_visits = max(num_visits)
#     max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, num_codes), dtype=torch.long)
#     y = torch.zeros((num_patients, num_categories), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, num_codes), dtype=torch.bool)
#     y_masks = torch.zeros((num_patients, num_categories), dtype=torch.bool)
#     import pdb; pdb.set_trace()
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """ 
            x[i_patient, j_visit] = torch.Tensor(visit)
            x_masks[i_patient, j_visit] = torch.Tensor(np.ones(num_codes, dtype=int))
            if j_visit == len(patient) - 1:
                rev_visit = x_masks[i_patient].any(dim=1)
                rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for j_visit, visit in enumerate(targets):
#         for k_cat, cat in enumerate(visit):
#             y[j_visit, k_cat] = cat
#             y_masks[j_visit, k_cat] = 1
    
    return x, x_masks, rev_x, rev_x_masks, y

In [40]:
def collate_fn_old(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

    y = torch.tensor(targets, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
#     global categories
# #     import pdb; pdb.set_trace()
#     num_categories = len(categories)

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
#     y = torch.zeros((num_patients, num_categories), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
#     y_masks = torch.zeros((num_patients, num_categories), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """ 
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                x_masks[i_patient, j_visit, k_code] = 1
                if j_visit == len(patient) - 1 and k_code == len(visit) - 1:
                    rev_visit = x_masks[i_patient].any(dim=1)
                    rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                    rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for j_visit, visit in enumerate(targets):
#         for k_cat, cat in enumerate(visit):
#             y[j_visit, k_cat] = cat
#             y_masks[j_visit, k_cat] = 1
    
    return x, x_masks, rev_x, rev_x_masks, y

In [41]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)
val_split = int(len(dataset)*0.10)

In [42]:
from torch.utils.data.dataset import random_split

train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)

lengths = [train_split, test_split, len(dataset) - (train_split + test_split)]
train_dataset, test_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of test dataset:", len(test_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6561
Length of test dataset: 1312
Length of val dataset: 875


In [43]:
from torch.utils.data import DataLoader

def load_data(train_dataset, test_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 100
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=collate_fn,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             collate_fn=collate_fn,
                                             shuffle=False)
    
    return train_loader, test_loader, val_loader


train_loader, test_loader, val_loader = load_data(train_dataset, test_dataset, val_dataset, collate_fn)

In [44]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        
    NOTE: Do NOT use for loop.

    """
    x[~masks] = 0
    return x.sum(2)

In [45]:
def get_last_visit(hidden_states, masks):
    """
    TODO: obtain the hidden state for the last true visit (not padding visits)

    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
        
    NOTE: DO NOT use for loop.
    
    HINT: Consider using `torch.gather()`.
    """
    idx_vector = masks.any(dim=2).sum(1) - 1
    p_idx = torch.arange(0,len(hidden_states), dtype=torch.int64)
    last_hidden_state = hidden_states[p_idx,idx_vector]
    return last_hidden_state

In [59]:
class NaiveRNN(nn.Module):
    
    """
    TODO: implement the naive RNN model above.
    """
    
    def __init__(self, num_codes, num_categories):
        super().__init__()
        """
        TODO: 
            1. Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
            2. Define the RNN using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
            2. Define the RNN for the reverse direction using `nn.GRU()`;
               Set `hidden_size` to 128. Set `batch_first` to True.
            3. Define the linear layers using `nn.Linear()`; Set `in_features` to 256, and `out_features` to 1.
            4. Define the final activation layer using `nn.Sigmoid().

        Arguments:
            num_codes: total number of diagnosis codes
        """
        
        self.embedding = nn.Embedding(num_codes, embedding_dim=128)
        self.rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.rev_rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(256, num_categories)
        self.softmax = nn.Softmax(dim=1)
        
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        import pdb; pdb.set_trace()
        batch_size = x.shape[0]
        
        # 1. Pass the sequence through the embedding layer;
        x = self.embedding(x)
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        x = sum_embeddings_with_mask(x, masks)
        
        # 3. Pass the embegginds through the RNN layer;
        output, _ = self.rnn(x)
        # 4. Obtain the hidden state at the last visit.
        true_h_n = get_last_visit(output, masks)
        
        """
        TODO:
            5. Do the step 1-4 again for the reverse order, and concatenate the hidden
               states for both directions;
        """
        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        rev_output, _ = self.rev_rnn(rev_x)
        true_h_n_rev = get_last_visit(rev_output, rev_masks)
        
        # 6. Pass the hidden state through the linear and activation layers.
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 1))        
        probs = self.softmax(logits)
        return probs
    

# load the model here
naive_rnn = NaiveRNN(num_codes = num_codes, num_categories=len(categories))
naive_rnn

NaiveRNN(
  (embedding): Embedding(4903, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=19, bias=True)
  (softmax): Softmax(dim=1)
)

In [61]:
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)
optimizer = torch.optim.Adadelta(naive_rnn.parameters(), weight_decay=0.001)

In [68]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, test_loader, k=15):
    
    """
    TODO: evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
        
    HINT: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    with torch.no_grad():
        for x, masks, rev_x, rev_masks, y in test_loader:
            #import pdb; pdb.set_trace()
            y_hat = model(x, masks, rev_x, rev_masks)
            num_categories = torch.count_nonzero(y, dim=1)
            nz_rows, nz_cols = torch.nonzero(y, as_tuple=True)
            k_correct = 0
            predictions = 0
            precision = 0
            for i in range(k):
                visit_correct = 0
                y_true = nz_cols[nz_rows == i]
                _, y_pred = torch.topk(y_hat[i], len(y_true))
                for v in y_pred:
                    if v in y_true:
                        visit_correct += 1
                predictions += len(y_true)
                precision += visit_correct / min(k, len(y_true))
                k_correct += visit_correct
            #import pdb; pdb.set_trace()
            precision_k = precision / k
            accuracy_k = k_correct / predictions
            
#             y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
#             y_hat = (y_hat > 0.5).int()
#             y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
#             y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
#     p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average='binary')
#     roc_auc = roc_auc_score(y_true, y_score)
    return precision_k, accuracy_k

In [69]:
def train(model, train_loader, test_loader, n_epochs):
    """
    TODO: train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
        
    You need to call `eval_model()` at the end of each training epoch to see how well the model performs 
    on validation data.
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, x_masks, rev_x, rev_x_masks, y in train_loader:
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
            y_hat = model(x, x_masks, rev_x, rev_x_masks)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
        for k in range(5, 31, 5):
            precision_k, accuracy_k = eval_model(model, test_loader, k=k)
            print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.2f}, accuracy@k{k}: {accuracy_k:.2f}')

In [70]:
n_epochs = 100
train(naive_rnn, train_loader, test_loader, n_epochs)

TypeError: forward() takes 2 positional arguments but 5 were given

In [None]:
for k in range(5, 31, 5):
    precision_k, accuracy_k = eval_model(model, val_loader, k=k)
    print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.2f}, accuracy@k{k}: {accuracy_k:.2f}')