In [1]:
from __future__ import division
import argparse
import math
import numpy as np

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import pickle
from tqdm import tqdm

from sklearn.metrics import roc_auc_score, classification_report

### Model Construction

In [8]:
class RNN(nn.Module):
    def __init__(self, epochs=5, batchsize=50, vocabsize=5, embsize=100):
        super(RNN, self).__init__()
        self.epochs = 5
        self.batchsize = batchsize
        self.vocabsize = vocabsize
        self.embsize = embsize

        self.emb_icd = nn.Linear(vocabsize_icd, embsize_icd)
        self.emb_meds = nn.Linear(vocabsize_meds, embsize_meds)
        self.emb_labs = nn.Linear(vocabsize_labs, embsize_labs)

        self.rnn = nn.LSTM(input_size=embsize, hidden_size=embsize, num_layers=1)
        self.out = nn.Linear(embsize, 1)
        self.sig = nn.Sigmoid()

    def forward(self, input_icd, input_med, input_lab, hidden=None, force=True, steps=0):
        if force or steps == 0: steps = len(input_icd)
        outputs = Variable(torch.zeros(steps, 1, 1))

        input_icd = F.relu(self.emb_icd(input_icd))
        input_med = F.relu(self.emb_meds(input_med))
        input_lab = F.relu(self.emb_labs(input_lab))

        inputs = torch.cat((input_icd, input_med, input_lab),1)

        inputs = inputs.view(inputs.size()[0],1,inputs.size()[1])
        outputs, hidden = self.rnn(inputs, hidden)
        outputs = self.out(outputs)
        return outputs.squeeze(), hidden

    def predict(self, input_icd, input_med, input_lab):
        out, hid = self.forward(input_icd, input_med, input_lab, None)
        return self.sig(out[-1]).data

### Data loading

In [9]:
DATA_PATH = 'E:/CS_Master_Degree_UIUC/CS598_DeepLearning_for_Health_Data/Project/paper290/MIMIC_Processed/'

n_epochs = 5
vocabsize_icd = 942
vocabsize_meds = 3202
vocabsize_labs = 284 #all 681
vocabsize = vocabsize_icd+vocabsize_meds+vocabsize_labs

embsize_icd = 50
embsize_meds = 75
embsize_labs = 50
embsize = embsize_icd + embsize_labs + embsize_meds

input_seqs_icd = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.3digitICD9.seqs', 'rb'))
input_seqs_meds = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.meds.seqs', 'rb'))
input_seqs_labs = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.abnlabs.seqs', 'rb'))
input_seqs_fullicd = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.seqs', 'rb'))

labels = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.morts', 'rb'))

### Model Training

In [10]:
trainratio = 0.7
validratio = 0.1
testratio = 0.2

trainlindex = int(len(input_seqs_icd)*trainratio)
validlindex = int(len(input_seqs_icd)*(trainratio + validratio))

print('Data prepared..')

def convert_to_one_hot(code_seqs, len):
    new_code_seqs = []
    for code_seq in code_seqs:
        one_hot_vec = np.zeros(len)
        for code in code_seq:
            one_hot_vec[code] = 1
        new_code_seqs.append(one_hot_vec)
    return np.array(new_code_seqs)

Data prepared..


In [14]:
import time

start = time.process_time()

print('Starting training..')

batchsize = 50

best_aucrocs = []
for run in range(10):
    print('Run', run)

    perm = np.random.permutation(len(input_seqs_icd))
    rinput_seqs_icd = [input_seqs_icd[i] for i in perm]
    rinput_seqs_meds = [input_seqs_meds[i] for i in perm]
    rinput_seqs_labs = [input_seqs_labs[i] for i in perm]
    rinput_seqs_fullicd = [input_seqs_fullicd[i] for i in perm]
    rlabels = [labels[i] for i in perm]
    rlabels = torch.tensor(rlabels)

    train_input_seqs_icd = rinput_seqs_icd[:trainlindex]
    train_input_seqs_meds = rinput_seqs_meds[:trainlindex]
    train_input_seqs_labs = rinput_seqs_labs[:trainlindex]
    train_labels = rlabels[:trainlindex]
    train_labels = train_labels.reshape(train_labels.shape[0],1)

    valid_input_seqs_icd = rinput_seqs_icd[trainlindex:validlindex]
    valid_input_seqs_meds = rinput_seqs_meds[trainlindex:validlindex]
    valid_input_seqs_labs = rinput_seqs_labs[trainlindex:validlindex]
    valid_labels = rlabels[trainlindex:validlindex]

    test_input_seqs_icd = rinput_seqs_icd[validlindex:]
    test_input_seqs_meds = rinput_seqs_meds[validlindex:]
    test_input_seqs_labs = rinput_seqs_labs[validlindex:]
    test_input_seqs_fullicd = rinput_seqs_fullicd[validlindex:]

    test_labels = rlabels[validlindex:]

    n_iters = len(train_input_seqs_icd)

    model = RNN(n_epochs, 1, vocabsize, embsize)
    criterion = nn.BCEWithLogitsLoss(reduction='sum')
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    aucrocs = []

    for epoch in range(n_epochs):

        epoch_loss = 0

        print('Epoch', (epoch+1))

        for i in (range(0, n_iters, batchsize)):
            batch_icd = train_input_seqs_icd[i:i+batchsize]
            batch_meds = train_input_seqs_meds[i:i+batchsize]
            batch_labs = train_input_seqs_labs[i:i+batchsize]

            batch_train_labels = train_labels[i:i+batchsize]

            optimizer.zero_grad()
            losses = []

            for j in range(len(batch_icd)):
                icd_onehot = convert_to_one_hot(batch_icd[j], vocabsize_icd)
                med_onehot = convert_to_one_hot(batch_meds[j], vocabsize_meds)
                lab_onehot = convert_to_one_hot(batch_labs[j], vocabsize_labs)

                icd_inputs = Variable(torch.from_numpy(icd_onehot).float())
                med_inputs = Variable(torch.from_numpy(med_onehot).float())
                lab_inputs = Variable(torch.from_numpy(lab_onehot).float())

                targets = Variable(batch_train_labels[j].float())
                
                # Use teacher forcing 50% of the time
                force = random.random() < 0.5
                outputs, hidden = model(icd_inputs, med_inputs, lab_inputs, None, force)

                #print outputs[-1], targets
                loss = criterion(outputs[-1].view(1), targets)
                losses.append(loss)
#                 losses.append(criterion(outputs[-1], targets))

            loss = sum(losses)/len(batch_icd)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.data

        #print(epoch, epoch_loss)

        ## Validation phase
        vpredictions = np.zeros(len(valid_input_seqs_icd))
        for i in range(len(valid_input_seqs_icd)):
            test_input_icd = Variable(torch.from_numpy(convert_to_one_hot(valid_input_seqs_icd[i], vocabsize_icd)).float())
            test_input_med = Variable(torch.from_numpy(convert_to_one_hot(valid_input_seqs_meds[i], vocabsize_meds)).float())
            test_input_lab = Variable(torch.from_numpy(convert_to_one_hot(valid_input_seqs_labs[i], vocabsize_labs)).float())
            vpredictions[i] = model.predict(test_input_icd, test_input_med, test_input_lab)

        print("Validation AUC_ROC: ", roc_auc_score(valid_labels, vpredictions))

        ## Testing phase
        predictions = np.zeros(len(test_input_seqs_icd))

        # ICD_wise_corr = np.zeros(5)
        # meds_wise_corr = np.zeros(5)
        # labs_wise_corr = np.zeros(5)
        # ICD_wise_tot = np.zeros(5)
        # meds_wise_tot = np.zeros(5)
        # labs_wise_tot = np.zeros(5)

        for i in range(len(test_input_seqs_icd)):
            test_input_icd = Variable(torch.from_numpy(convert_to_one_hot(test_input_seqs_icd[i], vocabsize_icd)).float())
            test_input_med = Variable(torch.from_numpy(convert_to_one_hot(test_input_seqs_meds[i], vocabsize_meds)).float())
            test_input_lab = Variable(torch.from_numpy(convert_to_one_hot(test_input_seqs_labs[i], vocabsize_labs)).float())
            predictions[i] = model.predict(test_input_icd, test_input_med, test_input_lab)

            # ICD_wise_corr[get_avg(test_input_seqs_icd[i], 'i')] += int((predictions[i]>0.5)*1 == test_labels[i])
            # ICD_wise_tot[get_avg(test_input_seqs_icd[i], 'i')] += 1

            # meds_wise_corr[get_avg(test_input_seqs_meds[i], 'm')] += int((predictions[i]>0.5)*1 == test_labels[i])
            # meds_wise_tot[get_avg(test_input_seqs_meds[i], 'm')] += 1

            # labs_wise_corr[get_avg(test_input_seqs_labs[i], 'l')] += int((predictions[i]>0.5)*1 == test_labels[i])
            # labs_wise_tot[get_avg(test_input_seqs_labs[i], 'l')] += 1

        print("Test AUC_ROC: ", roc_auc_score(test_labels, predictions))

        aucrocs.append(roc_auc_score(test_labels, predictions))
        actual_predictions = (predictions>0.5)*1
        print(classification_report(test_labels, actual_predictions))

    best_aucrocs.append(max(aucrocs))

print("Average AUCROC:", np.mean(best_aucrocs), "+/-", np.std(best_aucrocs))

end = time.process_time()
print('The training is complete!')
print('The time used is: ', end - start)

Starting training..
Run 0
Epoch 1
Validation AUC_ROC:  0.8531713486375656
Test AUC_ROC:  0.8594790306599623
              precision    recall  f1-score   support

           0       0.82      0.81      0.82       923
           1       0.70      0.73      0.72       585

    accuracy                           0.78      1508
   macro avg       0.76      0.77      0.77      1508
weighted avg       0.78      0.78      0.78      1508

Epoch 2
Validation AUC_ROC:  0.8377265933392096
Test AUC_ROC:  0.8518691372429184
              precision    recall  f1-score   support

           0       0.82      0.79      0.80       923
           1       0.69      0.73      0.71       585

    accuracy                           0.77      1508
   macro avg       0.75      0.76      0.76      1508
weighted avg       0.77      0.77      0.77      1508

Epoch 3
Validation AUC_ROC:  0.8140114206875407
Test AUC_ROC:  0.817923715865211
              precision    recall  f1-score   support

           0       0

Validation AUC_ROC:  0.8550256178420735
Test AUC_ROC:  0.863828218458691
              precision    recall  f1-score   support

           0       0.76      0.95      0.84       946
           1       0.84      0.49      0.62       562

    accuracy                           0.78      1508
   macro avg       0.80      0.72      0.73      1508
weighted avg       0.79      0.78      0.76      1508

Epoch 2
Validation AUC_ROC:  0.8493595539481615
Test AUC_ROC:  0.852138240804135
              precision    recall  f1-score   support

           0       0.78      0.92      0.84       946
           1       0.81      0.56      0.66       562

    accuracy                           0.79      1508
   macro avg       0.80      0.74      0.75      1508
weighted avg       0.79      0.79      0.78      1508

Epoch 3
Validation AUC_ROC:  0.827448764315853
Test AUC_ROC:  0.8322116722969161
              precision    recall  f1-score   support

           0       0.78      0.90      0.84       946
  

Validation AUC_ROC:  0.8591376609222806
Test AUC_ROC:  0.8546267573830448
              precision    recall  f1-score   support

           0       0.84      0.76      0.80       922
           1       0.67      0.77      0.72       586

    accuracy                           0.77      1508
   macro avg       0.76      0.77      0.76      1508
weighted avg       0.78      0.77      0.77      1508

Epoch 2
Validation AUC_ROC:  0.8548691051926526
Test AUC_ROC:  0.842166828307656
              precision    recall  f1-score   support

           0       0.80      0.81      0.81       922
           1       0.70      0.69      0.69       586

    accuracy                           0.76      1508
   macro avg       0.75      0.75      0.75      1508
weighted avg       0.76      0.76      0.76      1508

Epoch 3
Validation AUC_ROC:  0.833889607883205
Test AUC_ROC:  0.8229050217289908
              precision    recall  f1-score   support

           0       0.80      0.79      0.80       922
 

### Interpretation

In [20]:
MIMIC_PATH = 'E:/CS_Master_Degree_UIUC/CS598_DeepLearning_for_Health_Data/Project/paper290/MIMIC data/'
out_path = 'E:/CS_Master_Degree_UIUC/CS598_DeepLearning_for_Health_Data/Project/paper290/Output/'


icditems = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.types', 'rb'))
meditems = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.meds.types', 'rb'))
labitems = pickle.load(open(DATA_PATH + 'MIMICIIIPROCESSED.abnlabs.types', 'rb'))

model_name = 'CAE'

interpretation_file = open(out_path + "RNN_CLatent_Interpretations_" + model_name + ".txt", 'w')

# overall_risk_factor_file = open("risk_factors_averaged.txt", "w")


labnames = {}
lab_dict_file = open(MIMIC_PATH + 'D_LABITEMS.csv', 'r')
lab_dict_file.readline()
for line in lab_dict_file:
    tokens = line.strip().split(',')
    labnames[tokens[1].replace('"','')] = tokens[2]
lab_dict_file.close()

icdnames = {}
icd_dict_file = open(MIMIC_PATH + 'D_ICD_DIAGNOSES.csv', 'r')
icd_dict_file.readline()
for line in icd_dict_file:
    tokens = line.strip().split(',')
    icdnames[tokens[1].replace('"','')] = tokens[2]
icd_dict_file.close()

icd_scores = {}
med_scores = {}
lab_scores = {}

icd_totals = {}
med_totals = {}
lab_totals = {}

def get_ICD(icd):
    '''
    Given icd integer index, return the string name of that icd code:
    e.g. get_ICD(1) returns "Hypertension NOS"
    '''
    ret_str = ""
    icd_key_lst = list(icditems.keys())
    icd_val_ind = list(icditems.values())[icd]
    icd_str = icd_key_lst[icd_val_ind]
    actual_key = icd_str.replace(".", "")[2:]
    if actual_key in icdnames:
        ret_str = icdnames[actual_key]
    else:
        ret_str = icd_str
    return ret_str

def get_med(med):
    '''
    Given icd integer index, return the string name of that med code:
    e.g. get_med(1) returns "Phenylephrine HCl"
    '''
    med_key_lst = list(meditems.keys())
    med_val_ind = list(meditems.values())[med]
    ret_str = med_key_lst[med_val_ind]
    return ret_str

def get_lab(lab):
    '''
    Given lab integer index, return the string name of that lab code:
    e.g. get_lab(1) returns "Hemoglobin"
    '''
    lab_key_lst = list(labitems.keys())
    lab_val_ind = list(labitems.values())[lab]
    ret_str = labnames[lab_key_lst[lab_val_ind]]
    return ret_str

In [21]:
def get_factors(icd_seq, med_seq, lab_seq, model, actual_score, full_icd):
    potential_test_data = []

    for seq in range(len(icd_seq)):
        for i in range(len(icd_seq[seq])):
            potential_test_data.append(("icd", full_icd[seq][i], seq, icd_seq[:seq]+[icd_seq[seq][:i] + icd_seq[seq][i+1:]]+icd_seq[seq+1:], med_seq, lab_seq))
    for seq in range(len(med_seq)):
        for i in range(len(med_seq[seq])):
            potential_test_data.append(("med", med_seq[seq][i], seq, icd_seq, med_seq[:seq]+[med_seq[seq][:i]+med_seq[seq][i+1:]]+med_seq[seq+1:], lab_seq))
    for seq in range(len(lab_seq)):
        for i in range(len(lab_seq[seq])):
            potential_test_data.append(("lab", lab_seq[seq][i], seq, icd_seq, med_seq, lab_seq[:seq]+[lab_seq[seq][:i] + lab_seq[seq][i+1:]]+lab_seq[seq+1:]))

    risk_scores = []

    for pt in potential_test_data:
        test_input_icd = Variable(torch.from_numpy(convert_to_one_hot(pt[3], vocabsize_icd)).float())
        test_input_med = Variable(torch.from_numpy(convert_to_one_hot(pt[4], vocabsize_meds)).float())
        test_input_lab = Variable(torch.from_numpy(convert_to_one_hot(pt[5], vocabsize_labs)).float())
        factor_score = actual_score - model.predict(test_input_icd, test_input_med, test_input_lab)
        factor = ""
        if pt[0] == 'icd':
            icd_tag = get_ICD(pt[1])
            factor = "ICD-"+icd_tag
            if icd_tag in icd_scores:
                icd_scores[icd_tag] += factor_score
                icd_totals[icd_tag] += 1
            else:
                icd_scores[icd_tag] = factor_score
                icd_totals[icd_tag] = 1
        elif pt[0] == 'med':
            med_tag = get_med(pt[1])
            factor = "Med-"+med_tag
            if med_tag in med_scores:
                med_scores[med_tag] += factor_score
                med_totals[med_tag] += 1
            else:
                med_scores[med_tag] = factor_score
                med_totals[med_tag] = 1
        else:
            lab_tag = get_lab(pt[1])
            factor = "Lab-"+lab_tag
            if lab_tag in lab_scores:
                lab_scores[lab_tag] += factor_score
                lab_totals[lab_tag] += 1
            else:
                lab_scores[lab_tag] = factor_score
                lab_totals[lab_tag] = 1
        risk_scores.append(("Encounter-"+str(pt[2])+": "+factor, factor_score))

    risk_scores.sort(key=lambda tup: tup[1], reverse=True)

    return risk_scores[:10]


In [None]:
# print "Final testing and interpretations"

interpretation_file = open(out_path + "RNN_Concat_Interpretations.txt", 'w')
predictions = np.zeros(len(test_input_seqs_icd))
for i in (range(len(test_input_seqs_icd))):
    test_input_icd = Variable(torch.from_numpy(convert_to_one_hot(test_input_seqs_icd[i], vocabsize_icd)).float())
    test_input_med = Variable(torch.from_numpy(convert_to_one_hot(test_input_seqs_meds[i], vocabsize_meds)).float())
    test_input_lab = Variable(torch.from_numpy(convert_to_one_hot(test_input_seqs_labs[i], vocabsize_labs)).float())

    test_score = model.predict(test_input_icd, test_input_med, test_input_lab)
    predictions[i] = test_score
    top_risk_factors = get_factors(test_input_seqs_icd[i], test_input_seqs_meds[i], test_input_seqs_labs[i], model, test_score, test_input_seqs_fullicd[i]) 
    if (test_score>0.5):
        interpretation_file.write("ID: " + str(i) + " True label: "+str(test_labels[i])+"\n")
        for rf in top_risk_factors:
            interpretation_file.write(str(rf)+"\n")
        interpretation_file.write("\n")

interpretation_file.close()