In [None]:
import os
#import jsondim
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import *
from sklearn import metrics
import numpy as np
from collections import defaultdict
import sys

In [None]:
class Loss(nn.Module):
    def __init__(self,device):
        super(Loss, self).__init__()
        self.classify_loss = nn.BCELoss()
        self.device=device

    def forward(self, prob, labels, train=True):
        
        #prob = prob.data.cpu().numpy()
        #print(torch.sum(torch.isnan(prob)))
        pos_ind = labels >= 0.5
        neg_ind = labels < 0.5
        pos_label = labels[pos_ind]
        neg_label = labels[neg_ind]
        pos_prob = prob[pos_ind]
        neg_prob = prob[neg_ind]
        pos_loss, neg_loss = 0, 0

        
        if len(pos_prob):
            pos_prob=pos_prob.to(self.device)
            pos_label=pos_label.to(self.device)
            pos_loss = self.classify_loss(pos_prob, pos_label) 
       
        if len(neg_prob):
            neg_prob=neg_prob.to(self.device)
            neg_label=neg_label.to(self.device)
            neg_loss = self.classify_loss(neg_prob, neg_label)
        
        classify_loss = pos_loss + neg_loss
        # classify_loss = self.classify_loss(prob, labels)
        
        labels = labels.data.cpu().numpy()
        prob = prob.data.cpu().numpy()
        
        fpr, tpr, threshholds = metrics.roc_curve(labels, prob)
        auc = metrics.auc(fpr, tpr)
        
        base = ((labels==1).sum())/labels.shape[0]
        
        precision, recall, thresholds = metrics.precision_recall_curve(labels, prob)
        apr = metrics.auc(recall, precision)
        
        accur=metrics.accuracy_score(labels,prob>=0.5)
        prec=metrics.precision_score(labels,prob>=0.5)
        
        # stati number
        prob1 = prob >= 0.5
        #print(prob)
        
        pos_l = (labels==1).sum()
        neg_l = (labels==0).sum()
        pos_p = (prob1 + labels == 2).sum()#how many positives are predicted positive#####TP
        neg_p = (prob1 + labels == 0).sum()#True negatives
        prob2 = prob < 0.5
        fn    = (prob2 + labels==2).sum()
        fp    = (prob2 + labels==0).sum()
        #print(classify_loss, pos_p, pos_l, neg_p, neg_l)
        
        
        return [classify_loss, pos_p, pos_l, neg_p, neg_l, auc, apr, base, accur,prec,fn,fp,prob,labels]


In [None]:
class LSTMBase(nn.Module):
    def __init__(self,device,cond_vocab_size,cond_seq_len,proc_vocab_size,proc_seq_len,med_vocab_size,med_seq_len,lab_vocab_size,lab_seq_len,bmi_vocab_size,bmi_seq_len,time_vocab_size,census_vocab_size,embed_size,rnn_size,batch_size): #proc_vocab_size,med_vocab_size,lab_vocab_size
        super(LSTMBase, self).__init__()
        self.embed_size=embed_size
        self.rnn_size=rnn_size
        self.cond_vocab_size=cond_vocab_size
        self.cond_seq_len=cond_seq_len
        self.proc_vocab_size=proc_vocab_size
        self.proc_seq_len=proc_seq_len
        self.med_vocab_size=med_vocab_size
        self.med_seq_len=med_seq_len
        self.lab_vocab_size=lab_vocab_size
        self.lab_seq_len=lab_seq_len
        self.bmi_vocab_size=bmi_vocab_size
        self.bmi_seq_len=bmi_seq_len
        self.time_vocab_size=time_vocab_size
        self.census_vocab_size=census_vocab_size
        self.batch_size=batch_size
        self.padding_idx = 0
        self.modalities=1
        self.device=device
        self.build()
        
    def build(self):
        #self.cond=CodeLSTMAll(self.device,self.embed_size,self.rnn_size,self.cond_vocab_size,self.cond_seq_len,self.time_vocab_size,self.batch_size,bmi_flag=False)
        #self.proc=CodeLSTMAll(self.device,self.embed_size,self.rnn_size,self.proc_vocab_size,self.proc_seq_len,self.time_vocab_size,self.batch_size,bmi_flag=False)
        self.med=CodeBase(self.device,self.embed_size,self.rnn_size,self.med_vocab_size,self.med_seq_len,self.time_vocab_size,self.batch_size,bmi_flag=False)
        #self.lab=CodeLSTM(self.device,self.embed_size,self.rnn_size,self.lab_vocab_size,self.lab_seq_len,self.time_vocab_size,self.batch_size)
        
        self.condEmbed=nn.Embedding(self.cond_vocab_size,self.embed_size,self.padding_idx) 
        self.cond_fc=nn.Linear(self.rnn_size, 1, False)
        
        #self.censusEmbed=nn.Embedding(self.census_vocab_size,self.embed_size,self.padding_idx)             
        #self.census_max = nn.AdaptiveMaxPool1d(1, True)        
        
        
        self.fc=nn.Linear((self.embed_size*self.cond_seq_len)+self.rnn_size, 1, False)
        
        self.sig = nn.Sigmoid()
        
    def forward(self, meds,conds,contrib):        
        
        med_h_n = self.med(meds)  
        med_h_n=med_h_n.view(med_h_n.shape[0],-1)
        print("med_h_n",med_h_n.shape)
        
        conds=conds.to(self.device)
        conds=self.condEmbed(conds)
        print(conds.shape)
        conds=conds.view(conds.shape[0],-1)
        print(conds.shape)
        #print("cond_pool_ob",cond_pool_ob.shape)
        #out1=torch.cat((cond_pool,cond_pool_ob),1)
        #out1=cond_pool
        out1=torch.cat((conds,med_h_n),1)
        print("out1",out1.shape)
        out1 = self.fc(out1)
        #print("out1",out1.shape)
        
        sigout1 = self.sig(out1)
        #print("sig out",sigout1[16])
        #print("sig out",sigout1)
        #print(out1[0])
        #print("hi")
        
        return sigout1
        
            

In [None]:
class CodeBase(nn.Module):
    def __init__(self,device,embed_size,rnn_size,code_vocab_size,code_seq_len,time_vocab_size,batch_size,bmi_flag=False):             
        super(CodeBase, self).__init__()
        self.embed_size=embed_size
        self.rnn_size=rnn_size
        self.code_vocab_size=code_vocab_size
        self.code_seq_len=code_seq_len
        self.batch_size=batch_size
        self.padding_idx = 0
        self.device=device
        self.time_vocab_size=time_vocab_size
        self.bmi_flag=bmi_flag
        self.build()
    
    def build(self):

        self.codeEmbed=nn.Embedding(self.code_vocab_size,self.embed_size,self.padding_idx)
        self.codeRnn = nn.LSTM(input_size=self.embed_size,hidden_size=self.rnn_size,num_layers = 1,batch_first=True)
        
    def forward(self, code):
        #print(conds.shape)
        #ob=code[2]
#        code_time=code[1]
#        code=code[0]
        #print()
        #initialize hidden and cell state
        h_0, c_0 = self.init_hidden()
        h_0, c_0, code = h_0.to(self.device), c_0.to(self.device),code.to(self.device)

        #Embedd all sequences
        print(code.shape)
        #print(code[0,:,:])
        
        #code=torch.transpose(code,1,2)
        #print(code.shape)
        #print(code[0,:,:])
        
        code=self.codeEmbed(code)
        #print(code.shape)
        #print(code[0,0:2,0:3,:])
        
        code=torch.sum(code,1)
        #print(code.shape)
        #code=code.view(code.shape[0],code.shape[1],-1)
        print(code.shape)
        #print(code[0,0:2,0:15])
        #print(code[0,:,:])

        h_0, c_0, code = h_0.to(self.device), c_0.to(self.device),code.to(self.device)

        #code=code.type(torch.FloatTensor)
#        code_time=code_time.type(torch.FloatTensor)
        #h_0, c_0, code = h_0.to(self.device), c_0.to(self.device),code.to(self.device)

#        code=torch.cat((code,code_time),dim=2)
            
        #Run through LSTM
        code_output, (code_h_n, code_c_n)=self.codeRnn(code, (h_0, c_0))
        
        code_h_n=code_h_n.squeeze()
        print("output",code_h_n.shape)
        
        return code_h_n
    
    
    def init_hidden(self):
        # initialize the hidden state and the cell state to zeros
        h=torch.zeros(1,self.batch_size, self.rnn_size)
        c=torch.zeros(1,self.batch_size, self.rnn_size)

#         if self.hparams.on_gpu:
#             hidden_a = hidden_a.cuda()
#             hidden_b = hidden_b.cuda()

        h = Variable(h)
        c = Variable(c)

        return (h, c)    
    

In [None]:
class LSTMAttn(nn.Module):
    def __init__(self,device,cond_vocab_size,cond_seq_len,proc_vocab_size,proc_seq_len,med_vocab_size,med_seq_len,lab_vocab_size,lab_seq_len,bmi_vocab_size,bmi_seq_len,time_vocab_size,census_vocab_size,embed_size,rnn_size,batch_size): #proc_vocab_size,med_vocab_size,lab_vocab_size
        super(LSTMAttn, self).__init__()
        self.embed_size=embed_size
        self.rnn_size=rnn_size
        self.cond_vocab_size=cond_vocab_size
        self.cond_seq_len=cond_seq_len
        self.proc_vocab_size=proc_vocab_size
        self.proc_seq_len=proc_seq_len
        self.med_vocab_size=med_vocab_size
        self.med_seq_len=med_seq_len
        self.lab_vocab_size=lab_vocab_size
        self.lab_seq_len=lab_seq_len
        self.bmi_vocab_size=bmi_vocab_size
        self.bmi_seq_len=bmi_seq_len
        self.time_vocab_size=time_vocab_size
        self.census_vocab_size=census_vocab_size
        self.batch_size=batch_size
        self.padding_idx = 0
        self.modalities=1
        self.device=device
        self.build()
        
    def build(self):
        #self.cond=CodeLSTMAll(self.device,self.embed_size,self.rnn_size,self.cond_vocab_size,self.cond_seq_len,self.time_vocab_size,self.batch_size,bmi_flag=False)
        #self.proc=CodeLSTMAll(self.device,self.embed_size,self.rnn_size,self.proc_vocab_size,self.proc_seq_len,self.time_vocab_size,self.batch_size,bmi_flag=False)
        self.med=CodeAttn(self.device,self.embed_size,self.rnn_size,self.med_vocab_size,self.med_seq_len,self.time_vocab_size,self.batch_size,bmi_flag=False)
        #self.lab=CodeLSTM(self.device,self.embed_size,self.rnn_size,self.lab_vocab_size,self.lab_seq_len,self.time_vocab_size,self.batch_size)
        
        self.condEmbed=nn.Embedding(self.cond_vocab_size,self.embed_size,self.padding_idx) 
        self.cond_fc=nn.Linear(self.rnn_size, 1, False)
        
        #self.censusEmbed=nn.Embedding(self.census_vocab_size,self.embed_size,self.padding_idx)             
        #self.census_max = nn.AdaptiveMaxPool1d(1, True)        
        
        
        self.fc=nn.Linear((self.embed_size*self.cond_seq_len)+self.rnn_size, 1, False)
        
        self.sig = nn.Sigmoid()
        
    def forward(self, meds,conds,contrib):        
        
        med_h_n = self.med(meds)  
        med_h_n=med_h_n.view(med_h_n.shape[0],-1)
        print("med_h_n",med_h_n.shape)
        
        conds=conds.to(self.device)
        conds=self.condEmbed(conds)
        print(conds.shape)
        conds=conds.view(conds.shape[0],-1)
        print(conds.shape)
        #print("cond_pool_ob",cond_pool_ob.shape)
        #out1=torch.cat((cond_pool,cond_pool_ob),1)
        #out1=cond_pool
        out1=torch.cat((conds,med_h_n),1)
        print("out1",out1.shape)
        out1 = self.fc(out1)
        #print("out1",out1.shape)
        
        sigout1 = self.sig(out1)
        #print("sig out",sigout1[16])
        #print("sig out",sigout1)
        #print(out1[0])
        #print("hi")
        
        return sigout1
        
            

In [None]:
class CodeAttn(nn.Module):
    def __init__(self,device,embed_size,rnn_size,code_vocab_size,code_seq_len,time_vocab_size,batch_size,bmi_flag=False):             
        super(CodeAttn, self).__init__()
        self.embed_size=embed_size
        self.rnn_size=rnn_size
        self.code_vocab_size=code_vocab_size
        self.code_seq_len=code_seq_len
        self.batch_size=batch_size
        self.padding_idx = 0
        self.device=device
        self.time_vocab_size=time_vocab_size
        self.bmi_flag=bmi_flag
        self.build()
    
    def build(self):

        self.codeEmbed=nn.Embedding(self.code_vocab_size,self.embed_size,self.padding_idx)
        self.codeRnn = nn.LSTM(input_size=self.embed_size,hidden_size=self.rnn_size,num_layers = 1,batch_first=True)
        self.code_fc=nn.Linear(self.rnn_size, 1, False)
        
    def forward(self, code):
        #print(conds.shape)
        #ob=code[2]
#        code_time=code[1]
#        code=code[0]
        #print()
        #initialize hidden and cell state
        h_0, c_0 = self.init_hidden()
        h_0, c_0, code = h_0.to(self.device), c_0.to(self.device),code.to(self.device)

        #Embedd all sequences
        print(code.shape)
        #print(code[0,:,:])
        
        #code=torch.transpose(code,1,2)
        #print(code.shape)
        #print(code[0,:,:])
        
        code=self.codeEmbed(code)
        #print(code.shape)
        #print(code[0,0:2,0:3,:])
        
        code=torch.sum(code,1)
        #print(code.shape)
        #code=code.view(code.shape[0],code.shape[1],-1)
        print(code.shape)
        #print(code[0,0:2,0:15])
        #print(code[0,:,:])

        h_0, c_0, code = h_0.to(self.device), c_0.to(self.device),code.to(self.device)

        #code=code.type(torch.FloatTensor)
#        code_time=code_time.type(torch.FloatTensor)
        #h_0, c_0, code = h_0.to(self.device), c_0.to(self.device),code.to(self.device)

#        code=torch.cat((code,code_time),dim=2)
            
        #Run through LSTM
        code_output, (code_h_n, code_c_n)=self.codeRnn(code, (h_0, c_0))
        print("code_output",code_output.shape)
        
        code_softmax=self.code_fc(code_output)
        print("softmax",code_softmax.shape)
        code_softmax=F.softmax(code_softmax)
        print("softmax",code_softmax.shape)
        code_softmax=torch.sum(torch.mul(code_output,code_softmax),dim=1)
        print("softmax",code_softmax.shape)
        #print("========================")
        
        return code_softmax
    
    
    def init_hidden(self):
        # initialize the hidden state and the cell state to zeros
        h=torch.zeros(1,self.batch_size, self.rnn_size)
        c=torch.zeros(1,self.batch_size, self.rnn_size)

#         if self.hparams.on_gpu:
#             hidden_a = hidden_a.cuda()
#             hidden_b = hidden_b.cuda()

        h = Variable(h)
        c = Variable(c)

        return (h, c)    
    