In [41]:
import os
import sys
import pickle
import psutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "data/"
GRAM_DATA_PATH = "../Project/code/processed_data/gram"

In [44]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'vids.pkl'), 'rb'))
targets = pickle.load(open(os.path.join(DATA_PATH,'targets.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'seqs.pkl'), 'rb'))
diags = pickle.load(open(os.path.join(DATA_PATH,'diags.pkl'), 'rb'))
codes = pickle.load(open(os.path.join(DATA_PATH,'icd9.pkl'), 'rb'))
categories = pickle.load(open(os.path.join(DATA_PATH,'categories.pkl'), 'rb'))
sub_categories = pickle.load(open(os.path.join(DATA_PATH,'subcategories.pkl'), 'rb'))
assert len(pids) == len(vids) == len(targets) == len(seqs)

In [43]:
ccs_level1 = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.level1.pk'), 'rb'))
ccs_level2 = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.level2.pk'), 'rb'))
ccs_level3 = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.level3.pk'), 'rb'))
ccs_level4 = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.level4.pk'), 'rb'))
ccs_level5 = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.level5.pk'), 'rb'))
ccs_seqs = ccs_level1 = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.seqs'), 'rb'))
ccs_types = pickle.load(open(os.path.join(GRAM_DATA_PATH,'gram_ccs.types'), 'rb'))


FileNotFoundError: [Errno 2] No such file or directory: '../Project/code/processed_data/gram/gram_ccs.level1.pk'

In [4]:
gram_3digit_seqs = pickle.load(open(os.path.join(GRAM_DATA_PATH,'processed_gram.3digitICD9.seqs'), 'rb'))
gram_3digit_types = pickle.load(open(os.path.join(GRAM_DATA_PATH,'processed_gram.3digitICD9.types'), 'rb'))
gram_dates = pickle.load(open(os.path.join(GRAM_DATA_PATH,'processed_gram.dates'), 'rb'), encoding='latin1')
gram_pids = pickle.load(open(os.path.join(GRAM_DATA_PATH,'processed_gram.pids'), 'rb'))
gram_seqs = pickle.load(open(os.path.join(GRAM_DATA_PATH,'processed_gram.seqs'), 'rb'))
gram_types = pickle.load(open(os.path.join(GRAM_DATA_PATH,'processed_gram.types'), 'rb'))


FileNotFoundError: [Errno 2] No such file or directory: '../Project/code/processed_data/gram/processed_gram.3digitICD9.seqs'

In [5]:
r_3digit_types = {v:k for k,v in gram_3digit_types.items()}

NameError: name 'gram_3digit_types' is not defined

In [6]:
rtypes_ccs = {v:k for k,v in ccs_types.items()}
rtypes = {v:k for k,v in gram_types.items()}

In [45]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, targets):
        
        """
        TODO: Store `seqs`. to `self.x` and `hfs` to `self.y`.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        Do NOT permute the data.
        """
#         x = []
#         for i,patient in enumerate(seqs):
#             for j,visit in enumerate(patient):
#                 if j == len(patient) - 1:
#                     break
#                 x.append(visit)
#         y = []
#         for i,patient in enumerate(targets):
#             for j,visit in enumerate(patient):
#                 if j == len(patient) - 1:
#                     break
#                 y.append(patient[j+1])

        self.x = seqs
        self.y = targets
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        return(len(self.x))
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        return (self.x[index], self.y[index])

In [46]:
dataset = CustomDataset(seqs, targets)

In [8]:
def collate_fn_YMask(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

#     y = torch.tensor(targets, dtype=torch.float)
#     import pdb; pdb.set_trace()
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    batch_num_categories = [len(visit) for patient in targets for visit in patient]
    global sub_categories
# #     import pdb; pdb.set_trace()
    num_categories = len(sub_categories)

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    #max_num_categories = max(num_categories)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    y = torch.zeros((num_patients, max_num_visits, num_categories), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    y_masks = torch.zeros((num_patients, max_num_visits, num_categories), dtype=torch.bool)
#     import pdb; pdb.set_trace()
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient[:-1]):
#             x[i_patient, j_visit] = torch.Tensor(visit)
#             x_masks[i_patient, j_visit] = torch.Tensor(np.ones(num_codes, dtype=int))
#             if j_visit == len(patient) - 2:
#                 rev_visit = x_masks[i_patient].any(dim=1)
#                 rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
#                 rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                x_masks[i_patient, j_visit, k_code] = 1
                if j_visit == len(patient) - 2 and k_code == len(visit) - 1:
                    rev_visit = x_masks[i_patient].any(dim=1)
                    rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                    rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
    for i_patient, patient in enumerate(targets):   
        for j_visit, visit in enumerate(patient[1:]):
            for k_code, code in enumerate(visit):
                y[i_patient, j_visit, k_code] = code
                y_masks[i_patient, j_visit, k_code] = 1
#             y[i_patient, j_visit] = torch.Tensor(visit)
#             y_masks[i_patient, j_visit] = torch.Tensor(np.ones(num_codes, dtype=int))
    
    return x, x_masks, rev_x, rev_x_masks, y, y_masks

In [47]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

#     y = torch.tensor(targets, dtype=torch.float)
#     import pdb; pdb.set_trace()
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    batch_num_categories = [len(visit) for patient in targets for visit in patient]
    global sub_categories
# #     import pdb; pdb.set_trace()
    num_categories = len(sub_categories)

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    max_num_categories = max(batch_num_categories)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    y = torch.zeros((num_patients, max_num_categories), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    y_masks = torch.zeros((num_patients, max_num_categories), dtype=torch.bool)
#     import pdb; pdb.set_trace()
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient[:-1]):
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                x_masks[i_patient, j_visit, k_code] = 1
                if j_visit == len(patient) - 2 and k_code == len(visit) - 1:
                    rev_visit = x_masks[i_patient].any(dim=1)
                    rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                    rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for i_patient, patient in enumerate(targets):   
#         for j_visit, visit in enumerate(patient[1:]):
#             for k_code, code in enumerate(visit):
#                 y[i_patient, j_visit, k_code] = code
#                 y_masks[i_patient, j_visit, k_code] = 1
                
    for i_patient, patient in enumerate(targets):   
        for visit in patient[-1:]:
            for k_code, code in enumerate(visit):
                y[i_patient, k_code] = code
                y_masks[i_patient, k_code] = 1

#     for i_patient, patient in enumerate(sequences):   
#         for j_visit, visit in enumerate(patient):
#             for k_code, code in enumerate(visit):
#                 x[i_patient, j_visit, k_code] = code
#                 x_masks[i_patient, j_visit, k_code] = 1
#                 if j_visit == len(patient) - 1 and k_code == len(visit) - 1:
#                     rev_visit = x_masks[i_patient].any(dim=1)
#                     rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
#                     rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for i_patient, patient in enumerate(targets):   
#         for j_visit, visit in enumerate(patient):
#             for k_code, code in enumerate(visit):
#                 y[i_patient, j_visit, k_code] = code
#                 y_masks[i_patient, j_visit, k_code] = 1
    
    return x, x_masks, rev_x, rev_x_masks, y, y_masks

In [10]:
        x = []
        for i,patient in enumerate(seqs):
            for j,visit in enumerate(patient):
                if j == len(patient) - 1:
                    break
                x.append(visit)
        y = []
        for i,patient in enumerate(targets):
            for j,visit in enumerate(patient):
                if j == len(patient) - 1:
                    break
                y.append(patient[j+1])

In [11]:
def collate_fn_old(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

    y = torch.tensor(targets, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
#     global categories
# #     import pdb; pdb.set_trace()
#     num_categories = len(categories)

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
#     y = torch.zeros((num_patients, num_categories), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_x_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
#     y_masks = torch.zeros((num_patients, num_categories), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):   
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """ 
            for k_code, code in enumerate(visit):
                x[i_patient, j_visit, k_code] = code
                x_masks[i_patient, j_visit, k_code] = 1
                if j_visit == len(patient) - 1 and k_code == len(visit) - 1:
                    rev_visit = x_masks[i_patient].any(dim=1)
                    rev_x[i_patient, rev_visit] = x[i_patient, rev_visit].flip(0)
                    rev_x_masks[i_patient, rev_visit] = x_masks[i_patient, rev_visit].flip(0)
  
#     for j_visit, visit in enumerate(targets):
#         for k_cat, cat in enumerate(visit):
#             y[j_visit, k_cat] = cat
#             y_masks[j_visit, k_cat] = 1
    
    return x, x_masks, rev_x, rev_x_masks, y

In [48]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)
val_split = int(len(dataset)*0.10)

In [49]:
from torch.utils.data.dataset import random_split

train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)

lengths = [train_split, test_split, len(dataset) - (train_split + test_split)]
train_dataset, test_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of test dataset:", len(test_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6561
Length of test dataset: 1312
Length of val dataset: 875


In [50]:
from torch.utils.data import DataLoader

def load_data(train_dataset, test_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 100
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=collate_fn,
                                               shuffle=False)
    test_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             collate_fn=collate_fn,
                                             shuffle=False)
    
    return train_loader, test_loader, val_loader


train_loader, test_loader, val_loader = load_data(train_dataset, test_dataset, val_dataset, collate_fn)

In [51]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        
    NOTE: Do NOT use for loop.

    """
    x[~masks] = 0
    return x.sum(2)

In [52]:
def get_last_visit(hidden_states, masks):
    """
    TODO: obtain the hidden state for the last true visit (not padding visits)

    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
        
    NOTE: DO NOT use for loop.
    
    HINT: Consider using `torch.gather()`.
    """
    idx_vector = masks.any(dim=2).sum(1) - 1
    p_idx = torch.arange(0,len(hidden_states), dtype=torch.int64)
    last_hidden_state = hidden_states[p_idx,idx_vector]
    return last_hidden_state

In [53]:
class NaiveRNN(nn.Module):
    
    """
    TODO: implement the naive RNN model above.
    """
    
    def __init__(self, num_codes, num_categories):
        super().__init__()
        """
        TODO: 
            1. Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
            2. Define the RNN using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
            2. Define the RNN for the reverse direction using `nn.GRU()`;
               Set `hidden_size` to 128. Set `batch_first` to True.
            3. Define the linear layers using `nn.Linear()`; Set `in_features` to 256, and `out_features` to 1.
            4. Define the final activation layer using `nn.Sigmoid().

        Arguments:
            num_codes: total number of diagnosis codes
        """
        
        self.embedding = nn.Embedding(num_codes, embedding_dim=128)
        self.rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.rev_rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(256, num_categories)
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        
    
    def forward(self, x, masks, rev_x, rev_masks, y_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
#         import pdb; pdb.set_trace()
        batch_size = x.shape[0]
        
        # 1. Pass the sequence through the embedding layer;
        x = self.embedding(x)
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        x = sum_embeddings_with_mask(x, masks)
        
        # 3. Pass the embegginds through the RNN layer;
        output, _ = self.rnn(x)
        # 4. Obtain the hidden state at the last visit.
        true_h_n = get_last_visit(output, masks)
        
        """
        TODO:
            5. Do the step 1-4 again for the reverse order, and concatenate the hidden
               states for both directions;
        """
        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        rev_output, _ = self.rev_rnn(rev_x)
        true_h_n_rev = get_last_visit(rev_output, rev_masks)
        
        # 6. Pass the hidden state through the linear and activation layers.
        #import pdb; pdb.set_trace()
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 1))        
        # probs = self.softmax(logits)
        probs = self.sigmoid(logits)
        
#         probs = probs.reshape(probs.shape[0]*probs.shape[1], probs.shape[2])
#         y_masks = y_masks.reshape(y_masks.shape[0]*y_masks.shape[1], y_masks.shape[2])
#         probs = probs[y_masks.any(dim=1)]
        return logits
    

# load the model here
naive_rnn = NaiveRNN(num_codes = len(codes), num_categories=len(sub_categories))
naive_rnn

NaiveRNN(
  (embedding): Embedding(4903, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=184, bias=True)
  (softmax): Softmax(dim=1)
  (sigmoid): Sigmoid()
)

In [54]:
#criterion = nn.CrossEntropyLoss()
#criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)
#optimizer = torch.optim.Adadelta(naive_rnn.parameters(), weight_decay=0.001)

In [73]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, test_loader, threshold=0.5, k=15, n=-1):
    
    """
    TODO: evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
        
    HINT: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    all_precision = []
    all_accuracy = []
    model.eval()
    with torch.no_grad():
        for x, masks, rev_x, rev_masks, y, y_masks in test_loader:
            #import pdb; pdb.set_trace()
            nn = y.shape[0] - 1 if n == -1 else n
            y_hat = model(x, masks, rev_x, rev_masks, y_masks)
#             num_categories = torch.count_nonzero(y, dim=2)
#             nz_rows, nz_cols = torch.nonzero(y, as_tuple=True)
            k_correct = 0
            num_predictions = 0
            num_targets = 0
            all_predictions = []
            all_targets = []
            precision = 0
            total_precision = 0
            total_accuracy = 0
#             y_masks = y_masks.reshape(
#                 y_masks.shape[0] * y_masks.shape[1], y_masks.shape[2])
#             y = y.reshape(y.shape[0] * y.shape[1], y.shape[2])
#             y = y[y_masks.any(dim=1)]
#             y_masks = y_masks[y_masks.any(dim=1)]


#             v_idx = masks.any(dim=2)
#             v_idx = v_idx.sum(dim=1)
#             v_idx = v_idx.unsqueeze(-1)
#             v_idx = v_idx.repeat(1,y_hat.shape[2])
#             v_idx = v_idx.unsqueeze(1)
#             y_hat = torch.gather(y_hat,1,v_idx).squeeze()
            for i in range(k):
                
                visit_correct = 0
#                 y_true = nz_cols[nz_rows == i]

                y_true = y[i, y_masks[i]].unique()
                all_targets.extend(y_true.tolist())
                _, y_pred = torch.topk(y_hat[i], len(y_true))
                #y_pred = torch.nonzero(y_hat[0] > threshold).squeeze()
                if y_pred.numel() > 0:
                    try:
                        all_predictions.extend(y_pred.tolist())
                    except TypeError:
                        y_pred = [y_pred.tolist()]
                    all_predictions.extend(y_pred)
#                     for v in y_pred:
#                         if v in y_true:
#                             visit_correct += 1
                    for v in y_true:
                        if v in y_pred:
                            visit_correct += 1
                    num_predictions += len(y_pred)

                num_targets += len(y_true)
                precision += visit_correct / min(k, len(y_true))
                k_correct += visit_correct
                visit_precision = visit_correct / min(k, len(y_true))
                visit_accuracy = visit_correct / len(y_true)
                total_precision += visit_precision
                total_accuracy += visit_accuracy
 #           import pdb; pdb.set_trace()
            precision_k = precision / k
#             precision_k1 = k_correct / min(k, num_targets)
            if num_predictions == 0:
                accuracy_k = 0
            else:
                accuracy_k = k_correct / num_predictions
            precision_k = total_precision / nn
            accuracy_k = total_accuracy / nn
            all_precision.append(precision_k)
            all_accuracy.append(accuracy_k)
            
#             y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
#             y_hat = (y_hat > 0.5).int()
#             y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
#             y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
#     p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average='binary')
#     roc_auc = roc_auc_score(y_true, y_score)
    total_precision_k = np.mean(all_precision)
    total_accuracy_k = np.mean(all_accuracy)
    return total_precision_k, total_accuracy_k

In [74]:
 precision_k, accuracy_k = eval_model(naive_rnn, train_loader, k=5)

In [75]:
def train(model, train_loader, test_loader, n_epochs):
    """
    TODO: train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
        
    You need to call `eval_model()` at the end of each training epoch to see how well the model performs 
    on validation data.
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
    """
    #base_cpu, base_ram = print_cpu_usage()
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, x_masks, rev_x, rev_x_masks, y, y_masks in train_loader:
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
#             import pdb; pdb.set_trace()
            y_hat = model(x, x_masks, rev_x, rev_x_masks, y_masks)
#             import pdb; pdb.set_trace()
#             y[~y_masks] = criterion.ignore_index
#             last_y = y_masks.any(dim=2).sum(dim=1) - 1
#             indices = last_y.unsqueeze(-1)
#             indices = indices.repeat(1, y.shape[2])
#             indices = indices.unsqueeze(1)
#             y_filt = torch.gather(y, 1, indices)

#             n_visits = y_masks.any(dim=2).sum(dim=1)
#             for i_patient, j_visit in enumerate(n_visits):
#                 for visit in range(j_visit - 1):
#                     mask = y_masks[i_patient, visit+1]

#                     yh = y_hat[i_patient, visit]
#                     y_tmp = indices_to_multihot(
#                         y[i_patient, visit+1], mask, yh)
            
            # stack into visits
#             yh = y_hat.reshape(y_hat.shape[0] * y_hat.shape[1], y_hat.shape[2])
#             y_masks = y_masks.reshape(y_masks.shape[0] * y_masks.shape[1], y_masks.shape[2])
#             y = y.reshape(y.shape[0] * y.shape[1], y.shape[2])
        
                    
            
            y_mh = indices_to_multihot(y, y_masks, y_hat)
            loss = criterion(y_hat, y_mh)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print_cpu_usage()
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
        for k in range(5, 31, 5):
            precision_k, accuracy_k = eval_model(model, test_loader, k=k)
            print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.2f}, accuracy@k{k}: {accuracy_k:.2f}')

In [76]:
def indices_to_multihot(indices, masks, y_hat):
#     import pdb; pdb.set_trace()
    #indices = indices[masks.any(dim=1)]
    multihot = torch.zeros_like(y_hat, dtype=torch.float)
    for idx, row in enumerate(indices):
        y_idx = row[masks[idx]].unique()
        multihot[idx] = F.one_hot(y_idx, y_hat.shape[1]).sum(0).float()
    return multihot

In [77]:
def print_cpu_usage():
    load = psutil.getloadavg()[2]
    cpu_usage = (load/os.cpu_count()) * 100
    ram = psutil.virtual_memory()[2]
    print(f"CPU: {cpu_usage:0.2f}")
    print(f"RAM %: {ram}")
    return cpu_usage, ram

In [78]:
n_epochs = 100
%time train(naive_rnn, train_loader, test_loader, n_epochs)

CPU: 19.25
RAM %: 61.7
Epoch: 1 	 Training Loss: 0.136856
Epoch: 1 	 Validation precision@k5: 0.05, accuracy@k5: 0.03
Epoch: 1 	 Validation precision@k10: 0.06, accuracy@k10: 0.05
Epoch: 1 	 Validation precision@k15: 0.08, accuracy@k15: 0.08
Epoch: 1 	 Validation precision@k20: 0.10, accuracy@k20: 0.10
Epoch: 1 	 Validation precision@k25: 0.13, accuracy@k25: 0.13
Epoch: 1 	 Validation precision@k30: 0.15, accuracy@k30: 0.15
CPU: 19.32
RAM %: 61.7
Epoch: 2 	 Training Loss: 0.131998
Epoch: 2 	 Validation precision@k5: 0.06, accuracy@k5: 0.03
Epoch: 2 	 Validation precision@k10: 0.07, accuracy@k10: 0.05
Epoch: 2 	 Validation precision@k15: 0.08, accuracy@k15: 0.08
Epoch: 2 	 Validation precision@k20: 0.11, accuracy@k20: 0.11
Epoch: 2 	 Validation precision@k25: 0.13, accuracy@k25: 0.13
Epoch: 2 	 Validation precision@k30: 0.16, accuracy@k30: 0.16
CPU: 19.39
RAM %: 62.3
Epoch: 3 	 Training Loss: 0.128315
Epoch: 3 	 Validation precision@k5: 0.06, accuracy@k5: 0.03
Epoch: 3 	 Validation prec

CPU: 22.29
RAM %: 65.5
Epoch: 20 	 Training Loss: 0.088615
Epoch: 20 	 Validation precision@k5: 0.07, accuracy@k5: 0.04
Epoch: 20 	 Validation precision@k10: 0.09, accuracy@k10: 0.07
Epoch: 20 	 Validation precision@k15: 0.11, accuracy@k15: 0.11
Epoch: 20 	 Validation precision@k20: 0.14, accuracy@k20: 0.14
Epoch: 20 	 Validation precision@k25: 0.18, accuracy@k25: 0.18
Epoch: 20 	 Validation precision@k30: 0.21, accuracy@k30: 0.21
CPU: 22.03
RAM %: 53.7
Epoch: 21 	 Training Loss: 0.086547
Epoch: 21 	 Validation precision@k5: 0.07, accuracy@k5: 0.04
Epoch: 21 	 Validation precision@k10: 0.09, accuracy@k10: 0.07
Epoch: 21 	 Validation precision@k15: 0.11, accuracy@k15: 0.11
Epoch: 21 	 Validation precision@k20: 0.14, accuracy@k20: 0.14
Epoch: 21 	 Validation precision@k25: 0.18, accuracy@k25: 0.18
Epoch: 21 	 Validation precision@k30: 0.22, accuracy@k30: 0.22
CPU: 22.03
RAM %: 54.4
Epoch: 22 	 Training Loss: 0.084498
Epoch: 22 	 Validation precision@k5: 0.08, accuracy@k5: 0.04
Epoch: 22 

Epoch: 38 	 Validation precision@k30: 0.26, accuracy@k30: 0.26
CPU: 23.97
RAM %: 62.5
Epoch: 39 	 Training Loss: 0.053969
Epoch: 39 	 Validation precision@k5: 0.09, accuracy@k5: 0.04
Epoch: 39 	 Validation precision@k10: 0.11, accuracy@k10: 0.09
Epoch: 39 	 Validation precision@k15: 0.14, accuracy@k15: 0.13
Epoch: 39 	 Validation precision@k20: 0.17, accuracy@k20: 0.17
Epoch: 39 	 Validation precision@k25: 0.22, accuracy@k25: 0.22
Epoch: 39 	 Validation precision@k30: 0.26, accuracy@k30: 0.26
CPU: 23.99
RAM %: 62.7
Epoch: 40 	 Training Loss: 0.052539
Epoch: 40 	 Validation precision@k5: 0.09, accuracy@k5: 0.04
Epoch: 40 	 Validation precision@k10: 0.11, accuracy@k10: 0.09
Epoch: 40 	 Validation precision@k15: 0.14, accuracy@k15: 0.13
Epoch: 40 	 Validation precision@k20: 0.18, accuracy@k20: 0.17
Epoch: 40 	 Validation precision@k25: 0.22, accuracy@k25: 0.22
Epoch: 40 	 Validation precision@k30: 0.26, accuracy@k30: 0.26
CPU: 23.78
RAM %: 62.8
Epoch: 41 	 Training Loss: 0.051136
Epoch: 4

Epoch: 57 	 Validation precision@k25: 0.24, accuracy@k25: 0.24
Epoch: 57 	 Validation precision@k30: 0.29, accuracy@k30: 0.29
CPU: 25.93
RAM %: 67.3
Epoch: 58 	 Training Loss: 0.031287
Epoch: 58 	 Validation precision@k5: 0.10, accuracy@k5: 0.05
Epoch: 58 	 Validation precision@k10: 0.12, accuracy@k10: 0.10
Epoch: 58 	 Validation precision@k15: 0.15, accuracy@k15: 0.14
Epoch: 58 	 Validation precision@k20: 0.19, accuracy@k20: 0.19
Epoch: 58 	 Validation precision@k25: 0.24, accuracy@k25: 0.24
Epoch: 58 	 Validation precision@k30: 0.29, accuracy@k30: 0.29
CPU: 26.06
RAM %: 67.6
Epoch: 59 	 Training Loss: 0.030392
Epoch: 59 	 Validation precision@k5: 0.10, accuracy@k5: 0.05
Epoch: 59 	 Validation precision@k10: 0.12, accuracy@k10: 0.10
Epoch: 59 	 Validation precision@k15: 0.15, accuracy@k15: 0.15
Epoch: 59 	 Validation precision@k20: 0.19, accuracy@k20: 0.19
Epoch: 59 	 Validation precision@k25: 0.24, accuracy@k25: 0.24
Epoch: 59 	 Validation precision@k30: 0.29, accuracy@k30: 0.29
CPU:

Epoch: 76 	 Validation precision@k20: 0.20, accuracy@k20: 0.20
Epoch: 76 	 Validation precision@k25: 0.25, accuracy@k25: 0.25
Epoch: 76 	 Validation precision@k30: 0.30, accuracy@k30: 0.30
CPU: 28.28
RAM %: 51.6
Epoch: 77 	 Training Loss: 0.018270
Epoch: 77 	 Validation precision@k5: 0.10, accuracy@k5: 0.05
Epoch: 77 	 Validation precision@k10: 0.12, accuracy@k10: 0.10
Epoch: 77 	 Validation precision@k15: 0.16, accuracy@k15: 0.15
Epoch: 77 	 Validation precision@k20: 0.20, accuracy@k20: 0.20
Epoch: 77 	 Validation precision@k25: 0.25, accuracy@k25: 0.25
Epoch: 77 	 Validation precision@k30: 0.30, accuracy@k30: 0.30
CPU: 28.10
RAM %: 51.0
Epoch: 78 	 Training Loss: 0.017733
Epoch: 78 	 Validation precision@k5: 0.10, accuracy@k5: 0.05
Epoch: 78 	 Validation precision@k10: 0.12, accuracy@k10: 0.10
Epoch: 78 	 Validation precision@k15: 0.16, accuracy@k15: 0.15
Epoch: 78 	 Validation precision@k20: 0.20, accuracy@k20: 0.20
Epoch: 78 	 Validation precision@k25: 0.25, accuracy@k25: 0.25
Epoc

Epoch: 95 	 Validation precision@k15: 0.16, accuracy@k15: 0.15
Epoch: 95 	 Validation precision@k20: 0.20, accuracy@k20: 0.20
Epoch: 95 	 Validation precision@k25: 0.25, accuracy@k25: 0.25
Epoch: 95 	 Validation precision@k30: 0.30, accuracy@k30: 0.30
CPU: 42.32
RAM %: 59.5
Epoch: 96 	 Training Loss: 0.010425
Epoch: 96 	 Validation precision@k5: 0.10, accuracy@k5: 0.05
Epoch: 96 	 Validation precision@k10: 0.12, accuracy@k10: 0.10
Epoch: 96 	 Validation precision@k15: 0.16, accuracy@k15: 0.15
Epoch: 96 	 Validation precision@k20: 0.20, accuracy@k20: 0.20
Epoch: 96 	 Validation precision@k25: 0.25, accuracy@k25: 0.25
Epoch: 96 	 Validation precision@k30: 0.30, accuracy@k30: 0.30
CPU: 41.65
RAM %: 59.2
Epoch: 97 	 Training Loss: 0.010143
Epoch: 97 	 Validation precision@k5: 0.10, accuracy@k5: 0.05
Epoch: 97 	 Validation precision@k10: 0.12, accuracy@k10: 0.10
Epoch: 97 	 Validation precision@k15: 0.16, accuracy@k15: 0.15
Epoch: 97 	 Validation precision@k20: 0.20, accuracy@k20: 0.20
Epoc

In [None]:
for k in range(5, 31, 5):
    precision_k, accuracy_k = eval_model(naive_rnn, test_loader, k=k)
    print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.2f}, accuracy@k{k}: {accuracy_k:.2f}')