In [None]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import time
import numpy as np
from torch.utils.data import DataLoader, Dataset
import random
import torch.nn.init as init
from functools import partial
import math
import os
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
from multiprocessing import JoinableQueue, Queue, Process

In [None]:
RELNUM = 1345  # relation number of dataset
ENTNUM = 14951 # entity number of datset
EMBDIM = 200 # embedding dim
DROPOUT = 0.2 # dropout rate

trainTriple = [[int(line.strip().split('\t')[0]), int(line.strip().split('\t')[1]), int(line.strip().split('\t')[2])] for line in open('./fb15k/train.txt', encoding='utf-8').readlines()]
testTriple = [[int(line.strip().split('\t')[0]), int(line.strip().split('\t')[1]), int(line.strip().split('\t')[2])] for line in open('./fb15k/test.txt', encoding='utf-8').readlines()]
validTriple = [[int(line.strip().split('\t')[0]), int(line.strip().split('\t')[1]), int(line.strip().split('\t')[2])] for line in open('./fb15k/valid.txt', encoding='utf-8').readlines()]

In [None]:
def tri2pair(triList):
    hrtPair = {}
    trhPair = {}
    for h, t, r in triList:
        if h not in hrtPair:
            hrtPair[h] = {}
        if r not in hrtPair[h]:
            hrtPair[h][r] = set()
        hrtPair[h][r].add(t)
        
        if t not in trhPair:
            trhPair[t] = {}
        if r not in trhPair[t]:
            trhPair[t][r] = set()
        trhPair[t][r].add(h)
    
    return hrtPair, trhPair

trainHrtPair, trainTrhPair = tri2pair(trainTriple)
allHrtPair, allTrhPair = tri2pair(trainTriple + testTriple + validTriple)

In [None]:
class dset(Dataset):
    def __init__(self, triple, hrtPair, trhPair, train, prob):
        self.hrtPair = hrtPair
        self.trhPair = trhPair
        self.triple = triple
        self.train = train
        self.sampleNum = int(ENTNUM*prob)
        
    def __getitem__(self, index):
        if self.train:
            hrtPos = self.hrtPair[self.triple[index][0]][self.triple[index][2]]       
            hrtNeg = set(range(ENTNUM)) - hrtPos
            hrtSampleNeg = random.sample(hrtNeg, self.sampleNum - len(hrtPos))
            hrtSample = np.array(list(hrtPos) + hrtSampleNeg)
            hrtlPos = np.array([1.]*len(hrtPos)+[0.]*len(hrtSampleNeg))
            
            trhPos = self.trhPair[self.triple[index][1]][self.triple[index][2]]      
            trhNeg = set(range(ENTNUM)) - trhPos
            trhSampleNeg = random.sample(trhNeg, self.sampleNum - len(trhPos))
            trhSample = np.array(list(trhPos) + trhSampleNeg)
            trhlPos = np.array([1.]*len(trhPos)+[0.]*len(trhSampleNeg))
            
            return self.triple[index][0], self.triple[index][1], self.triple[index][2], hrtSample, hrtlPos, trhSample, trhlPos
        else:
            return self.triple[index][0], self.triple[index][1], self.triple[index][2]
    def __len__(self):
        return len(self.triple)

In [None]:
bound = 6 / math.sqrt(EMBDIM)

class naiveModel(nn.Module):
    def __init__(self):
        super(naiveModel, self).__init__()
        
        init1 = partial(init.uniform_, a=-bound, b=bound)

        self.entEmbedding = nn.Embedding(ENTNUM, EMBDIM, sparse=False)
        self.relEmbedding = nn.Embedding(RELNUM, EMBDIM, sparse=False)
        
        self.entEmbedding.weight = init1(self.entEmbedding.weight)
        self.relEmbedding.weight = init1(self.relEmbedding.weight)
        
        self.entHrtW = init1(nn.Parameter(torch.Tensor(EMBDIM)))
        self.relHrtW = init1(nn.Parameter(torch.Tensor(EMBDIM)))
        self.hrtB = nn.Parameter(torch.zeros(EMBDIM))

        self.entTrhW = init1(nn.Parameter(torch.Tensor(EMBDIM)))
        self.relTrhW = init1(nn.Parameter(torch.Tensor(EMBDIM)))
        self.trhB = nn.Parameter(torch.zeros(EMBDIM))

        self.dropout = nn.Dropout(DROPOUT)
        
    def forward(self, hID, rID, tID):
        hEmbedding = self.entEmbedding(hID)
        rEmbedding = self.relEmbedding(rID)
        tEmbedding = self.entEmbedding(tID)
        
        hrt = self.dropout(F.tanh(hEmbedding * self.entHrtW + rEmbedding * self.relHrtW + self.hrtB))
        hrt = torch.mm(hrt, self.entEmbedding.weight.t())
        trh = self.dropout(F.tanh(tEmbedding * self.entTrhW + rEmbedding * self.relTrhW + self.trhB))
        trh = torch.mm(trh, self.entEmbedding.weight.t())
        
        return hrt, trh

In [None]:
bound = 6 / math.sqrt(EMBDIM)

class concatModel(nn.Module):
    def __init__(self):
        super(concatModel, self).__init__()
        
        init1 = partial(init.uniform_, a=-bound, b=bound)

        self.entEmbedding = nn.Embedding(ENTNUM, EMBDIM, sparse=False)
        self.relEmbedding = nn.Embedding(RELNUM, EMBDIM, sparse=False)
        
        self.entEmbedding.weight = init1(self.entEmbedding.weight)
        self.relEmbedding.weight = init1(self.relEmbedding.weight)
        
        self.hrtLinear = nn.Linear(EMBDIM*2, EMBDIM)
        self.hrtLinear.weight = init1(self.hrtLinear.weight)
        
        self.trhLinear = nn.Linear(EMBDIM*2, EMBDIM)
        self.trhLinear.weight = init1(self.trhLinear.weight)
        
        self.dropout = nn.Dropout(DROPOUT)
        
    def forward(self, hID, rID, tID):
        hEmbedding = self.entEmbedding(hID)
        rEmbedding = self.relEmbedding(rID)
        tEmbedding = self.entEmbedding(tID)
        
        hrt = self.dropout(F.tanh(self.hrtLinear(torch.cat((hEmbedding, rEmbedding), 1))))
        hrt = torch.mm(hrt, self.entEmbedding.weight.t())
        
        trh = self.dropout(F.tanh(self.trhLinear(torch.cat((tEmbedding, rEmbedding), 1))))
        trh = torch.mm(trh, self.entEmbedding.weight.t())
        
        return hrt, trh


In [None]:
def test_evaluation(hList, tList, hrtSort, trhSort, hr_t, tr_h):

    mean_rank_h = list()
    filtered_mean_rank_h = list()
    mean_rank_t = list()
    filtered_mean_rank_t = list()


    for i in range(len(hList)):
        h = hList
        t = tList

        mr = 1
        for val in trhSort[i]:
            if val == hList[i]:
                mean_rank_h.append(mr)
                break
            mr += 1

        mr = 1
        for val in hrtSort[i]:
            if val == tList[i]:
                mean_rank_t.append(mr)
                break
            mr += 1

        fmr = 1
        for val in trhSort[i]:
            if val == hList[i]:
                filtered_mean_rank_h.append(fmr)
                break
            
            if t[i].item() in tr_h and r[i].item() in tr_h[t[i].item()] and val.item() in tr_h[t[i].item()][r[i].item()]:
                continue
            else:
                fmr += 1

        fmr = 1
        for val in hrtSort[i]:
            if val == tList[i]:
                filtered_mean_rank_t.append(fmr)
                break
            if h[i].item() in hr_t and r[i].item() in hr_t[h[i].item()] and val.item() in hr_t[h[i].item()][r[i].item()]:
                continue
            else:
                fmr += 1

    return mean_rank_h, filtered_mean_rank_h, mean_rank_t, filtered_mean_rank_t

In [None]:
def eva_func(in_queue, out_queue, hr_t, tr_h):
    while True:
        dat = in_queue.get()
        if dat is None:
            in_queue.task_done()
            continue
        hList, tList, hrtSort, trhSort = dat
        out_queue.put(test_evaluation(hList, tList, hrtSort, trhSort, hr_t, tr_h))
        in_queue.task_done()


def testFunc(model, data_loader, hrtPair, trhPair):
    accu_rank_h = []
    accu_rank_fh = []
    accu_rank_t = []
    accu_rank_ft = []
    count = 0
    
    evaluation_queue = JoinableQueue()
    result_queue = Queue()
    
    model.train(False)
    r = []
    for i in range(20):
        worker = Process(target=eva_func, args=(evaluation_queue, result_queue, hrtPair, trhPair))
        worker.start()
        
    for i, (h, t, r) in enumerate(data_loader):
        with torch.no_grad():
            hV = h.cuda()
            tV = t.cuda()
            rV = r.cuda()
            hrtRes, trhRes = model(hV, rV, tV)
    
        _, hrtSort = torch.sort(hrtRes, descending=True)
        _, trhSort = torch.sort(trhRes, descending=True)
        
        evaluation_queue.put((h, t, hrtSort.cpu().data, trhSort.cpu().data))
        count += 1
        
    for i in range(20):
        evaluation_queue.put(None)
    
    evaluation_queue.join()
    
    while count > 0:
        count -= 1

        rh, frh, rt, frt = result_queue.get()

        accu_rank_h += rh
        accu_rank_t += rt
        accu_rank_fh += frh
        accu_rank_ft += frt
    
    # print MR, MRR, top10, top3, top1
    
    print(
        "MEAN RANK: %.2f %.4f %.4f %.4f %.4f" %
        (np.mean(accu_rank_h), np.mean(1 / np.asarray(accu_rank_h, dtype=np.int32)), np.mean(np.asarray(accu_rank_h, dtype=np.int32) <= 10), np.mean(np.asarray(accu_rank_h, dtype=np.int32) <= 3), np.mean(np.asarray(accu_rank_h, dtype=np.int32) <= 1)))
         
    print(
        "MEAN RANK: %.2f %.4f %.4f %.4f %.4f" %
        (np.mean(accu_rank_fh), np.mean(1 / np.asarray(accu_rank_fh, dtype=np.int32)), np.mean(np.asarray(accu_rank_fh, dtype=np.int32) <= 10), np.mean(np.asarray(accu_rank_fh, dtype=np.int32) <= 3), np.mean(np.asarray(accu_rank_fh, dtype=np.int32) <= 1)))

    print(
        "MEAN RANK: %.2f %.4f %.4f %.4f %.4f" %
        (np.mean(accu_rank_t), np.mean(1 / np.asarray(accu_rank_t, dtype=np.int32)), np.mean(np.asarray(accu_rank_t, dtype=np.int32) <= 10), np.mean(np.asarray(accu_rank_t, dtype=np.int32) <= 3), np.mean(np.asarray(accu_rank_t, dtype=np.int32) <= 1)))

    print(
        "MEAN RANK: %.2f %.4f %.4f %.4f %.4f" %
        (np.mean(accu_rank_ft), np.mean(1 / np.asarray(accu_rank_ft, dtype=np.int32)), np.mean(np.asarray(accu_rank_ft, dtype=np.int32) <= 10), np.mean(np.asarray(accu_rank_ft, dtype=np.int32) <= 3), np.mean(np.asarray(accu_rank_ft, dtype=np.int32) <= 1)))



In [None]:
trainData = dset(trainTriple, trainHrtPair, trainTrhPair, True, 0.25)
testData = dset(testTriple, allHrtPair, allTrhPair, False, 0)
validData = dset(validTriple, allHrtPair, allTrhPair, False, 0)

train_loader = DataLoader(trainData, batch_size=512, num_workers=20, shuffle=True)
test_loader = DataLoader(testData, batch_size=2048, num_workers=2, shuffle=False)

model = naiveModel()
model = model.cuda()
ttt = time.time()

In [None]:
LR = 0.01
optim = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)

ttt = time.time()
for epoch in range(80):
    model.train(True)
    for i, (h, t, r, hrtSample, hrtlPos, trhSample, trhlPos) in enumerate(train_loader):
        hV = Variable(h).cuda()
        rV = Variable(r).cuda()
        tV = Variable(t).cuda()
        hrtSampleV = Variable(hrtSample).cuda()
        hrtlPosV = Variable(torch.Tensor.float(hrtlPos)).cuda()
        trhSampleV = Variable(trhSample).cuda()
        trhlPosV = Variable(torch.Tensor.float(trhlPos)).cuda()
        
        optim.zero_grad()
        hrtRes, trhRes = model(hV, rV, tV)
        
#         # pointwise
#         sample_hrt_t = torch.gather(hrt_t, 1, hrt_tSampleV)
#         sample_hrt_t = sample_hrt_t * hrtlPosV
#         hrt_softmax_t = F.logsigmoid(sample_hrt_t)
#         hrt_softmax_t = hrt_softmax_t.clamp(-23., 0.)
#         hrt_loss_t = hrt_softmax_t.sum().neg()

#         sample_trh_h = torch.gather(trh_h, 1, trh_hSampleV)
#         sample_trh_h = sample_trh_h * trhlPosV
#         trh_softmax_h = F.logsigmoid(sample_trh_h)
#         trh_softmax_h = trh_softmax_h.clamp(-23., 0.)
#         trh_loss_h = trh_softmax_h.sum().neg()

        # wlistwise
        sample_hrt_res = torch.gather(hrtRes, 1, hrtSampleV)
        hrt_softmax = F.log_softmax(sample_hrt_res)
        hrt_softmax = hrt_softmax.clamp(-23., 0.)
        hrt_loss = (torch.sum(hrt_softmax * hrtlPosV, 1)/hrtlPosV.sum(1)).sum().neg()

        sample_trh_res = torch.gather(trhRes, 1, trhSampleV)
        trh_softmax = F.log_softmax(sample_trh_res)
        trh_softmax = trh_softmax.clamp(-23., 0.)
        trh_loss = (torch.sum(trh_softmax * trhlPosV, 1)/trhlPosV.sum(1)).sum().neg()
                                  
        loss = hrt_loss + trh_loss
        loss.backward()

        optim.step()
        if i%200 == 0:
            print(loss.item())
    
    if (epoch+1) % 10 == 0:
        model.train(False)
        test_loader = DataLoader(testData, num_workers=2, shuffle=False)
        testFunc(model, test_loader, allHrtPair, allTrhPair)

    print(time.time()-ttt)
    print(epoch)
