In [103]:
import sys
sys.path.append('../../')
sys.path.append('../')

from Utils import Vocab, multi_label_metric
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math 

In [104]:
GLOBAL_DATA_PATH = "../../Data"
GLOBAL_MODELS_PATH = "../../Models"
GAMENET_DATA_PATH = "./Data"
MULTI_VISIT_TEMPORAL_PKL = "multi_visit_temporal.pkl"
EHR_ADJ_PKL = 'ehr_adj_matrix.pkl'

ATC4 = "ATC4"
ICD9_CODE = "ICD9_CODE"

# <span style="text-decoration: underline">GAMENet</span>

The following GAMENet implementation references code from Homework 5. Adjustments were made in consideration to the GAMENet baseline implemented in the original G-BERT paper. 

This is a utility function to build a 'Vocabulary' from medical codes: ICD-9 (diagnoses) codes and ATC4 (medication) codes. The function takes in data in the following format:

data = [ patient1, patient2, patient3, ... ] \
patient = [ diagnoses_codes, medication_codes ] \
diagnoses_codes = [ diagnoses_visit1, diagnoses_visit2, diagnoses_visit3 ... ] \
medication_codes = [ medication_visit1, medication_visit2, medication_visit3 ... ] \
diagnoses_visit = [ code, code, code, ... ] \
medication_visit = [ code, code, code, ... ]

diagnoses_codes = data[0][0][0] \
medication_codes = data[0][1][0]

In [105]:
def build_vocab_from_patient_data(data):
    vocab = Vocab()
    for patient in data:
        for visit in patient:
            vocab.add_sentence(visit)
    
    return vocab

This utility function creates a list of patient records using MIMIC-III patient data, a diagnosis Vocab, and a medication Vocab as input. The data is formatted as explained above. The ICD-9 and ATC codes for each patient record are converted into indices defined in the Vocabs. 

In [106]:
def create_idx_records(multi_visit_data, diag_vocab, med_vocab):
    records = []
    for index, row in multi_visit_data.iterrows():
        patient = []
        admissions = []
        
        for visit in row[ICD9_CODE]:
            diag_codes = []
            for code in visit:
                diag_codes.append(diag_vocab.word2idx[code])
        
            admissions.append(diag_codes)
        patient.append(admissions)

        admissions = []
        for visit in row[ATC4]:
            med_codes = []
            for code in visit:
                med_codes.append(med_vocab.word2idx[code])
            
            admissions.append(med_codes)
        patient.append(admissions)

        records.append(patient)
        
    return records

This code builds the diagnoses Vocab, the medication Vocab, and the patient records as codes converted to indices. The lengths of each Vocab is also defined here. They will be used as parameters in the neural networks. 

In [107]:
multi_visit_data = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_TEMPORAL_PKL))

diag_vocab = build_vocab_from_patient_data(multi_visit_data[ICD9_CODE])
med_vocab = build_vocab_from_patient_data(multi_visit_data[ATC4])

diag_vocab_size = len(diag_vocab)
med_vocab_size = len(med_vocab)

records = create_idx_records(multi_visit_data, diag_vocab, med_vocab)


In [108]:
def create_adjacency_matrix(N, records):
    
    adj_matrix = torch.zeros(N, N)

    for patient in records:
        med_set = patient[1]
        for visit in med_set:
            for i, code_i in enumerate(visit):
                for j, code_j in enumerate(visit):
                    if j <= i:
                        continue

                    adj_matrix[code_i, code_j] = 1
                    adj_matrix[code_j, code_i] = 1
                    
    return adj_matrix


In [109]:
ehr_adj_path = os.path.join(GAMENET_DATA_PATH, EHR_ADJ_PKL)

This code below creates an adjacency matrix out of the medication codes and outputs them as a PKL file. Running create_adjacency_matrix method takes ~2 minutes. Only uncomment the code to generate a new PKL file. 

In [110]:
# import dill

# N = len(med_vocab.word2idx)

# adj_matrix = create_adjacency_matrix(N, records)
# dill.dump(adj_matrix, open(ehr_adj_path, 'wb'))


In [111]:
ehr_adj = pd.read_pickle(ehr_adj_path)
print("The shape of EHR adjacency matrix: ", ehr_adj.shape)

The shape of EHR adjacency matrix:  torch.Size([385, 385])


In [112]:
class RNN(nn.Module):
    def __init__(self, vocab_size, emb_dim=16):
        super(RNN, self).__init__()

        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=emb_dim, batch_first=True)
        self.init_weights()
    
    def forward(self, codes):        
        emb_list = []
        for code in codes:
            emb = None
            emb = self.embeddings(torch.tensor(code))
            
            # take the mean and make one sample per batch
            emb_mean = emb.mean(dim=0).unsqueeze(dim=0)
            emb_list.append(emb_mean)
        
        # Check emb_list, inform Dinesh
        emb_seq = torch.cat(emb_list, dim=0).unsqueeze(dim=0)
        #print("Embedding Sequence: ")
        #print(emb_seq)
        result, _ = self.rnn(emb_seq)
        
        return result
    
    def init_weights(self):
        torch.nn.init.normal_(self.embeddings.weight)
        for param in self.rnn.parameters():
            if len(param.shape) >= 2:
                torch.nn.init.orthogonal_(param.data)
            else:
                torch.nn.init.normal_(param.data)    

In [113]:
class PatientQuery(nn.Module):
    def __init__(self, diag_vocab, emb_dim=16):
        super(PatientQuery, self).__init__()

        self.diag_rnn = RNN(diag_vocab, emb_dim)
        self.linear = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=True)
        
    def forward(self, codes_diag,):
        """
        Input:
            codes_diag: [[diag codes for visit1], [diag codes for visit2], ...]
                - e.g., [[0,1], [1,2]]  
        output:
            query embedding:
                - size: (#visits, emd_dim)
        """
        # get diag and prod embedding by self.rnn_diag, self.rnn_prod
        diag_emb = self.diag_rnn(codes_diag)
        
        # concat emb_diag and emb_prod, then tranfrom by self.linear
        result = self.linear(diag_emb)
        result = result.squeeze(0)

        return result

In [114]:
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class GCN(nn.Module):
    def __init__(self, vocab_med_size, adj, emb_dim=16):
        super(GCN, self).__init__()
        adj = self.normalize(adj + np.eye(adj.shape[0]).astype(np.float32))
        self.adj = torch.FloatTensor(adj.float())
        
        # self.gcn1 = GCNConv(in_channels=vocab_med_size, out_channels=emb_dim)
        # self.gcn2 = GCNConv(in_channels=emb_dim, out_channels=emb_dim)
        
        self.gcn1 = GraphConvolution(in_features=vocab_med_size, out_features=emb_dim)
        self.gcn2 = GraphConvolution(in_features=emb_dim, out_features=emb_dim)
        
        # the initial feature
        self.x = torch.eye(vocab_med_size)
    
    def forward(self):
        """
        TODO: 
            1 (already done). use self.gcn1 for the first graph convolution
                - remember to use self.adj
            2. use F.relu() as the activation function
            3. use self.gcn2
        """
        x = self.gcn1(self.x, self.adj)
        x = F.relu(x)
        x = self.gcn2(x, self.adj)
        
        return x
    
    def normalize(self, adj):
        adj = adj / (adj @ np.ones(adj.shape) + 1e-8)
        return adj

In [115]:
class MemoryBank(nn.Module):
    def __init__(self, vocab_med_size, adj_ehr, emb_dim=16):
        super(MemoryBank, self).__init__()
        """
        FUNCTIONALITY: Combine information from EHR graph and DDI graph
        """
        
        # Use GCN for EHR graph using adj_ehr and emb_dim as output simension
        self.gcn_ehr = GCN(vocab_med_size=vocab_med_size, adj=adj_ehr, emb_dim=emb_dim)
        
        # TO-DO: Check if this is needed. 
        # Learnable weight between adj_ddi and adj_ehr 
        self.weight = nn.Parameter(torch.FloatTensor(1))
        self.weight.data.uniform_(-0.1, 0.1)
        
    def forward(self):
        # get ehr graph feature
        info_ehr = self.gcn_ehr.forward()
        
        # get weighted information
        info_comb = info_ehr * self.weight
        
        return info_comb

In [116]:
class DynamicMemory(nn.Module):
    def __init__(self, med_vocab_size):
        super(DynamicMemory, self).__init__()
        """
        FUNCTIONALITY: generate a historical mapping: query embedding -> multi-hot medication vector
        """
        self.med_vocab_size = med_vocab_size
    
    def forward(self, queries, codes_med):
        """
        Input:  queries
                    - this is the historical query embedding, given by PatientQuery Module
                    - size: (#visits - 1, emb_dim), delete the current query
                codes_med
                    - this is the historical groud truth med vector
                    - format: a list of length (#visits - 1)
        """
        
        DM_key = queries

        DM_value = np.zeros((queries.shape[0], self.med_vocab_size))
        
        # transform codes_med to multi-hot and filling DM_value row by row (visit by visit)
        for visit_i in range(len(codes_med)):
            for code_j in codes_med[visit_i]:                
                DM_value[visit_i][code_j] = 1
        
        # use torch.FloatTensor for DM_value
        DM_value = torch.FloatTensor(DM_value)
        
        return DM_key, DM_value

In [117]:
class Fact1(nn.Module):
    def __init__(self, queries):
        super(Fact1, self).__init__()
        """
        FUNCTIONALITY: extract the final embedding from input queries, q^t
        """
        self.queries = queries
    
    def forward(self):
        # assign last embedding to result
        result = self.queries[-1]

        # final size: (1, emb_dim)
        result = result.unsqueeze(0)
        return result
    

class Fact2(nn.Module):
    def __init__(self, query, MB):
        super(Fact2, self).__init__()
        """
        FUNCTIONALITY: get attention information from MB, o^b_t
        Input:
            query
                - this is the last embedding
            MB
                - is the memory bank
        """
        self.query = query
        self.MB = MB
    
    def forward(self):
        # get the attention weight between query and each row of MB, adding to attn_score
        # need to transpose MB, using MB.t()
        attn_score = torch.mm(self.query, self.MB.t())
        # use F.softmax(, dim=1) to compute the attention matrix, attn_matrix
        attn_matrix = F.softmax(attn_score, dim=1)
        # get the final result from attn_matrix and MB
        result = torch.mm(attn_matrix, self.MB)
        
        return result
    
class Fact3(nn.Module):
    def __init__(self, query, MB, DM_key, DM_value):
        super(Fact3, self).__init__()
        """
        FUNCTIONALITY: similar to Fact2, get information from the DM, o_d^t
        """
        self.query = query
        self.MB = MB
        self.DM_key = DM_key
        self.DM_value = DM_value
    
    def forward(self):
        attn = F.softmax(torch.mm(self.query, self.DM_key.t()), dim=1)
        value = torch.mm(attn, self.DM_value)
        out = torch.mm(value, self.MB)
        return out

In [118]:
class OutNet(nn.Module):
    def __init__(self, vocab_med, emb_dim=16):
        super(OutNet, self).__init__()
        """
        FUNCTIONALITY: combine fact1, fact2, fact3 to do final prediction
        """
        
        """
        TODO:
            1. build the first linear layer
                - input size: 3 * emb_dim
                - end size: 2 * emb_dim
                - use bias=True option
            2. build the second linear layer
                - input size: 2 * emb_dim
                - end size: vocab_med
                - use bias=True option
        """
        self.fc1 = nn.Linear(in_features=3 * emb_dim, out_features=2 * emb_dim, bias=True)
        self.fc2 = nn.Linear(in_features=2 * emb_dim, out_features=vocab_med, bias=True)
        
    def forward(self, fact1, fact2, fact3):
        """
        Input:
            fact1, fact2, fact3:
                - three facts q^t, o^t_b, o^t_d
                - each size: (1, emb_dim)
        """
        # concat fact1, fact2, fact3, and assign to memory_out
        # size: (1, 3 * emb_dim)
        memory_out = torch.cat((fact1, fact2, fact3), dim=1)
        
        result = self.fc1(memory_out)
        result = F.relu(result)
        result = self.fc2(result)
        
        return result

In [119]:
class GAMENet(nn.Module):
    def __init__(self, diag_vocab_size, med_vocab_size, ehr_adj, emb_dim=128):
        super(GAMENet, self).__init__()
        """
        FUNCTIONALITY: integrate the whole model together
        """
        self.patient = PatientQuery(diag_vocab_size, emb_dim)
        self.memorybank = MemoryBank(med_vocab_size, ehr_adj, emb_dim)
        self.dynamicmemory = DynamicMemory(med_vocab_size)
        self.outnet = OutNet(med_vocab_size, emb_dim)
        
    def forward(self, codes_diag, codes_med):
        """
        input:
            codes_diag
                - a list of length #visits
                - each element is also itself a list
            codes_med
                - a list of length #visits - 1
                - each element is also itself a list
        """
        
        # get patient query embedding (#visit, emb_dim*2)
        # get patient query embedding (#visit, emb_dim) ??
        queries = self.patient(codes_diag)
        """
        TODO:
            1. build a memory bank, assign to MB
                - use self.memorybank
            2. build a dynamic memory, assign to DM_key and DM_value
                - use self.dynamicmemory
                - use queries and codes_med as features
        """
        MB = None
        DM_key, DM_value = None, None
           
        MB = self.memorybank()
        DM_key, DM_value = self.dynamicmemory(queries, codes_med)

        # extract three memory outputs, assign to fact1, fact2, fact3
        fact1 = Fact1(queries)()
        fact2 = Fact2(fact1, MB)()
        fact3 = Fact3(fact1, MB, DM_key, DM_value)()

        # get the final output
        result = self.outnet(fact1, fact2, fact3)
        
        return result

In [120]:
model = GAMENet(diag_vocab_size, med_vocab_size, ehr_adj)
output = model([[0,1], [1,2], [4,5]], [[0,1], [2,5]])

assert output.shape == (1, med_vocab_size)

In [121]:
num_param = sum(p.numel() for p in model.parameters())
print("Number of Model Parameters: ", num_param)

Number of Model Parameters:  961538


In [122]:
import random

seed = 1234
random.seed(seed)

tmp = records.copy()
random.shuffle(tmp)

# 67%: 16.5% 16.5% split training and test sets randomly
records_size = len(records)

train_end = int(records_size * 0.67)
validation_end = int(records_size * 0.165) + train_end

train_set = tmp[:train_end]
validation_set = tmp[train_end:validation_end]
test_set = tmp[validation_end:]

In [127]:
EPOCHS = 75
LEARNING_RATE = 1e-3

In [128]:
from sklearn.metrics import accuracy_score, jaccard_score, average_precision_score, f1_score
from datetime import datetime

date_string = datetime.now().strftime("_%y%m%d_%H%M%S")

model = GAMENet(diag_vocab_size, med_vocab_size, ehr_adj)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

def dataFormatter(patient_list):

    diag_list, med_list = [], []
    for diag in patient_list[0]:
        diag_list.append(diag)

    for med in patient_list[1]:
        med_list.append(med)
    
    target = np.zeros((1, med_vocab_size))
    target[0, med_list[-1]] = 1
    return diag_list, med_list[:-1], torch.FloatTensor(target)

def test(model, test_set):
    model.eval()
    pred_list, pred_prob_list, target_list = [], [], []
    for patient in test_set:
        for idx, visit in enumerate(patient[1]):
            codes_diag, codes_med, target = dataFormatter(patient)
            pred = model(codes_diag, codes_med).detach().cpu().numpy()[0]
            pred_prob_list += pred.tolist()
            pred[pred >= 0.5] = 1; pred[pred < 0.5] = 0
            pred_list += pred.tolist(); target_list += target.numpy().tolist()[0]
    return pred_list, target_list, pred_prob_list

def train(train_set):
    for i in range(EPOCHS):
        model.train()

        jaccard_list = []
        pr_auc_list = []
        f1_list = []
        accuracy_list = []

        for patient in train_set:
            loss = 0

            for idx, visit in enumerate(patient[1]):
                codes_diag, codes_med, target = dataFormatter(patient)
                # get the target of multilabel_margin_loss
                multi_target = np.full((1, med_vocab_size), -1)
                
                for idx, item in enumerate(visit):
                    multi_target[0][idx] = item
                multi_target = torch.LongTensor(multi_target)

                pred = model(codes_diag, codes_med)
                
                loss += F.binary_cross_entropy_with_logits(pred, target) + \
                                    F.multilabel_margin_loss(torch.sigmoid(pred), multi_target)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        # use the first 100 test patients to compute intermediate accuracy
        pred_list, target_list, pred_prob_list = test(model, test_set[:100])

        # _, current_pr_auc, _ = multi_label_metric(target_list, pred_list, pred_prob_list)
        current_jaccard = jaccard_score(target_list, pred_list)
        current_pr_auc = average_precision_score(target_list, pred_prob_list)
        current_f1 = f1_score(target_list, pred_list)
        current_acc = accuracy_score(pred_list, target_list)

        jaccard_list.append(current_jaccard)
        pr_auc_list.append(current_pr_auc)
        f1_list.append(current_f1)
        accuracy_list.append(current_acc)
        
        print("Current Epoch: ", i)
        print("===================")
        print("Current Avg. Jaccard Score: ", sum(jaccard_list) / len(jaccard_list))
        print("Current Avg. PR-AUC Score: ", sum(pr_auc_list) / len(pr_auc_list))
        print("Current Avg. F1 Score: ", sum(f1_list) / len(f1_list))
        print("Current Avg. Accuracy Score: ", sum(accuracy_list) / len(accuracy_list))
        print("..................................................................")

    torch.save({
            "model_state_dict": model.state_dict()
            }, os.path.join(GLOBAL_MODELS_PATH,"gamenet" + date_string))

    return model


print("Training..........................................................")
eval_model = train(train_set)
checkpoint = torch.load(os.path.join(GLOBAL_MODELS_PATH, "gamenet" + date_string))
#eval_model = model.load_state_dict(checkpoint["model_state_dict"])
        
print("Evaluating........................................................")
pred_list, target_list, pred_prob_list = test(eval_model, test_set)
evaluation_jaccard_score = "Evaluation Jaccard Score: {0} ".format(jaccard_score(target_list, pred_list))
evaluation_pr_auc_score = "Evaluation PR-AUC Score: {0}".format(average_precision_score(target_list, pred_prob_list, average='macro'))
evaluation_f1_score = "Evaluation F1 Score: {0}".format(f1_score(target_list, pred_list))
evaluation_accuracy_score = "Evaluation Accuracy Score: {0}".format(accuracy_score(target_list, pred_list))

print(evaluation_jaccard_score)
print(evaluation_pr_auc_score)
print(evaluation_f1_score)
print(evaluation_accuracy_score)

results_file = os.path.join("Results", "results" + date_string + ".txt")
file_obj = open(results_file, 'w+')
file_obj.write("Number of Epochs: " + str(EPOCHS) + "\n")
file_obj.write("Learning Rate: " + str(LEARNING_RATE) + "\n")
file_obj.write("Number of Parameters: " + str(num_param) + "\n")
file_obj.write(evaluation_jaccard_score + "\n")
file_obj.write(evaluation_pr_auc_score + "\n")
file_obj.write(evaluation_f1_score + "\n")
file_obj.write(evaluation_accuracy_score + "\n")

file_obj.close()

Training..........................................................
Current Epoch:  0
Current Avg. Jaccard Score:  0.20279585751386706
Current Avg. PR-AUC Score:  0.34359816424183565
Current Avg. F1 Score:  0.3372074425547795
Current Avg. Accuracy Score:  0.8072184531886024
..................................................................
Current Epoch:  1
Current Avg. Jaccard Score:  0.22115518927224473
Current Avg. PR-AUC Score:  0.32235875771391054
Current Avg. F1 Score:  0.36220652577997653
Current Avg. Accuracy Score:  0.8270750145377012
..................................................................
Current Epoch:  2
Current Avg. Jaccard Score:  0.23346205204085316
Current Avg. PR-AUC Score:  0.3624451185256669
Current Avg. F1 Score:  0.3785476037216923
Current Avg. Accuracy Score:  0.835316921884086
..................................................................
Current Epoch:  3
Current Avg. Jaccard Score:  0.23671065384328635
Current Avg. PR-AUC Score:  0.388392845254122

Results of GAMENet Training
===============================
*** Stats *** \
Number of Epochs: 75 \
Learning Rate: 0.001 \
Number of Parameters: 961538 

Average Jaccard Score: 0.2633452915764412 \
Average PR-AUC Score: 0.32996452114879504 \
Average F1 Score: 0.4169015285565054 \
Average Accuracy Score: 0.8665114667369459