In [1]:
import os
import sys
import pickle
import psutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "data/"
GRAM_DATA_PATH = "../Project/code/processed_data/gram"

In [2]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'vids.pkl'), 'rb'))
targs = pickle.load(open(os.path.join(DATA_PATH,'targets.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'seqs.pkl'), 'rb'))
diags = pickle.load(open(os.path.join(DATA_PATH,'diags.pkl'), 'rb'))
codes = pickle.load(open(os.path.join(DATA_PATH,'icd9.pkl'), 'rb'))
categories = pickle.load(open(os.path.join(DATA_PATH,'categories.pkl'), 'rb'))
sub_categories = pickle.load(open(os.path.join(DATA_PATH,'subcategories.pkl'), 'rb'))
assert len(pids) == len(vids) == len(targs) == len(seqs)

In [3]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, targets):
        
        """
        TODO: Store `seqs`. to `self.x` and `hfs` to `self.y`.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        Do NOT permute the data.
        """
#         x = []
#         for i,patient in enumerate(seqs):
#             for j,visit in enumerate(patient):
#                 if j == len(patient) - 1:
#                     break
#                 x.append(visit)
#         y = []
#         for i,patient in enumerate(targets):
#             for j,visit in enumerate(patient):
#                 if j == len(patient) - 1:
#                     break
#                 y.append(patient[j+1])

        self.x = seqs
        self.y = targets
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        return(len(self.x))
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        return (self.x[index], self.y[index])

In [4]:
dataset = CustomDataset(seqs, targs)

In [5]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

#     y = torch.tensor(targets, dtype=torch.float)
#     import pdb; pdb.set_trace()
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    batch_num_categories = [len(visit) for patient in targets for visit in patient]
    global sub_categories
# #     import pdb; pdb.set_trace()
    num_categories = len(sub_categories)

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    max_num_categories = max(batch_num_categories)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    y = torch.zeros((num_patients, max_num_categories), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    y_masks = torch.zeros((num_patients, max_num_categories), dtype=torch.bool)
#     import pdb; pdb.set_trace()
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient[:-1]):
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                x_masks[i_patient, j_visit, k_code] = 1
                if j_visit == len(patient) - 2 and k_code == len(visit) - 1:
                    rev_visit = x_masks[i_patient].any(dim=1)
                    rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                    rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for i_patient, patient in enumerate(targets):   
#         for j_visit, visit in enumerate(patient[1:]):
#             for k_code, code in enumerate(visit):
#                 y[i_patient, j_visit, k_code] = code
#                 y_masks[i_patient, j_visit, k_code] = 1
                
    for i_patient, patient in enumerate(targets):   
        for visit in patient[-1:]:
            for k_code, code in enumerate(visit):
                y[i_patient, k_code] = code
                y_masks[i_patient, k_code] = 1

#     for i_patient, patient in enumerate(sequences):   
#         for j_visit, visit in enumerate(patient):
#             for k_code, code in enumerate(visit):
#                 x[i_patient, j_visit, k_code] = code
#                 x_masks[i_patient, j_visit, k_code] = 1
#                 if j_visit == len(patient) - 1 and k_code == len(visit) - 1:
#                     rev_visit = x_masks[i_patient].any(dim=1)
#                     rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
#                     rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for i_patient, patient in enumerate(targets):   
#         for j_visit, visit in enumerate(patient):
#             for k_code, code in enumerate(visit):
#                 y[i_patient, j_visit, k_code] = code
#                 y_masks[i_patient, j_visit, k_code] = 1
    
    return x, x_masks, rev_x, rev_x_masks, y, y_masks

In [6]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)
val_split = int(len(dataset)*0.10)

In [7]:
from torch.utils.data.dataset import random_split

train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)

lengths = [train_split, test_split, len(dataset) - (train_split + test_split)]
train_dataset, test_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of test dataset:", len(test_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6561
Length of test dataset: 1312
Length of val dataset: 875


In [8]:
from torch.utils.data import DataLoader

def load_data(train_dataset, test_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 100
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=collate_fn,
                                               shuffle=False)
    test_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             collate_fn=collate_fn,
                                             shuffle=False)
    
    return train_loader, test_loader, val_loader


train_loader, test_loader, val_loader = load_data(train_dataset, test_dataset, val_dataset, collate_fn)

In [9]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        
    NOTE: Do NOT use for loop.

    """
    x[~masks] = 0
    return x.sum(2)

In [23]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
    TODO: mask select the hidden states for true visits (not padding visits) and then
        sum the them up.

    Arguments:
        alpha: the alpha attention weights of shape (batch_size, seq_length, 1)
        beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
        rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

    Outputs:
        c: the context vector of shape (batch_size, hidden_dim)
        
    NOTE: Do NOT use for loop.
    """
    rev_masks = rev_masks[:,:].any(dim=2)
    rev_v_masked = rev_v.clone()
    rev_v_masked[~rev_masks] = 0
    weights = alpha * beta
    a_sum = torch.sum(weights * rev_v_masked, dim=1)
    return a_sum

In [24]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.a_att` for alpha-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        Arguments:
            g: the output tensor from RNN-alpha of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            alpha: the corresponding attention weights of shape (batch_size, seq_length, 1)
        """
        m = nn.Softmax(dim=1)
        return m(self.a_att(g))

In [25]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        """
        Define the linear layer `self.b_att` for beta-attention using `nn.Linear()`;
        
        Arguments:
            hidden_dim: the hidden dimension
        """
        
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        """
        TODO: Implement the beta attention.
        
        Arguments:
            h: the output tensor from RNN-beta of shape (batch_size, seq_length, hidden_dim) 
        
        Outputs:
            beta: the corresponding attention weights of shape (batch_size, seq_length, hidden_dim)
            
        HINT: consider `torch.tanh`
        """
        return torch.tanh(self.b_att(h))

In [26]:
class BaselineRETAIN(nn.Module):
    
    """
    TODO: implement the naive RNN model above.
    """
    
    def __init__(self, num_codes, num_categories):
        super().__init__()
        """
        TODO: 
            1. Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
            2. Define the RNN using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
            2. Define the RNN for the reverse direction using `nn.GRU()`;
               Set `hidden_size` to 128. Set `batch_first` to True.
            3. Define the linear layers using `nn.Linear()`; Set `in_features` to 256, and `out_features` to 1.
            4. Define the final activation layer using `nn.Sigmoid().

        Arguments:
            num_codes: total number of diagnosis codes
        """
        
        self.embedding = nn.Embedding(num_codes, embedding_dim=128)
        self.rnn_a = nn.GRU(128, hidden_size=128, batch_first=True)
        self.rnn_b = nn.GRU(128, hidden_size=128, batch_first=True)
        self.att_a = AlphaAttention(128)
        self.att_b = AlphaAttention(128)
        self.fc = nn.Linear(128, num_categories)
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
#         import pdb; pdb.set_trace()
       
        # 1. Pass the sequence through the embedding layer;
        rev_x = self.embedding(rev_x)
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the embegginds through the RNN layer;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        # 4. Obtain the hidden state at the last visit.
#         true_h_n = get_last_visit(output, masks)
        alpha = self.att_a(g)
        beta = self.att_b(h)
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        logits = self.fc(c)
        
        # 6. Pass the hidden state through the linear and activation layers.
        #import pdb; pdb.set_trace()
        #logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 1))        
        probs = self.softmax(logits)
        #probs = self.sigmoid(logits)
        
#         probs = probs.reshape(probs.shape[0]*probs.shape[1], probs.shape[2])
#         y_masks = y_masks.reshape(y_masks.shape[0]*y_masks.shape[1], y_masks.shape[2])
#         probs = probs[y_masks.any(dim=1)]
        return logits
    

# load the model here
baseline_retain = BaselineRETAIN(num_codes = len(codes), num_categories=len(sub_categories))
baseline_retain

BaselineRETAIN(
  (embedding): Embedding(4903, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (fc): Linear(in_features=128, out_features=184, bias=True)
  (softmax): Softmax(dim=1)
  (sigmoid): Sigmoid()
)

In [27]:
criterion = nn.CrossEntropyLoss()
#criterion = nn.BCELoss()
#criterion = nn.BCEWithLogitsLoss()
#optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)
optimizer = torch.optim.Adadelta(baseline_retain.parameters(), weight_decay=0.001)

In [29]:
def train(model, train_loader, test_loader, n_epochs):
    """
    TODO: train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
        
    You need to call `eval_model()` at the end of each training epoch to see how well the model performs 
    on validation data.
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
    """
    #base_cpu, base_ram = print_cpu_usage()
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, x_masks, rev_x, rev_x_masks, y, y_masks in train_loader:
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
#             import pdb; pdb.set_trace()
            y_hat = model(x, x_masks, rev_x, rev_x_masks)
#             import pdb; pdb.set_trace()
#             y[~y_masks] = criterion.ignore_index
#             last_y = y_masks.any(dim=2).sum(dim=1) - 1
#             indices = last_y.unsqueeze(-1)
#             indices = indices.repeat(1, y.shape[2])
#             indices = indices.unsqueeze(1)
#             y_filt = torch.gather(y, 1, indices)

#             n_visits = y_masks.any(dim=2).sum(dim=1)
#             for i_patient, j_visit in enumerate(n_visits):
#                 for visit in range(j_visit - 1):
#                     mask = y_masks[i_patient, visit+1]

#                     yh = y_hat[i_patient, visit]
#                     y_tmp = indices_to_multihot(
#                         y[i_patient, visit+1], mask, yh)
            
            # stack into visits
#             yh = y_hat.reshape(y_hat.shape[0] * y_hat.shape[1], y_hat.shape[2])
#             y_masks = y_masks.reshape(y_masks.shape[0] * y_masks.shape[1], y_masks.shape[2])
#             y = y.reshape(y.shape[0] * y.shape[1], y.shape[2])
        
                    
            
            y_mh = indices_to_multihot(y, y_masks, y_hat)
            loss = criterion(y_hat, y_mh)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print_cpu_usage()
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
        for k in range(5, 31, 5):
            precision_k, accuracy_k = eval_model(model, test_loader, k=k)
            print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.2f}, accuracy@k{k}: {accuracy_k:.2f}')

In [34]:
def eval_model(model, test_loader, k=15, n=-1):
    
    """
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    """
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    all_precision = []
    all_accuracy = []
    
    model.eval()
    with torch.no_grad():
        for x, x_masks, rev_x, rev_x_masks, y, y_masks in test_loader:
#             import pdb; pdb.set_trace()
            n_eval = y.shape[0] - 1 if n == -1 else n
            y_hat = model(x, x_masks, rev_x, rev_x_masks)
            y_hat = F.softmax(y_hat, dim=-1)
#             num_labels = y_hat.shape[1]
#             num_categories = torch.count_nonzero(y, dim=1)
#             nz_rows, nz_cols = torch.nonzero(y, as_tuple=True)
            y_multihot = indices_to_multihot(y, y_masks, y_hat)
            k_correct = 0
#             predictions = 0
            total_precision = 0
            total_accuracy = 0
            for i in range(n_eval):
                visit_correct = 0
#                 y_true = nz_cols[nz_rows == i]
                y_true = y[i, y_masks[i]]
                _, y_pred = torch.topk(y_hat[i], k)
#                 for v in y_pred:
#                     if v in y_true:
#                         visit_correct += 1
                for v in y_true:
                    if v in y_pred:
                        visit_correct += 1
#                 predictions += len(y_true)
                visit_precision = visit_correct / min(k, len(y_true))
                visit_accuracy = visit_correct / len(y_true)
                #print(f'visit {i}: precision: {visit_precision:0.2f} accuracy: {visit_accuracy:0.2f}')
                k_correct += visit_correct
                total_precision += visit_precision
                total_accuracy += visit_accuracy
            #import pdb; pdb.set_trace()
#             precision_k = precision / k
#             accuracy_k = k_correct / predictions
            precision_k = total_precision / n_eval
            accuracy_k = total_accuracy / n_eval
            all_precision.append(precision_k)
            all_accuracy.append(accuracy_k)
            
#             y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
#             y_hat = (y_hat > 0.5).int()
#             y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
#             y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
#     import pdb; pdb.set_trace()
    total_precision_k = np.mean(all_precision)
    total_accuracy_k = np.mean(all_accuracy)
    return total_precision_k, total_accuracy_k

In [35]:
def indices_to_multihot(indices, masks, y_hat):
#     import pdb; pdb.set_trace()
    #indices = indices[masks.any(dim=1)]
    multihot = torch.zeros_like(y_hat, dtype=torch.float)
    for idx, row in enumerate(indices):
        y_idx = row[masks[idx]].unique()
        multihot[idx] = F.one_hot(y_idx, y_hat.shape[1]).sum(0).float()
    return multihot

In [36]:
def print_cpu_usage():
    load = psutil.getloadavg()[2]
    cpu_usage = (load/os.cpu_count()) * 100
    ram = psutil.virtual_memory()[2]
    print(f"CPU: {cpu_usage:0.2f}")
    print(f"RAM %: {ram}")
    return cpu_usage, ram

In [None]:
n_epochs = 100
%time train(baseline_retain, train_loader, test_loader, n_epochs)

CPU: 23.71
RAM %: 61.9
Epoch: 1 	 Training Loss: 41.331719
Epoch: 1 	 Validation precision@k5: 0.66, accuracy@k5: 0.35
Epoch: 1 	 Validation precision@k10: 0.62, accuracy@k10: 0.52
Epoch: 1 	 Validation precision@k15: 0.65, accuracy@k15: 0.63
Epoch: 1 	 Validation precision@k20: 0.72, accuracy@k20: 0.72
Epoch: 1 	 Validation precision@k25: 0.77, accuracy@k25: 0.77
Epoch: 1 	 Validation precision@k30: 0.82, accuracy@k30: 0.82
CPU: 25.21
RAM %: 62.0
Epoch: 2 	 Training Loss: 40.443196
Epoch: 2 	 Validation precision@k5: 0.68, accuracy@k5: 0.36
Epoch: 2 	 Validation precision@k10: 0.63, accuracy@k10: 0.54
Epoch: 2 	 Validation precision@k15: 0.67, accuracy@k15: 0.65
Epoch: 2 	 Validation precision@k20: 0.73, accuracy@k20: 0.73
Epoch: 2 	 Validation precision@k25: 0.79, accuracy@k25: 0.79
Epoch: 2 	 Validation precision@k30: 0.83, accuracy@k30: 0.83
CPU: 25.27
RAM %: 61.9
Epoch: 3 	 Training Loss: 40.007895
Epoch: 3 	 Validation precision@k5: 0.69, accuracy@k5: 0.37
Epoch: 3 	 Validation p

In [None]:
for k in range(5, 31, 5):
    precision_k, accuracy_k = eval_model(baseline_dipole, test_loader, k=k)
    print(f'Validation precision@k{k}: {precision_k:.2f}, accuracy@k{k}: {accuracy_k:.2f}')