In [20]:
import pandas as pd
import numpy as np
from scipy.special import softmax # ver>=1.20

In [21]:
# !wget https://raw.githubusercontent.com/dhimmel/hsdn/gh-pages/data/symptoms-DO.tsv

In [71]:
def get_trans_mat():
    df = pd.read_csv("symptoms-DO.tsv",sep = '\t')
    d_set = list(df['disease_name'].unique())
    s_set = list(df['symptom_name'].unique())
    tran_mat = np.full((len(s_set), len(d_set)), -np.inf)
    for i, s_name in enumerate(s_set):
        d_subset = df[df['symptom_name'] == s_name]['disease_name'].unique()
        js = [d_set.index(x) for x in d_subset]
        tran_mat[i,js] = 1
    return softmax(tran_mat, axis=1), s_set, d_set

In [72]:
tran_mat, row_name, col_name = get_trans_mat()

In [73]:
tran_mat.shape

(316, 119)

In [76]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [88]:
class KnowledgeBase(nn.Module):
    def __init__(self, hidden_size, embed_size, row_name, col_name, transition):
        super(KnowledgeBase, self).__init__()
        self.dis_embed = nn.Embedding(len(col_name), embed_size)
        self.sym_embed = nn.Embedding(len(row_name), embed_size)
        self.n_sym = len(row_name)
        self.n_dis = len(col_name)
        self.W = transition
        self.attn = nn.Linear(hidden_size+embed_size, hidden_size)
        self.v = nn.Parameter(torch.rand(1, hidden_size))
        stdv = 1. / np.sqrt(self.v.size(0))
        self.v.data.normal_(mean=0, std=stdv)
        
    def forward(self, h, batch_size):
        sym_rep = self.sym_embed(torch.arange(self.n_sym)).repeat(batch_size,1,1)
        h = h.repeat(self.n_sym, 1, 1).transpose(0,1)
        energy =  torch.tanh(self.attn(torch.cat([h, sym_rep],dim=2)))
        energy = energy.transpose(1, 2)
        v = self.v.repeat(batch_size,1 , 1)
        score = torch.bmm(v, energy)
        attnweight_sym = torch.softmax(score, dim = 2)
        tran_mat = self.W.repeat(batch_size, 1, 1).type(torch.FloatTensor)
        attnweight_dis = torch.bmm(attnweight_sym, tran_mat)
        V = torch.bmm(self.dis_embed(torch.arange(self.n_dis)).repeat(batch_size,1,1).transpose(1,2),
                      attnweight_dis.transpose(1,2)).squeeze(2)
        return V
        
# tran_mat = torch.from_numpy(tran_mat).to(device)
kb = KnowledgeBase(100, 100, row_name, col_name, tran_mat)
kb(torch.rand(8, 100), 8)

tensor([[ 0.1623,  0.1545, -0.1456, -0.1516, -0.0378, -0.1204, -0.4996,  0.1679,
          0.3518,  0.0887,  0.0953,  0.0751, -0.3093,  0.0810,  0.2835,  0.2610,
         -0.6895, -0.5181,  0.0513, -0.1765, -0.0513, -0.1011,  0.0909,  0.5021,
          0.0364, -0.1343, -0.1381,  0.1234,  0.4001, -0.0421,  0.1451, -0.6458,
          0.0673, -0.2595, -0.2049, -0.4400, -0.3555,  0.3233,  0.0557, -0.0059,
         -0.1668,  0.7617,  0.2963,  0.0750,  0.2388, -0.4716, -0.1207, -0.3473,
          0.1337,  0.0343,  0.0648,  0.2415, -0.4162, -0.1515, -0.4681,  0.1095,
          0.0964,  0.0278, -0.0973, -0.2629,  0.0734, -0.2071,  0.2631, -0.4057,
         -0.0360, -0.0200, -0.0772,  0.1103, -0.0927,  0.3580, -0.0429, -0.0387,
          0.4007, -0.1783,  0.1926,  0.2801,  0.5725, -0.0466,  0.1270, -0.1231,
          0.0302,  0.2360, -0.2447,  0.4589, -0.3433,  0.0687,  0.4650,  0.1946,
          0.0086, -0.2577, -0.3497,  0.2546, -0.0512,  0.3302,  0.1868,  0.3761,
          0.0338,  0.2752,  