In [22]:
import os
import sys
import pickle
import psutil
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Define data path
DATA_PATH = "data/"
CHECKPOINT_PATH = "models/"

In [23]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH,'vids.pkl'), 'rb'))
targets = pickle.load(open(os.path.join(DATA_PATH,'targets.pkl'), 'rb'))
prob_targets = pickle.load(open(os.path.join(DATA_PATH,'prob_targets.pkl'), 'rb'))
prob_targets_allvisits = pickle.load(open(os.path.join(DATA_PATH,'prob_targets_allvisits.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'seqs.pkl'), 'rb'))
diags = pickle.load(open(os.path.join(DATA_PATH,'diags.pkl'), 'rb'))
categories = pickle.load(open(os.path.join(DATA_PATH,'categories.pkl'), 'rb'))
sub_categories = pickle.load(open(os.path.join(DATA_PATH,'subcategories.pkl'), 'rb'))
codes = pickle.load(open(os.path.join(DATA_PATH,'icd9.pkl'), 'rb'))
assert len(pids) == len(vids) == len(targets) == len(seqs)

In [24]:
embedding_matrix = torch.load(os.path.join(DATA_PATH, 'embedding_matrix.pt'))

In [25]:
class CustomDataset(Dataset):
    def __init__(self, seqs, targets):
        self.x = seqs
        self.y = targets
    
    def __len__(self):
        return(len(self.x))
    
    def __getitem__(self, index):
        return (self.x[index], self.y[index])

In [26]:
dataset = CustomDataset(seqs, prob_targets_allvisits)

In [27]:
def collate_fn(data):
    """
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# total visits excluding last visit per patient, max # diagnosis codes) of
            type torch.long
        x_masks: a tensor of shape (# total visits excluding last visit per patient, max # diagnosis codes)
            of type torch.bool
        y: a tensor of shape (# total visits excluding first visit per patient, num higher level categories
            to predict) of type torch.float
    """
    sequences, targets = zip(*data)
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]
    num_categories = len(targets[0][0])

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    sum_visits = sum(num_visits)
    
    x = torch.zeros((sum_visits - num_patients, max_num_codes), dtype=torch.long)
    y = torch.zeros((sum_visits - num_patients, num_categories), dtype=torch.float)
    x_masks = torch.zeros((sum_visits - num_patients, max_num_codes), dtype=torch.bool)

    n = 0
    for i,patient in enumerate(sequences):
        for j,visit in enumerate(patient):
            if j == len(patient) - 1:
                break
            for k,code in enumerate(visit):
                x[n,k] = code
                x_masks[n,k] = 1
            n+=1
    n = 0
    for i,patient in enumerate(targets):
        for j,visit in enumerate(patient):
            if j == len(patient) - 1:
                break
            y[n] = torch.tensor(patient[j+1])
            n += 1
    
    
    return x, x_masks, y

In [28]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)
val_split = int(len(dataset)*0.10)

In [29]:
from torch.utils.data.dataset import random_split

train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)*0.15)

lengths = [train_split, test_split, len(dataset) - (train_split + test_split)]
train_dataset, test_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of test dataset:", len(test_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6561
Length of test dataset: 1312
Length of val dataset: 875


In [30]:
def load_data(train_dataset, test_dataset, val_dataset, collate_fn):
    '''
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        test dataset: test dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, test_loader, val_loader: train, test and validation dataloaders
    '''
    batch_size = 100
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=collate_fn,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             collate_fn=collate_fn,
                                             shuffle=False)
    
    return train_loader, test_loader, val_loader


train_loader, test_loader, val_loader = load_data(train_dataset, test_dataset, val_dataset, collate_fn)

In [31]:
def indices_to_multihot(indices, masks, dim):
    multihot = torch.zeros((indices.shape[0], dim), dtype=torch.float)
    for idx, row in enumerate(indices):
        y_idx = row[masks[idx]].unique()
        multihot[idx] = F.one_hot(y_idx.to(torch.int64), multihot.shape[1]).sum(0)
    return multihot

In [32]:
class EnhancedMLP(nn.Module):
    
    def __init__(self, num_codes, num_categories, embedding_matrix):
        super().__init__()
        """
        Arguments:
            num_codes: total number of diagnosis codes
            num_categories: number of higher level categories to predict
            embedding_matrix: learned embedding matrix of icd9 descriptions
        """
        self.embedding = nn.Linear(4903, 300)
        self.embedding.weight.data = embedding_matrix
        self.fc = nn.Linear(300, num_categories)
    
    def forward(self, x, masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)
        Outputs:
            logits: logits of shape (batch_size, # diagnosis codes)
        """
        x = indices_to_multihot(x, masks, 4903)
        x = self.embedding(x)
        x = torch.tanh(x)
        logits = self.fc(x)
        return logits
    
# load the model here
enhanced_mlp = EnhancedMLP(num_codes = len(codes), num_categories=len(sub_categories), embedding_matrix=embedding_matrix)
enhanced_mlp

EnhancedMLP(
  (embedding): Linear(in_features=4903, out_features=300, bias=True)
  (fc): Linear(in_features=300, out_features=184, bias=True)
)

In [33]:
criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(baseline_mlp.parameters(), lr=0.001)
optimizer = torch.optim.Adadelta(enhanced_mlp.parameters(), weight_decay=0.001)

In [34]:
def eval_model(model, test_loader, k=15, n=-1):    
    """
    Arguments:
        model: the EnhancedMLP model
        test_loader: validation dataloader
        k: value for top k predictions
        n: num of records to evaluate in the batch, value -1 evaulates all records
        
    Outputs:
        precision_k: visit-level precison@k
        accuracy_k: code-level accuracy@k
    """
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    all_precision = []
    all_accuracy = []
    
    model.eval()
    with torch.no_grad():
        for x, masks, y in test_loader:
            n_eval = y.shape[0] - 1 if n == -1 else n
            y_hat = model(x, masks)
            y_hat = F.softmax(y_hat, dim=-1)
            nz_rows, nz_cols = torch.nonzero(y, as_tuple=True)
            k_correct = 0
            total_precision = 0
            total_accuracy = 0
            for i in range(n_eval):
                visit_correct = 0
                y_true = nz_cols[nz_rows == i]
                _, y_pred = torch.topk(y_hat[i], k)
                    
                for v in y_true:
                    if v in y_pred:
                        visit_correct += 1
                        
                visit_precision = visit_correct / min(k, len(y_true))
                visit_accuracy = visit_correct / len(y_true)
                k_correct += visit_correct
                total_precision += visit_precision
                total_accuracy += visit_accuracy

            precision_k = total_precision / n_eval
            accuracy_k = total_accuracy / n_eval
            all_precision.append(precision_k)
            all_accuracy.append(accuracy_k)
            
    total_precision_k = np.mean(all_precision)
    total_accuracy_k = np.mean(all_accuracy)
    return total_precision_k, total_accuracy_k

In [35]:
def train(model, train_loader, test_loader, n_epochs):
    """
    Arguments:
        model: the EnhancedMLP model
        train_loader: training dataloader
        test_loader: validation dataloader
        n_epochs: num epochs to train
    """
    max_cpu, max_ram = print_cpu_usage()
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, y in train_loader:
            y_hat = model(x, masks)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        cpu, ram = print_cpu_usage()
        max_cpu = cpu if cpu > max_cpu else max_cpu
        max_ram = ram if ram > max_ram else max_ram
        print(f'Epoch: {epoch+1} \t Training Loss: {train_loss:.6f}')
        for k in range(5, 31, 5):
            precision_k, accuracy_k = eval_model(model, test_loader, k=k)
            print(f'Epoch: {epoch+1} \t Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')
    final_cpu, final_ram = print_cpu_usage()
    print(f"Max CPU usage: {max_cpu:.3f}\tMax RAM % usage: {max_ram}")

In [36]:
def print_cpu_usage():
    load = psutil.getloadavg()[2]
    cpu_usage = (load/os.cpu_count()) * 100
    ram = psutil.virtual_memory()[2]
    print(f"CPU: {cpu_usage:0.2f}")
    print(f"RAM %: {ram}")
    return cpu_usage, ram

In [16]:
n_epochs = 
%time train(enhanced_mlp, train_loader, val_loader, n_epochs)

CPU: 12.53
RAM %: 57.0
CPU: 12.60
RAM %: 57.0
Epoch: 1 	 Training Loss: 4.054320
Epoch: 1 	 Validation precision@k5: 0.4018, accuracy@k5: 0.1895
Epoch: 1 	 Validation precision@k10: 0.3749, accuracy@k10: 0.3063
Epoch: 1 	 Validation precision@k15: 0.4664, accuracy@k15: 0.4474
Epoch: 1 	 Validation precision@k20: 0.6029, accuracy@k20: 0.5993
Epoch: 1 	 Validation precision@k25: 0.6707, accuracy@k25: 0.6704
Epoch: 1 	 Validation precision@k30: 0.7468, accuracy@k30: 0.7468
CPU: 13.00
RAM %: 57.0
Epoch: 2 	 Training Loss: 4.016588
Epoch: 2 	 Validation precision@k5: 0.3892, accuracy@k5: 0.1892
Epoch: 2 	 Validation precision@k10: 0.4677, accuracy@k10: 0.3874
Epoch: 2 	 Validation precision@k15: 0.5084, accuracy@k15: 0.4899
Epoch: 2 	 Validation precision@k20: 0.5947, accuracy@k20: 0.5912
Epoch: 2 	 Validation precision@k25: 0.6644, accuracy@k25: 0.6642
Epoch: 2 	 Validation precision@k30: 0.7134, accuracy@k30: 0.7134
CPU: 13.10
RAM %: 57.1
Epoch: 3 	 Training Loss: 4.003275
Epoch: 3 	 Vali

CPU: 14.69
RAM %: 57.1
Epoch: 19 	 Training Loss: 3.725864
Epoch: 19 	 Validation precision@k5: 0.5816, accuracy@k5: 0.2935
Epoch: 19 	 Validation precision@k10: 0.5431, accuracy@k10: 0.4542
Epoch: 19 	 Validation precision@k15: 0.5923, accuracy@k15: 0.5711
Epoch: 19 	 Validation precision@k20: 0.6649, accuracy@k20: 0.6610
Epoch: 19 	 Validation precision@k25: 0.7299, accuracy@k25: 0.7296
Epoch: 19 	 Validation precision@k30: 0.7853, accuracy@k30: 0.7853
CPU: 14.87
RAM %: 57.4
Epoch: 20 	 Training Loss: 3.709797
Epoch: 20 	 Validation precision@k5: 0.5887, accuracy@k5: 0.3034
Epoch: 20 	 Validation precision@k10: 0.5664, accuracy@k10: 0.4768
Epoch: 20 	 Validation precision@k15: 0.6095, accuracy@k15: 0.5880
Epoch: 20 	 Validation precision@k20: 0.6714, accuracy@k20: 0.6677
Epoch: 20 	 Validation precision@k25: 0.7323, accuracy@k25: 0.7321
Epoch: 20 	 Validation precision@k30: 0.7811, accuracy@k30: 0.7811
CPU: 14.86
RAM %: 57.4
Epoch: 21 	 Training Loss: 3.695081
Epoch: 21 	 Validation 

Epoch: 36 	 Validation precision@k30: 0.8194, accuracy@k30: 0.8194
CPU: 16.00
RAM %: 52.6
Epoch: 37 	 Training Loss: 3.537119
Epoch: 37 	 Validation precision@k5: 0.6896, accuracy@k5: 0.3543
Epoch: 37 	 Validation precision@k10: 0.6381, accuracy@k10: 0.5353
Epoch: 37 	 Validation precision@k15: 0.6674, accuracy@k15: 0.6433
Epoch: 37 	 Validation precision@k20: 0.7254, accuracy@k20: 0.7214
Epoch: 37 	 Validation precision@k25: 0.7795, accuracy@k25: 0.7792
Epoch: 37 	 Validation precision@k30: 0.8211, accuracy@k30: 0.8211
CPU: 16.03
RAM %: 52.7
Epoch: 38 	 Training Loss: 3.529459
Epoch: 38 	 Validation precision@k5: 0.6891, accuracy@k5: 0.3530
Epoch: 38 	 Validation precision@k10: 0.6396, accuracy@k10: 0.5368
Epoch: 38 	 Validation precision@k15: 0.6690, accuracy@k15: 0.6451
Epoch: 38 	 Validation precision@k20: 0.7244, accuracy@k20: 0.7205
Epoch: 38 	 Validation precision@k25: 0.7772, accuracy@k25: 0.7769
Epoch: 38 	 Validation precision@k30: 0.8185, accuracy@k30: 0.8185
CPU: 16.09
RAM 

Epoch: 54 	 Validation precision@k25: 0.7884, accuracy@k25: 0.7882
Epoch: 54 	 Validation precision@k30: 0.8287, accuracy@k30: 0.8287
CPU: 17.18
RAM %: 53.5
Epoch: 55 	 Training Loss: 3.477466
Epoch: 55 	 Validation precision@k5: 0.7064, accuracy@k5: 0.3602
Epoch: 55 	 Validation precision@k10: 0.6543, accuracy@k10: 0.5470
Epoch: 55 	 Validation precision@k15: 0.6818, accuracy@k15: 0.6565
Epoch: 55 	 Validation precision@k20: 0.7354, accuracy@k20: 0.7312
Epoch: 55 	 Validation precision@k25: 0.7866, accuracy@k25: 0.7863
Epoch: 55 	 Validation precision@k30: 0.8294, accuracy@k30: 0.8294
CPU: 17.16
RAM %: 53.5
Epoch: 56 	 Training Loss: 3.477667
Epoch: 56 	 Validation precision@k5: 0.7014, accuracy@k5: 0.3583
Epoch: 56 	 Validation precision@k10: 0.6560, accuracy@k10: 0.5481
Epoch: 56 	 Validation precision@k15: 0.6823, accuracy@k15: 0.6570
Epoch: 56 	 Validation precision@k20: 0.7363, accuracy@k20: 0.7321
Epoch: 56 	 Validation precision@k25: 0.7878, accuracy@k25: 0.7876
Epoch: 56 	 Val

Epoch: 72 	 Validation precision@k20: 0.7403, accuracy@k20: 0.7361
Epoch: 72 	 Validation precision@k25: 0.7896, accuracy@k25: 0.7894
Epoch: 72 	 Validation precision@k30: 0.8288, accuracy@k30: 0.8288
CPU: 18.12
RAM %: 53.3
Epoch: 73 	 Training Loss: 3.469834
Epoch: 73 	 Validation precision@k5: 0.7026, accuracy@k5: 0.3580
Epoch: 73 	 Validation precision@k10: 0.6574, accuracy@k10: 0.5497
Epoch: 73 	 Validation precision@k15: 0.6858, accuracy@k15: 0.6601
Epoch: 73 	 Validation precision@k20: 0.7382, accuracy@k20: 0.7339
Epoch: 73 	 Validation precision@k25: 0.7910, accuracy@k25: 0.7907
Epoch: 73 	 Validation precision@k30: 0.8312, accuracy@k30: 0.8312
CPU: 18.23
RAM %: 53.4
Epoch: 74 	 Training Loss: 3.469497
Epoch: 74 	 Validation precision@k5: 0.7077, accuracy@k5: 0.3606
Epoch: 74 	 Validation precision@k10: 0.6592, accuracy@k10: 0.5508
Epoch: 74 	 Validation precision@k15: 0.6852, accuracy@k15: 0.6597
Epoch: 74 	 Validation precision@k20: 0.7371, accuracy@k20: 0.7328
Epoch: 74 	 Val

Epoch: 90 	 Validation precision@k15: 0.6882, accuracy@k15: 0.6623
Epoch: 90 	 Validation precision@k20: 0.7389, accuracy@k20: 0.7346
Epoch: 90 	 Validation precision@k25: 0.7882, accuracy@k25: 0.7879
Epoch: 90 	 Validation precision@k30: 0.8285, accuracy@k30: 0.8285
CPU: 19.03
RAM %: 53.9
Epoch: 91 	 Training Loss: 3.466563
Epoch: 91 	 Validation precision@k5: 0.7053, accuracy@k5: 0.3602
Epoch: 91 	 Validation precision@k10: 0.6586, accuracy@k10: 0.5505
Epoch: 91 	 Validation precision@k15: 0.6851, accuracy@k15: 0.6594
Epoch: 91 	 Validation precision@k20: 0.7371, accuracy@k20: 0.7328
Epoch: 91 	 Validation precision@k25: 0.7893, accuracy@k25: 0.7891
Epoch: 91 	 Validation precision@k30: 0.8289, accuracy@k30: 0.8289
CPU: 19.17
RAM %: 54.0
Epoch: 92 	 Training Loss: 3.465803
Epoch: 92 	 Validation precision@k5: 0.7014, accuracy@k5: 0.3578
Epoch: 92 	 Validation precision@k10: 0.6573, accuracy@k10: 0.5496
Epoch: 92 	 Validation precision@k15: 0.6886, accuracy@k15: 0.6629
Epoch: 92 	 Val

In [17]:
for k in range(5, 31, 5):
    precision_k, accuracy_k = eval_model(enhanced_mlp, test_loader, k=k)
    print(f'Validation precision@k{k}: {precision_k:.4f}, accuracy@k{k}: {accuracy_k:.4f}')

Validation precision@k5: 0.7004, accuracy@k5: 0.3605
Validation precision@k10: 0.6480, accuracy@k10: 0.5432
Validation precision@k15: 0.6779, accuracy@k15: 0.6532
Validation precision@k20: 0.7303, accuracy@k20: 0.7263
Validation precision@k25: 0.7817, accuracy@k25: 0.7815
Validation precision@k30: 0.8252, accuracy@k30: 0.8252


In [21]:
torch.save(enhanced_mlp, os.path.join(CHECKPOINT_PATH, "EnhancedMLP_100.pth"))