In [1]:
import os
import sys
import pickle
import psutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# # set seed
# seed = 24
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "data"

In [2]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'vids.pkl'), 'rb'))
targets = pickle.load(open(os.path.join(DATA_PATH,'targets.pkl'), 'rb'))
prob_targets = pickle.load(open(os.path.join(DATA_PATH,'prob_targets.pkl'), 'rb'))
prob_targets_allvisits = pickle.load(open(os.path.join(DATA_PATH,'prob_targets_allvisits.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'seqs.pkl'), 'rb'))
diags = pickle.load(open(os.path.join(DATA_PATH,'diags.pkl'), 'rb'))
categories = pickle.load(open(os.path.join(DATA_PATH,'categories.pkl'), 'rb'))
sub_categories = pickle.load(open(os.path.join(DATA_PATH,'subcategories.pkl'), 'rb'))
codes = pickle.load(open(os.path.join(DATA_PATH,'icd9.pkl'), 'rb'))
assert len(pids) == len(vids) == len(targets) == len(seqs)

In [3]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, targets):
        
        """
        TODO: Store `seqs`. to `self.x` and `hfs` to `self.y`.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        Do NOT permute the data.
        """
        self.x = seqs
        self.y = targets
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        return(len(self.x))
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        return (self.x[index], self.y[index])

In [4]:
dataset = CustomDataset(seqs, prob_targets_allvisits)

In [5]:
def collate_fn(data):
    """
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """
    sequences, targets = zip(*data)

#     y = torch.tensor(targets, dtype=torch.float)
    #import pdb; pdb.set_trace()
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    num_categories = len(targets[0][0])

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    sum_visits = sum(num_visits)
    
    x = torch.zeros((sum_visits - num_patients, max_num_codes), dtype=torch.int)
    y = torch.zeros((sum_visits - num_patients, num_categories), dtype=torch.float32)
    x_masks = torch.zeros((sum_visits - num_patients, max_num_codes), dtype=torch.bool)

#     for i_patient, patient in enumerate(sequences):   
#         for j_visit, visit in enumerate(patient):
#             """
#             TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
#             """ 
#             x[i_patient, j_visit] = torch.Tensor(visit)
#             #x_masks[i_patient, j_visit] = torch.Tensor(np.ones(num_codes, dtype=int))
#             x_masks[i_patient, j_visit] = 1
#     import pdb; pdb.set_trace()
    n = 0
    for i,patient in enumerate(sequences):
        for j,visit in enumerate(patient):
            if j == len(patient) - 1:
                break
            for k,code in enumerate(visit):
                x[n,k] = code
                x_masks[n,k] = 1
            n+=1
    n = 0
    for i,patient in enumerate(targets):
        for j,visit in enumerate(patient):
            if j == len(patient) - 1:
                break
            y[n] = torch.tensor(patient[j+1])
            n += 1
    
    
    return x, x_masks, y

In [6]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)
val_split = int(len(dataset)*0.10)

In [7]:
from torch.utils.data.dataset import random_split

train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)

lengths = [train_split, test_split, len(dataset) - (train_split + test_split)]
train_dataset, test_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of test dataset:", len(test_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6561
Length of test dataset: 1312
Length of val dataset: 875


In [8]:
from torch.utils.data import DataLoader

def load_data(train_dataset, test_dataset, val_dataset, collate_fn):
    
    '''
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 100
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=collate_fn,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             collate_fn=collate_fn,
                                             shuffle=False)
    
    return train_loader, test_loader, val_loader


train_loader, test_loader, val_loader = load_data(train_dataset, test_dataset, val_dataset, collate_fn)

In [9]:
def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [10]:
def indices_to_multihot(indices, masks, dim):
    #import pdb; pdb.set_trace()
    #indices = indices[masks.any(dim=1)]
    multihot = torch.zeros((indices.shape[0], dim), dtype=torch.int)
    for idx, row in enumerate(indices):
        y_idx = row[masks[idx]].unique()
        multihot[idx] = F.one_hot(y_idx.to(torch.int64), multihot.shape[1]).sum(0)
    return multihot

In [11]:
class BaselineMLP(nn.Module):
    
    """
    TODO: implement the naive RNN model above.
    """
    
    def __init__(self, num_codes, num_categories):
        super().__init__()
        """
        Arguments:
            num_codes: total number of diagnosis codes
        """
#         self.padding_idx = 0
        self.embedding = nn.Embedding(num_codes, embedding_dim=128, padding_idx=0)
        self.fc = nn.Linear(128, num_categories)
        self.softmax = nn.Softmax(dim=-1)
        
    
    def forward(self, x, masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
#         import pdb; pdb.set_trace()
#         num_codes = self.embedding.weight.shape[0]
#         x = indices_to_multihot(x, masks, num_codes)
#         x[~masks] = self.padding_idx
        x[masks] += 1
        x = self.embedding(x)
        x = x.sum(dim=1)
        #x = sum_embeddings_with_mask(x, masks)
        logits = self.fc(x)
#         logits = logits.mean(dim=1)
        probs = self.softmax(logits)
        return logits
    

# load the model here
baseline_mlp = BaselineMLP(num_codes = len(codes), num_categories=len(sub_categories))
baseline_mlp

BaselineMLP(
  (embedding): Embedding(4903, 128, padding_idx=0)
  (fc): Linear(in_features=128, out_features=184, bias=True)
  (softmax): Softmax(dim=-1)
)

In [12]:
criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(baseline_mlp.parameters(), lr=0.001)
optimizer = torch.optim.Adadelta(baseline_mlp.parameters(), weight_decay=0.001)

In [13]:
def eval_model(model, test_loader, k=15, n=-1):
    
    """
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    """
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    all_precision = []
    all_accuracy = []
    
    model.eval()
    with torch.no_grad():
        for x, masks, y in test_loader:
#             import pdb; pdb.set_trace()
            n_eval = y.shape[0] - 1 if n == -1 else n
            y_hat = model(x, masks)
            y_hat = F.softmax(y_hat, dim=-1)
#             num_labels = y_hat.shape[1]
#             num_categories = torch.count_nonzero(y, dim=1)
            nz_rows, nz_cols = torch.nonzero(y, as_tuple=True)
            k_correct = 0
#             predictions = 0
            total_precision = 0
            total_accuracy = 0
            for i in range(n_eval):
                visit_correct = 0
                y_true = nz_cols[nz_rows == i]
                _, y_pred = torch.topk(y_hat[i], k)
#                 for v in y_pred:
#                     if v in y_true:
#                         visit_correct += 1
                for v in y_true:
                    if v in y_pred:
                        visit_correct += 1
#                 predictions += len(y_true)
                visit_precision = visit_correct / min(k, len(y_true))
                visit_accuracy = visit_correct / len(y_true)
                #print(f'visit {i}: precision: {visit_precision:0.2f} accuracy: {visit_accuracy:0.2f}')
                k_correct += visit_correct
                total_precision += visit_precision
                total_accuracy += visit_accuracy
            #import pdb; pdb.set_trace()
#             precision_k = precision / k
#             accuracy_k = k_correct / predictions
            precision_k = total_precision / n_eval
            accuracy_k = total_accuracy / n_eval
            all_precision.append(precision_k)
            all_accuracy.append(accuracy_k)
            
#             y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
#             y_hat = (y_hat > 0.5).int()
#             y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
#             y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
#     import pdb; pdb.set_trace()
    total_precision_k = np.mean(all_precision)
    total_accuracy_k = np.mean(all_accuracy)
    return total_precision_k, total_accuracy_k

In [14]:
def train(model, train_loader, test_loader, n_epochs):
    """

    """
    base_cpu, base_ram = print_cpu_usage()
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
#         import pdb; pdb.set_trace()
        for x, masks, y in train_loader:

            y_hat = model(x, masks)
#             mask_idxs = masks.sum(dim=1) - 1
#             y_hat = y_hat[range(len(masks)), mask_idxs]
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print_cpu_usage()
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
        for k in range(5, 31, 5):
            precision_k, accuracy_k = eval_model(model, test_loader, k=k)
            print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')
    final_cpu, final_ram = print_cpu_usage()

In [15]:
def print_cpu_usage():
    load = psutil.getloadavg()[2]
    cpu_usage = (load/os.cpu_count()) * 100
    ram = psutil.virtual_memory()[2]
    print(f"CPU: {cpu_usage:0.2f}")
    print(f"RAM %: {ram}")
    return cpu_usage, ram

In [16]:
n_epochs = 100
%time train(baseline_mlp, train_loader, val_loader, n_epochs)

CPU: 25.40
RAM %: 66.8
CPU: 25.32
RAM %: 67.1
Epoch: 1 	 Training Loss: 4.786061
Epoch: 1 	 Validation precision@k5: 0.5410, accuracy@k5: 0.2711
Epoch: 1 	 Validation precision@k10: 0.5182, accuracy@k10: 0.4287
Epoch: 1 	 Validation precision@k15: 0.5592, accuracy@k15: 0.5373
Epoch: 1 	 Validation precision@k20: 0.6211, accuracy@k20: 0.6174
Epoch: 1 	 Validation precision@k25: 0.6826, accuracy@k25: 0.6823
Epoch: 1 	 Validation precision@k30: 0.7332, accuracy@k30: 0.7332
CPU: 25.35
RAM %: 67.2
Epoch: 2 	 Training Loss: 3.897191
Epoch: 2 	 Validation precision@k5: 0.5981, accuracy@k5: 0.3033
Epoch: 2 	 Validation precision@k10: 0.5672, accuracy@k10: 0.4689
Epoch: 2 	 Validation precision@k15: 0.6039, accuracy@k15: 0.5800
Epoch: 2 	 Validation precision@k20: 0.6674, accuracy@k20: 0.6633
Epoch: 2 	 Validation precision@k25: 0.7291, accuracy@k25: 0.7289
Epoch: 2 	 Validation precision@k30: 0.7773, accuracy@k30: 0.7773
CPU: 25.35
RAM %: 67.1
Epoch: 3 	 Training Loss: 3.790570
Epoch: 3 	 Vali

CPU: 25.09
RAM %: 67.2
Epoch: 19 	 Training Loss: 3.591748
Epoch: 19 	 Validation precision@k5: 0.6603, accuracy@k5: 0.3345
Epoch: 19 	 Validation precision@k10: 0.6193, accuracy@k10: 0.5136
Epoch: 19 	 Validation precision@k15: 0.6550, accuracy@k15: 0.6291
Epoch: 19 	 Validation precision@k20: 0.7186, accuracy@k20: 0.7143
Epoch: 19 	 Validation precision@k25: 0.7742, accuracy@k25: 0.7740
Epoch: 19 	 Validation precision@k30: 0.8184, accuracy@k30: 0.8184
CPU: 25.05
RAM %: 67.1
Epoch: 20 	 Training Loss: 3.586677
Epoch: 20 	 Validation precision@k5: 0.6616, accuracy@k5: 0.3356
Epoch: 20 	 Validation precision@k10: 0.6194, accuracy@k10: 0.5134
Epoch: 20 	 Validation precision@k15: 0.6553, accuracy@k15: 0.6294
Epoch: 20 	 Validation precision@k20: 0.7185, accuracy@k20: 0.7141
Epoch: 20 	 Validation precision@k25: 0.7759, accuracy@k25: 0.7757
Epoch: 20 	 Validation precision@k30: 0.8182, accuracy@k30: 0.8182
CPU: 25.20
RAM %: 67.0
Epoch: 21 	 Training Loss: 3.580750
Epoch: 21 	 Validation 

Epoch: 36 	 Validation precision@k30: 0.8340, accuracy@k30: 0.8340
CPU: 25.55
RAM %: 66.9
Epoch: 37 	 Training Loss: 3.517422
Epoch: 37 	 Validation precision@k5: 0.6851, accuracy@k5: 0.3481
Epoch: 37 	 Validation precision@k10: 0.6471, accuracy@k10: 0.5361
Epoch: 37 	 Validation precision@k15: 0.6776, accuracy@k15: 0.6504
Epoch: 37 	 Validation precision@k20: 0.7376, accuracy@k20: 0.7330
Epoch: 37 	 Validation precision@k25: 0.7901, accuracy@k25: 0.7899
Epoch: 37 	 Validation precision@k30: 0.8338, accuracy@k30: 0.8338
CPU: 25.47
RAM %: 66.9
Epoch: 38 	 Training Loss: 3.514594
Epoch: 38 	 Validation precision@k5: 0.6914, accuracy@k5: 0.3513
Epoch: 38 	 Validation precision@k10: 0.6479, accuracy@k10: 0.5370
Epoch: 38 	 Validation precision@k15: 0.6767, accuracy@k15: 0.6496
Epoch: 38 	 Validation precision@k20: 0.7352, accuracy@k20: 0.7307
Epoch: 38 	 Validation precision@k25: 0.7889, accuracy@k25: 0.7886
Epoch: 38 	 Validation precision@k30: 0.8340, accuracy@k30: 0.8340
CPU: 25.61
RAM 

Epoch: 54 	 Validation precision@k25: 0.7944, accuracy@k25: 0.7941
Epoch: 54 	 Validation precision@k30: 0.8364, accuracy@k30: 0.8364
CPU: 26.67
RAM %: 65.8
Epoch: 55 	 Training Loss: 3.491076
Epoch: 55 	 Validation precision@k5: 0.7005, accuracy@k5: 0.3562
Epoch: 55 	 Validation precision@k10: 0.6564, accuracy@k10: 0.5441
Epoch: 55 	 Validation precision@k15: 0.6838, accuracy@k15: 0.6564
Epoch: 55 	 Validation precision@k20: 0.7408, accuracy@k20: 0.7361
Epoch: 55 	 Validation precision@k25: 0.7920, accuracy@k25: 0.7918
Epoch: 55 	 Validation precision@k30: 0.8358, accuracy@k30: 0.8358
CPU: 26.62
RAM %: 65.9
Epoch: 56 	 Training Loss: 3.491958
Epoch: 56 	 Validation precision@k5: 0.6995, accuracy@k5: 0.3557
Epoch: 56 	 Validation precision@k10: 0.6589, accuracy@k10: 0.5456
Epoch: 56 	 Validation precision@k15: 0.6870, accuracy@k15: 0.6594
Epoch: 56 	 Validation precision@k20: 0.7427, accuracy@k20: 0.7380
Epoch: 56 	 Validation precision@k25: 0.7939, accuracy@k25: 0.7936
Epoch: 56 	 Val

Epoch: 72 	 Validation precision@k20: 0.7421, accuracy@k20: 0.7375
Epoch: 72 	 Validation precision@k25: 0.7946, accuracy@k25: 0.7943
Epoch: 72 	 Validation precision@k30: 0.8365, accuracy@k30: 0.8365
CPU: 28.66
RAM %: 60.1
Epoch: 73 	 Training Loss: 3.484259
Epoch: 73 	 Validation precision@k5: 0.7002, accuracy@k5: 0.3555
Epoch: 73 	 Validation precision@k10: 0.6573, accuracy@k10: 0.5443
Epoch: 73 	 Validation precision@k15: 0.6883, accuracy@k15: 0.6607
Epoch: 73 	 Validation precision@k20: 0.7431, accuracy@k20: 0.7385
Epoch: 73 	 Validation precision@k25: 0.7941, accuracy@k25: 0.7938
Epoch: 73 	 Validation precision@k30: 0.8359, accuracy@k30: 0.8359
CPU: 28.67
RAM %: 60.4
Epoch: 74 	 Training Loss: 3.484494
Epoch: 74 	 Validation precision@k5: 0.7033, accuracy@k5: 0.3578
Epoch: 74 	 Validation precision@k10: 0.6601, accuracy@k10: 0.5470
Epoch: 74 	 Validation precision@k15: 0.6892, accuracy@k15: 0.6616
Epoch: 74 	 Validation precision@k20: 0.7425, accuracy@k20: 0.7379
Epoch: 74 	 Val

Epoch: 90 	 Validation precision@k15: 0.6880, accuracy@k15: 0.6607
Epoch: 90 	 Validation precision@k20: 0.7420, accuracy@k20: 0.7374
Epoch: 90 	 Validation precision@k25: 0.7950, accuracy@k25: 0.7948
Epoch: 90 	 Validation precision@k30: 0.8365, accuracy@k30: 0.8365
CPU: 29.40
RAM %: 63.4
Epoch: 91 	 Training Loss: 3.482944
Epoch: 91 	 Validation precision@k5: 0.7066, accuracy@k5: 0.3594
Epoch: 91 	 Validation precision@k10: 0.6599, accuracy@k10: 0.5465
Epoch: 91 	 Validation precision@k15: 0.6887, accuracy@k15: 0.6613
Epoch: 91 	 Validation precision@k20: 0.7424, accuracy@k20: 0.7378
Epoch: 91 	 Validation precision@k25: 0.7947, accuracy@k25: 0.7945
Epoch: 91 	 Validation precision@k30: 0.8364, accuracy@k30: 0.8364
CPU: 29.38
RAM %: 63.2
Epoch: 92 	 Training Loss: 3.483685
Epoch: 92 	 Validation precision@k5: 0.7064, accuracy@k5: 0.3589
Epoch: 92 	 Validation precision@k10: 0.6591, accuracy@k10: 0.5462
Epoch: 92 	 Validation precision@k15: 0.6884, accuracy@k15: 0.6610
Epoch: 92 	 Val

In [23]:
for k in range(5, 31, 5):
    precision_k, accuracy_k = eval_model(baseline_mlp, test_loader, k=k)
    print(f'Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')

Validation precision@k5: 0.6980, accuracy@k5: 0.3523
Validation precision@k10: 0.6512, accuracy@k10: 0.5422
Validation precision@k15: 0.6808, accuracy@k15: 0.6541
Validation precision@k20: 0.7370, accuracy@k20: 0.7331
Validation precision@k25: 0.7892, accuracy@k25: 0.7890
Validation precision@k30: 0.8336, accuracy@k30: 0.8336
