In [30]:
import sys
sys.path.append('../../')
sys.path.append('../')

from Globals import UNIQUE_ATC_CSV, UNIQUE_ICD_CSV
from Utils import Vocab
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from torch_geometric.nn import Sequential, GCNConv
 

In [60]:
GLOBAL_DATA_PATH = "../../Data"
ATC4 = "ATC4"
ICD9_CODE = "ICD9_CODE"
MULTI_VISIT_TEMPORAL_PKL = "multi_visit_temporal.pkl"
GAMENET_DATA_PATH = "./Data"
EHR_ADJ_PKL = 'ehr_adj_matrix.pkl'

In [63]:
def build_vocab_from_patient_data(data):
    vocab = Vocab()
    for patient in data:
        for visit in patient:
            vocab.add_sentence(visit)
    
    return vocab


In [62]:
# Build Vocab of ATC (medication) and ICD-9 (diagnoses) codes for each patient

multi_visit_records = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_TEMPORAL_PKL))

diag_vocab = build_vocab_from_patient_data(multi_visit_records[ICD9_CODE])
med_vocab = build_vocab_from_patient_data(multi_visit_records[ATC4])

diag_vocab_size = len(diag_vocab)
med_vocab_size = len(med_vocab)

records = []
for index, row in multi_visit_records.iterrows():
    patient = []
    admissions = []
    
    for visit in row[ICD9_CODE]:
        diag_codes = []
        for code in visit:
            diag_codes.append(diag_vocab.word2idx[code])
    
        admissions.append(diag_codes)
    patient.append(admissions)

    admissions = []
    for visit in row[ATC4]:
        med_codes = []
        for code in visit:
            med_codes.append(med_vocab.word2idx[code])
        
        admissions.append(med_codes)
    patient.append(admissions)

    records.append(patient)

# Format of records:
# Patient: [[diag_code, diag_code, diag_code], [med_code, med_code, med_code]]
# for row in multi_visit_records[ICD9_CODE]:
#     admission = []
#     patient = []

#     for visit in row:
#         admission.append([diag_vocab.word2idx[i] for i in visit])
    # admission.append([med_vocab.word2idx[i] for i in row[ATC4]])
    # patient.append(admission)
    # print(patient)
    # records.append(patient)


In [34]:
# Method to create adjacency matrix from medication codes
def create_adjacency_matrix(N, records):
    
    adj_matrix = torch.zeros(N, N)

    for patient in records:
        med_set = patient[1]
        #print(med_set)
        for visit in med_set:
            for i, code_i in enumerate(visit):
                for j, code_j in enumerate(visit):
                    if j <= i:
                        continue

                    adj_matrix[code_i, code_j] = 1
                    adj_matrix[code_j, code_i] = 1
                    
    return adj_matrix


In [35]:
ehr_adj_path = os.path.join(GAMENET_DATA_PATH, EHR_ADJ_PKL)

Uncomment the code below to create adjacency matrix and output as PKL file. 

In [36]:
# import dill

# N = len(med_vocab.word2idx)

# adj_matrix = create_adjacency_matrix(N, records)
# dill.dump(adj_matrix, open(ehr_adj_path, 'wb'))


In [37]:
ehr_adj = pd.read_pickle(ehr_adj_path)
# print(ehr_adj)
print(ehr_adj.shape)

tensor([[0., 1., 1.,  ..., 1., 0., 1.],
        [1., 0., 1.,  ..., 1., 0., 0.],
        [1., 1., 0.,  ..., 0., 1., 1.],
        ...,
        [1., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [1., 0., 1.,  ..., 0., 0., 0.]])


In [38]:
class RNN(nn.Module):
    def __init__(self, vocab_size, emb_dim=16):
        super(RNN, self).__init__()

        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=emb_dim, batch_first=True)
        self.init_weights()
    
    def forward(self, codes):        
        emb_list = []
        for code in codes:
            emb = None
            emb = self.embeddings(torch.tensor(code))
            
            # take the mean and make one sample per batch
            emb_mean = emb.mean(dim=0).unsqueeze(dim=0)
            emb_list.append(emb_mean)
            
        emb_seq = torch.cat(emb_list, dim=0).unsqueeze(dim=0)
        
        result, _ = self.rnn(emb_seq)
        
        return result
    
    def init_weights(self):
        torch.nn.init.normal_(self.embeddings.weight)
        for param in self.rnn.parameters():
            if len(param.shape) >= 2:
                torch.nn.init.orthogonal_(param.data)
            else:
                torch.nn.init.normal_(param.data)
    

In [39]:
class PatientQuery(nn.Module):
    def __init__(self, diag_vocab, emb_dim=16):
        super(PatientQuery, self).__init__()

        self.diag_rnn = RNN(diag_vocab, emb_dim)
        # self.linear = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=True)
        self.linear = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=True)
        
    def forward(self, codes_diag,):
        """
        Input:
            codes_diag: [[diag codes for visit1], [diag codes for visit2], ...]
                - e.g., [[0,1], [1,2]]  
            codes_prod: [[prod codes for visit1], [prod codes for visit2], ...]
                - e.g., [[0,1], [1,2]]  
        output:
            query embedding:
                - size: (#visits, emd_dim)
        """
        
        """
        TODO: get diag and prod embedding by self.rnn_diag, self.rnn_prod
        """
        diag_emb = self.diag_rnn(codes_diag)
        
        """
        TODO: concat emb_diag and emb_prod, then tranfrom by self.linear
        """
        result = self.linear(diag_emb)
        result = result.squeeze(0)

        return result

In [40]:
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class GCN(nn.Module):
    def __init__(self, vocab_med_size, adj, emb_dim=16):
        super(GCN, self).__init__()
        adj = self.normalize(adj + np.eye(adj.shape[0]).astype(np.float32))
        self.adj = torch.FloatTensor(adj.float())
        
        # self.gcn1 = GCNConv(in_channels=vocab_med_size, out_channels=emb_dim)
        # self.gcn2 = GCNConv(in_channels=emb_dim, out_channels=emb_dim)
        
        self.gcn1 = GraphConvolution(in_features=vocab_med_size, out_features=emb_dim)
        self.gcn2 = GraphConvolution(in_features=emb_dim, out_features=emb_dim)
        
        # the initial feature
        self.x = torch.eye(vocab_med_size)
    
    def forward(self):
        """
        TODO: 
            1 (already done). use self.gcn1 for the first graph convolution
                - remember to use self.adj
            2. use F.relu() as the activation function
            3. use self.gcn2
        """
        x = self.gcn1(self.x, self.adj)
        x = F.relu(x)
        x = self.gcn2(x, self.adj)
        
        return x
    
    def normalize(self, adj):
        adj = adj / (adj @ np.ones(adj.shape) + 1e-8)
        return adj

In [41]:
class MemoryBank(nn.Module):
    def __init__(self, vocab_med_size, adj_ehr, emb_dim=16):
        super(MemoryBank, self).__init__()
        """
        FUNCTIONALITY: conbime information from EHR graph and DDI graph
        """
        
        """
        TODO:
            1. use GCN for EHR graph
                - using adj_ehr
                - using emb_dim as the ouput dimension
            2. use GCN for DDI graph
                - using adj_ddi
                - using emb_dim as the ouput dimension
            3 (already done). design a learnable weight between them
        """
        
        self.gcn_ehr = GCN(vocab_med_size=vocab_med_size, adj=adj_ehr, emb_dim=emb_dim)
        
        self.weight = nn.Parameter(torch.FloatTensor(1))
        self.weight.data.uniform_(-0.1, 0.1)
        
    def forward(self):
        """
        TODO: 
            1. get ehr graph feature
                - using .forward() function
            2. get ddi graph feature
                - using .forward() function
            3 (already done). get weighted information
        """
        
        info_ehr = self.gcn_ehr.forward()
        
        info_comb = info_ehr * self.weight
        
        return info_comb

In [42]:
class DynamicMemory(nn.Module):
    def __init__(self, med_vocab_size):
        super(DynamicMemory, self).__init__()
        """
        FUNCTIONALITY: generate a historical mapping: query embedding -> multi-hot medication vector
        """
        self.med_vocab_size = med_vocab_size
    
    def forward(self, queries, codes_med):
        """
        Input:  queries
                    - this is the historical query embedding, given by PatientQuery Module
                    - size: (#visits - 1, emb_dim), delete the current query
                codes_med
                    - this is the historical groud truth med vector
                    - format: a list of length (#visits - 1)
        """
        
        DM_key = queries

        DM_value = np.zeros((queries.shape[0], self.med_vocab_size))
        
        """
        TODO: transform codes_med to multi-hot and filling DM_value row by row (visit by visit)
        """
        #print(DM_key)
        #print(codes_med)
        #print(DM_value)
        #print(self.med_vocab_size)
        #print(codes_med)
        for visit_i in range(len(codes_med)):
            for code_j in codes_med[visit_i]:                
                DM_value[visit_i][code_j] = 1
        
        # use torch.FloatTensor for DM_value
        DM_value = torch.FloatTensor(DM_value)
        
        return DM_key, DM_value

In [43]:
class Fact1(nn.Module):
    def __init__(self, queries):
        super(Fact1, self).__init__()
        """
        FUNCTIONALITY: extract the final embedding from input queries, q^t
        """
        self.queries = queries
    
    def forward(self):
        """
        TODO: assign the last embedding to result
                - final size: (1, emb_dim)
                - make sure the output size is **not** a vector format (emb_dim,)
        """
        result = self.queries[-1]

        result = result.unsqueeze(0)
        return result
    

class Fact2(nn.Module):
    def __init__(self, query, MB):
        super(Fact2, self).__init__()
        """
        FUNCTIONALITY: get attention information from MB, o^t_b
        Input:
            query
                - this is the last embedding
            MB
                - is the memory bank
        """
        self.query = query
        self.MB = MB
    
    def forward(self):
        """
        TODO: 
            1. get the attention weight between query and each row of MB, adding to attn_score
                - use torch.mm()
                - tip: need to transpose MB, using MB.t()
            2. use F.softmax(, dim=1) to compute the attention matrix, attn_matrix
            3. get the final result from attn_matrix and MB
                - use torch.mm(,)
        """
        attn_score = torch.mm(self.query, self.MB.t())
        attn_matrix = F.softmax(attn_score, dim=1)
        result = torch.mm(attn_matrix, self.MB)
        
        return result
    
class Fact3(nn.Module):
    def __init__(self, query, MB, DM_key, DM_value):
        super(Fact3, self).__init__()
        """
        FUNCTIONALITY: similar to Fact2, get information from the DM, o_d^t
        """
        self.query = query
        self.MB = MB
        self.DM_key = DM_key
        self.DM_value = DM_value
    
    def forward(self):
        attn = F.softmax(torch.mm(self.query, self.DM_key.t()), dim=1)
        value = torch.mm(attn, self.DM_value)
        out = torch.mm(value, self.MB)
        return out

In [44]:
class OutNet(nn.Module):
    def __init__(self, vocab_med, emb_dim=16):
        super(OutNet, self).__init__()
        """
        FUNCTIONALITY: combine fact1, fact2, fact3 to do final prediction
        """
        
        """
        TODO:
            1. build the first linear layer
                - input size: 3 * emb_dim
                - end size: 2 * emb_dim
                - use bias=True option
            2. build the second linear layer
                - input size: 2 * emb_dim
                - end size: vocab_med
                - use bias=True option
        """
        self.fc1 = None
        self.fc2 = None
        
        self.fc1 = nn.Linear(in_features=3 * emb_dim, out_features=2 * emb_dim, bias=True)
        self.fc2 = nn.Linear(in_features=2 * emb_dim, out_features=vocab_med, bias=True)
        
    def forward(self, fact1, fact2, fact3):
        """
        Input:
            fact1, fact2, fact3:
                - three facts q^t, o^t_b, o^t_d
                - each size: (1, emb_dim)
        """
        
        """
        TODO: 
            1. concat fact1, fact2, fact3, and assign to memory_out
                - size: (1, 3 * emb_dim)
            2. pass through the first linear layer
            3. use F.relu() as the activation
            3. pass through the second linear layer
        """
        memory_out = None
        result = None
        
        memory_out = torch.cat((fact1, fact2, fact3), dim=1)
        #print(memory_out.size())
        
        result = self.fc1(memory_out)
        result = F.relu(result)
        result = self.fc2(result)
        
        return result

In [45]:
class GAMENet(nn.Module):
    def __init__(self, diag_vocab_size, med_vocab_size, ehr_adj, emb_dim=128):
        super(GAMENet, self).__init__()
        """
        FUNCTIONALITY: integrate the whole model together
        """
        
        """
        TODO: 
            1. build a patient query network
                - PatientQuery
                - use diag_vocab_size, vocab_prod, emb_dim
            2. build a memory bank
                - MemoryBank
                - use vocab_med, adj_ehr, adj_ddi, emb_dim
            3. build a dynamic memory
                - DynamicMemory
                - use vocab_med
            4. build a output network
                - OutNet
                - vocab_med, emb_dim
        """        
        self.patient = PatientQuery(diag_vocab_size, emb_dim)
        self.memorybank = MemoryBank(med_vocab_size, ehr_adj, emb_dim)
        self.dynamicmemory = DynamicMemory(med_vocab_size)
        self.outnet = OutNet(med_vocab_size, emb_dim)
        
    def forward(self, codes_diag, codes_med):
        """
        input:
            codes_diag
                - a list of length #visits
                - each element is also itself a list
            codes_prod
                - a list of length #visits
                - each element is also itself a list
            codes_med
                - a list of length #visits - 1
                - each element is also itself a list
        """
        
        # get patient query embedding (#visit, emb_dim*2)
        # get patient query embedding (#visit, emb_dim) ??
        queries = self.patient(codes_diag)
        """
        TODO:
            1. build a memory bank, assign to MB
                - use self.memorybank
            2. build a dynamic memory, assign to DM_key and DM_value
                - use self.dynamicmemory
                - use queries and codes_med as features
        """
        MB = None
        DM_key, DM_value = None, None
           
        MB = self.memorybank()
        DM_key, DM_value = self.dynamicmemory(queries, codes_med)
        #print("queries")
        #print(queries.size())
        # extract three memory outputs, assign to fact1, fact2, fact3
        fact1 = Fact1(queries)()
        #print(MB.size())
        fact2 = Fact2(fact1, MB)()
        fact3 = Fact3(fact1, MB, DM_key, DM_value)()

        # get the final output
        result = self.outnet(fact1, fact2, fact3)
        
        return result

In [46]:
model = GAMENet(diag_vocab_size, med_vocab_size, ehr_adj)
output = model([[0,1], [1,2], [4,5]], [[0,1], [2,5]])

assert output.shape == (1, med_vocab_size)

In [47]:
import random

seed = 1234
random.seed(seed)

tmp = records.copy()
random.shuffle(tmp)

print(tmp[0])
# 80%: 20% split training and test sets randomly
train_set = tmp[:int(len(records)*0.8)]
test_set = tmp[int(len(records)*0.8):]

[[[426, 97, 26, 51, 68, 256, 490, 1000], [954, 34, 68, 36, 26, 7, 71]], [[82, 83, 100, 101, 23, 37, 20, 5, 0, 45, 18, 117, 224], [34, 21, 0, 5, 18, 19, 20, 22, 4, 100, 101, 84, 41, 42, 9, 10, 11, 6, 2, 26, 27, 28, 29, 30, 81, 35, 149]]]


In [50]:

model = model = GAMENet(diag_vocab_size, med_vocab_size, ehr_adj)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def dataFormatter(patient_list):
    print("patient list")
    print(patient_list)
    diag_list, med_list = [], []
    for diag, med in patient_list:
        print("diag")
        print(diag)
        print("med")
        print(med)
        diag_list.append(diag)
        med_list.append(med)
    
    target = np.zeros((1, med_vocab_size))
    target[0, med_list[-1]] = 1
    return diag_list, med_list[:-1], torch.FloatTensor(target)

# def dataFormatter(patient_list):
#     diag_list, med_list = [], []

#     for codes in patient_list:


# from sklearn.metrics import accuracy_score

def test(test_set):
    model.eval()
    pred_list, target_list = [], []
    for patient in test_set:
        for idx, visit in enumerate(patient):
            codes_diag, codes_med, target = dataFormatter(patient[:idx+1])
            pred = model(codes_diag, codes_med).detach().cpu().numpy()[0]
            pred[pred >= 0.5] = 1; pred[pred < 0.5] = 0
            pred_list += pred.tolist(); target_list += target.numpy().tolist()[0]
    return pred_list, target_list

def train(train_set):
    for i in range(2):
        model.train()
        for patient in train_set:
            loss = 0

            # codes_diag, codes_med, target = dataFormatter(patient)
            # print(patient)
            # Gets number of visits from diagnoses codes (should also be the same for medication codes)
            for idx, visit in enumerate(patient[0]):
                print(visit)
                #print("meds")
                #print(visit[0])
                #print("Diags")
                #print(visit[1])
                # get the training data and target of bce loss
                codes_diag, codes_med, target = dataFormatter(patient[:idx+1])
                # get the target of multilabel_margin_loss
                multi_target = np.full((1, med_vocab_size), -1)
                for idx, item in enumerate(visit[2]):
                    multi_target[0][idx] = item
                multi_target = torch.LongTensor(multi_target)
                
                """
                TODO: get the output of the model, assign to pred
                """
                pred = None
                pred = model(codes_diag, codes_med)
                
                loss += F.binary_cross_entropy_with_logits(pred, target) + \
                                    F.multilabel_margin_loss(torch.sigmoid(pred), multi_target)
            
            """ 
            TODO:
                1. set zero grad to the optimizer
                2. backward loss
                3. make one step of the optimizer
            """
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        # use the first 100 test patients to compute intermediate accuracy
        print(accuracy_score(*test(test_set[:100])))
    
# train the model with first 500 patients
# feel free to change the number of training patients
train(train_set[:50])

[426, 97, 26, 51, 68, 256, 490, 1000]
patient list
[[[426, 97, 26, 51, 68, 256, 490, 1000], [954, 34, 68, 36, 26, 7, 71]]]
diag
[426, 97, 26, 51, 68, 256, 490, 1000]
med
[954, 34, 68, 36, 26, 7, 71]


IndexError: index 954 is out of bounds for axis 1 with size 385

Ignore code below.

In [None]:
# import dill
# import pickle as pkl

# gamenet_data_path = '../GAMENet_Data'
# voc_file = 'voc_final.pkl'
# file_obj = open(os.path.join(gamenet_data_path, voc_file), 'rb')
# med_voc = pkl.load(file_obj)
# med_voc_size = len(med_voc.idx2word)

# ehr_adj = np.zeros((med_voc_size, med_voc_size))
# for patient in records:
#     for adm in patient:
#         med_set = adm[2]
#         for i, med_i in enumerate(med_set):
#             for j, med_j in enumerate(med_set):
#                 if j<=i:
#                     continue
#                 ehr_adj[med_i, med_j] = 1
#                 ehr_adj[med_j, med_i] = 1
# print(ehr_adj)
# #dill.dump(ehr_adj, open('ehr_adj_final.pkl', 'wb'))  

In [None]:
# records = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_PKL))

# # Retrieve unique medication codes and diagnosis codes
# med_vocab = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, UNIQUE_ATC_CSV))[ATC4]
# diag_vocab = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, UNIQUE_ICD_CSV))[ICD9_CODE]

# med_vocab = med_vocab.values.tolist()
# diag_vocab = diag_vocab.values.tolist()
# #print(med_vocab.values.tolist())

# print(records.head(5))
# print(records['ATC4'].head(5))

385
