In [1]:
# Synchronize to google drive, define the root path
import os

google_colab  = True
if google_colab == True:
  #This statement used to pointing the google drive storage 
  from google.colab import drive
  drive.mount('/content/gdrive')
  root_path = 'gdrive/My Drive/Colab Notebooks/'

  #This statement is purposed to import library from the Drive
  os.chdir('gdrive/My Drive/Colab Notebooks/')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
!pip install rdflib

Collecting rdflib
[?25l  Downloading https://files.pythonhosted.org/packages/3c/fe/630bacb652680f6d481b9febbb3e2c3869194a1a5fc3401a4a41195a2f8f/rdflib-4.2.2-py3-none-any.whl (344kB)
[K     |████████████████████████████████| 348kB 2.9MB/s 
[?25hCollecting isodate
[?25l  Downloading https://files.pythonhosted.org/packages/9b/9f/b36f7774ff5ea8e428fdcfc4bb332c39ee5b9362ddd3d40d9516a55221b2/isodate-0.6.0-py2.py3-none-any.whl (45kB)
[K     |████████████████████████████████| 51kB 6.5MB/s 
Installing collected packages: isodate, rdflib
Successfully installed isodate-0.6.0 rdflib-4.2.2


In [0]:
import os
import os.path as path
import re
import sys
import numpy as np
import torch

import rdflib as rdf
import gzip
import networkx as nx
import scipy.sparse as sp
import scipy
from collections import OrderedDict

class OrderedGraph(nx.MultiDiGraph):
  node_dict_factory = OrderedDict

def parser(f):
    triples = list()
    for i, triple in enumerate(f):
        # extract subject
        sub = triple.strip().replace("<", "").split(">")[0]
        sub = sub[sub.rfind("/")+1:]
        # extract content from "content"
        if "\"" in sub:
            pattern = re.compile('"(.*)"')
            try:
                sub_new = pattern.findall(sub)[0]
            except IndexError:
                # like "United States/Australian victory"
                sub = sub.replace("\"", "").strip()
                sub_new = sub
        # extract content from ":content"
        elif ":" in sub:
            pattern = re.compile(':(.*)')
            sub_new = pattern.findall(sub)[0]
        else:
            sub_new = sub
        sub_new = sub_new.replace(" ", "")

        # extract object
        obj = triple.strip().replace("<", "").split(">")[2]
        # fix extract content form "content\"
        if obj.rfind("/")+1 == len(obj):
            obj = obj[:-1]
        obj = obj[obj.rfind("/")+1:]
        # extract content from "content"
        if "\"" in obj:
            pattern = re.compile('"(.*)"')
            try:
                obj_new = pattern.findall(obj)[0]
            except IndexError:
                # like "United States/Australian victory"
                obj = obj.replace("\"", "").strip()
                obj_new = obj
        # extract content from ":content"
        elif ":" in obj:
            pattern = re.compile(':(.*)')
            obj_new = pattern.findall(obj)[0]
        else:
            obj_new = obj
        obj_new = obj_new.replace(" ", "")
        if obj_new == "":
            obj_new = "UNK"
        
        # extract predicate
        pred = triple.strip().replace("<", "").split(">")[1]
        pred = pred[pred.rfind("/")+1:]
        if "#" in pred:
            pattern = re.compile('#(.*)')
            pred_new = pattern.findall(pred)[0]
        elif ":" in pred:
            pattern = re.compile(':(.*)')
            pred_new = pattern.findall(pred)[0]
        else:
            pred_new = pred
        pred_new = pred_new.replace(" ", "")
        if not (sub_new == "" or pred_new == "" or obj_new == ""):
            triple_tuple = (i, sub, pred, obj, sub_new.replace(" ", ""), pred_new.replace(" ", ""), obj_new.replace(" ", ""))
            triples.append(triple_tuple)
        else:
            print(triple)
    return triples

# prepare data for per entity
def prepare_data(db_path, num):
    with open(path.join(db_path, 
        "{}".format(num), 
        "{}_desc.nt".format(num)),
        encoding="utf8") as f:
        triples = parser(f)
    return triples

# prepeare label for per label
def prepare_label(db_path, num, top_n, file_n):
    per_entity_label_dict = {}
    for i in range(file_n):
        with open(path.join(db_path, 
            "{}".format(num), 
            "{}_gold_top{}_{}.nt".format(num, top_n, i).format(num)),
            encoding="utf8") as f:
            labels  = parser(f)
            for _, _, _, _, _, pred_new, obj_new in labels:
                counter(per_entity_label_dict, "{}++$++{}".format(pred_new, obj_new))
    return per_entity_label_dict

# dict counter
def counter(cur_dict, word):
    if word in cur_dict:
        cur_dict[word] += 1
    else:
        cur_dict[word] = 1

# prepare data graph for per entity
def build_graph(db_path, num):
    G = nx.Graph()
    with RDFReader(path.join(db_path, "{}".format(num), "{}_desc.nt".format(num))) as reader:
      relations = reader.relationList()
      subjects = reader.subjectSet()
      objects = reader.objectSet()
      relations_dict = {rel: i+1 for i, rel in enumerate(list(relations))}
      relations_dict.update({'UNK':0})
      
      triples=list()
      nodes_dict={}
      nodes_dict.update({'UNK':0})
      for (s, p, o) in reader.triples():
        if s not in nodes_dict:
          nodes_dict[s] = len(nodes_dict)
        
        if o not in nodes_dict:
          nodes_dict[o] = len(nodes_dict)
        else:
          n = 1
          
          new_status = True
          while new_status: 
            o_new = str(o) + '_{}'.format(n)
            if o_new not in nodes_dict:
              new_status = False
            n +=1
          nodes_dict[o_new] = len(nodes_dict)
          o = o_new
        triple_tuple = (s, p, o)
        triples.append(triple_tuple)
      nodes_index = [nodes_dict[node] for node in nodes_dict.keys()]
      
      for (s, p, o) in triples:
        G.add_node(nodes_dict[s])
        G.add_node(nodes_dict[o])
        G.add_edge(nodes_dict[s], nodes_dict[o])
    return G, triples  

def process_data(db_name, db_start, db_end, top_n=10, file_n=6):
    if db_name == "dbpedia":
        db_path = path.join(path.join("data"), "dbpedia")
    elif db_name == "lmdb":
        db_path = path.join(path.join("data"), "lmdb")
    else:
        raise ValueError("The database's name must be dbpedia or lmdb")

    data, data_for_transE = [], []
    label = []
    data_graph = []
    for i in range(db_start[0], db_end[0]):
        graph, triples = build_graph(db_path, i)
        sub_data = []
        per_entity_data = prepare_data(db_path, i)
        for _, _, _, _, _, pred_new, obj_new in per_entity_data:
          sub_data.append([pred_new, obj_new])
        if len(sub_data) < graph.number_of_nodes():
          n_subdata = graph.number_of_nodes() - len(sub_data)
          for n in range(n_subdata):  
            sub_data.append(['UNK', 'UNK'])
        data.append(sub_data)
        data_for_transE.extend([[sub_new, obj_new, pred_new]for _, _, _, _, sub_new, pred_new, obj_new in per_entity_data]) 
        data_graph.append(graph)
    for i in range(db_start[1], db_end[1]):
        graph, triples = build_graph(db_path, i)
        sub_data = []
        per_entity_data = prepare_data(db_path, i)
        for _, _, _, _, _, pred_new, obj_new in per_entity_data:
          sub_data.append([pred_new, obj_new])
        if len(sub_data) < graph.number_of_nodes():
          n_subdata = graph.number_of_nodes() - len(sub_data)
          for n in range(n_subdata):  
            sub_data.append(['UNK', 'UNK'])
        data.append(sub_data)
        data_for_transE.extend([[sub_new, obj_new, pred_new]for _, _, _, _, sub_new, pred_new, obj_new in per_entity_data])
        data_graph.append(graph)

    for i in range(db_start[0], db_end[0]): 
        per_entity_label_dict = prepare_label(db_path, i, top_n=top_n, file_n=file_n)
        label.append(per_entity_label_dict)

    for i in range(db_start[1], db_end[1]): 
        per_entity_label_dict = prepare_label(db_path, i, top_n=top_n, file_n=file_n)
        label.append(per_entity_label_dict)
        
    # entity dict
    entity2ix = {}
    entity2ix['UNK'] = 0
    for sub_new, obj_new, _ in data_for_transE:
        if sub_new not in entity2ix:
            entity2ix[sub_new] = len(entity2ix)
        if obj_new not in entity2ix:
            entity2ix[obj_new] = len(entity2ix)

    # pred dict
    pred2ix = {}  
    pred2ix['UNK']= 0
    for _, _, pred_new in data_for_transE:
        if pred_new not in pred2ix:
            pred2ix[pred_new] = len(pred2ix)
      
    return data, data_for_transE, label, entity2ix, pred2ix, data_graph

def gen_data_transE(db_name, entity_to_ix, pred_to_ix, data_for_transE):
    # make dir
    if db_name == "dbpedia":
        directory = path.join(path.join("data"), "dbpedia_transE")
    elif db_name == "lmdb":
        directory = path.join(path.join("data"), "lmdb_transE")
    else:
        raise ValueError("The database's name must be dbpedia or lmdb")
    if not path.exists(directory):
        os.makedirs(directory)

    with open(path.join(directory, "entity2id.txt"), "w", encoding="utf-8") as f:
        dict_sorted =  sorted(entity_to_ix.items(), key = lambda x:x[1], reverse = False)
        f.write("{}\n".format(len(entity_to_ix)))
        for entity in dict_sorted:
            f.write("{}\t{}\n".format(entity[0], entity[1]))

    with open(path.join(directory, "relation2id.txt"), "w", encoding="utf-8") as f:
        dict_sorted =  sorted(pred_to_ix.items(), key = lambda x:x[1], reverse = False)
        f.write("{}\n".format(len(pred_to_ix)))
        for relation in dict_sorted:
            f.write("{}\t{}\n".format(relation[0], relation[1]))

    with open(path.join(directory, "train2id.txt"), "w", encoding="utf-8") as f:    
        # train2id 
        f.write("{}\n".format(len(data_for_transE)))
        for [sub, obj, pred] in data_for_transE:
            f.write("{}\t{}\t{}\n".format(entity_to_ix[sub], entity_to_ix[obj], pred_to_ix[pred]))

# load transE
def build_dict(f_path):
    word2ix = {}
    with open(f_path, "r", encoding="utf-8") as f:
        for _, pair in enumerate(f):
            try:
                temp = pair.strip().split("\t")
                word2ix[temp[0]] = int(temp[1])
            except:
                print(temp)
    return word2ix

def build_vec(word2ix, word_embedding):
    word2vec = {}
    for word in word2ix:
        word2vec[word] = word_embedding[int(word2ix[word])]
    return word2vec

def load_transE(db_name):
    if db_name == "dbpedia":
        directory = path.join(path.join("data"), "dbpedia_transE")
    elif db_name == "lmdb":
        directory = path.join(path.join("data"), "lmdb_transE")
    else:
        raise ValueError("The database's name must be dbpedia or lmdb")

    entity2ix = build_dict(path.join(directory, "entity2id.txt"))
    pred2ix = build_dict(path.join(directory, "relation2id.txt"))

    embedding = np.load(path.join(directory, "transE_vec.npz"))
    entity_embedding = embedding["ent_embedding"]
    pred_embedding = embedding["rel_embedding"]

    entity2vec = build_vec(entity2ix, entity_embedding)
    pred2vec = build_vec(pred2ix, pred_embedding)
    return entity2vec, pred2vec, entity2ix, pred2ix

def tensor_from_data(entity2vec, pred2ix, data):
    pred_list, obj_list = [], []
    for pred, obj in data:
        pred_list.append(pred2ix[pred])
        obj_list.append(entity2vec[obj])
    pred_tensor = torch.tensor(pred_list).view(-1, 1)
    obj_tensor = torch.tensor(obj_list).unsqueeze(1)
    return pred_tensor, obj_tensor

def tensor_from_weight(tensor_size, data, label):
    weight_tensor = torch.zeros(tensor_size)
    for label_word in label:
        order = -1
        for pred, obj in data:
            order += 1
            data_word = "{}++$++{}".format(pred, obj)
            if label_word == data_word:
                weight_tensor[order] += label[label_word]
                break
    return weight_tensor / torch.sum(weight_tensor)

# split data for cross validation
def split_data(base, num, data, label, data_graph):
    start = num * base
    end = (num + 1) * base
    test_data = data[start:end]
    test_label = label[start:end]
    test_graph = data_graph[start:end]
    train_data, train_label = [], []
    train_graph = []
    for i, adjacency in enumerate(data_graph):
      if i not in range(start, end):
        train_graph.append(adjacency)
    for i, triples in enumerate(data):
        if i not in range(start, end):
            train_data.append(triples)
    for i, triples in enumerate(label):
        if i not in range(start, end):
            train_label.append(triples)
    return train_data, train_label, train_graph, test_data, test_label, test_graph

if __name__ == "__main__":
   # dbpedia
   #  1 - 100, 141 - 165
   #data, data_for_transE, label, entity2ix, pred2ix = process_data("dbpedia", [1, 141], [101, 166])
   #gen_data_transE("dbpedia", entity2ix, pred2ix, data_for_transE)
   # lmdb
   # 101 - 140, 166 - 176
   #data, data_for_transE, label, entity2ix, pred2ix = process_data("lmdb", [101, 166], [141, 176])
   #gen_data_transE("lmdb", entity2ix, pred2ix, data_for_transE)
   None

In [0]:
from collections import Counter
class RDFReader:
    __graph = None
    __freq = {}

    def __init__(self, file):

        self.__graph = rdf.Graph()

        self.__graph.parse(file, format='nt')

        # See http://rdflib.readthedocs.io for the rdflib documentation

        self.__freq = Counter(self.__graph.predicates())

        #print("Graph loaded, frequencies counted.")

    def triples(self, relation=None):
        for s, p, o in self.__graph.triples((None, relation, None)):
            yield s, p, o

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.__graph.destroy("store")
        self.__graph.close(True)

    def subjectSet(self):
        return set(self.__graph.subjects())

    def objectSet(self):
        return set(self.__graph.objects())

    def relationList(self):
        """
        Returns a list of relations, ordered descending by frequenecy
        :return:
        """
        res = list(set(self.__graph.predicates()))
        res.sort(key=lambda rel: - self.freq(rel))
        return res

    def __len__(self):
        return len(self.__graph)

    def freq(self, relation):
        """
        The frequency of this relation (how many distinct triples does it occur in?)
        :param relation:
        :return:
        """
        if relation not in self.__freq:
            return 0
        return self.__freq[relation]


In [0]:
#data, data_for_transE, label, entity2ix, pred2ix, graph = process_data("dbpedia", [1, 141], [101, 166])
#gen_data_transE("dbpedia", entity2ix, pred2ix, data_for_transE)
#train_data, train_label, train_graph, _, _, _ = split_data(base, i, data, label, data_graph)

In [0]:
#import pandas as pd
#from sklearn.model_selection import train_test_split
#train_data,test_data = train_test_split(data_for_transE,test_size=0.2)
#test_data,valid_data = train_test_split(test_data,test_size=0.5)

In [0]:
#def gen_data_transE(db_name, entity_to_ix, pred_to_ix, data_for_transE):
#    # make dir
#    if db_name == "dbpedia":
#        directory = path.join(path.join("data"), "dbpedia_transE")
#    elif db_name == "lmdb":
#        directory = path.join(path.join("data"), "lmdb_transE")
#    else:
#        raise ValueError("The database's name must be dbpedia or lmdb")
#    if not path.exists(directory):
#        os.makedirs(directory)

#    with open(path.join(directory, "valid2id_.txt"), "w", encoding="utf-8") as f:    
        # train2id 
#        f.write("{}\n".format(len(data_for_transE)))
#        for [sub, obj, pred] in data_for_transE:
#            f.write("{}\t{}\t{}\n".format(entity_to_ix[sub], entity_to_ix[obj], pred_to_ix[pred]))

In [0]:
#gen_data_transE("dbpedia", entity2ix, pred2ix, valid_data)

In [0]:
#data, data_for_transE, label, entity2ix, pred2ix, graph = process_data("lmdb", [101, 166], [141, 176])
#train_data,test_data = train_test_split(data_for_transE,test_size=0.2)
#test_data,valid_data = train_test_split(test_data,test_size=0.5)
#gen_data_transE("lmdb", entity2ix, pred2ix, data_for_transE)

In [0]:
#gen_data_transE("lmdb", entity2ix, pred2ix, test_data)

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pygat.models import GAT, SpGAT
from pygat.utils import normalize_adj, normalize_features
import scipy.sparse as spT
from torch.autograd import Variable

class ESGCN(nn.Module):
    def __init__(self, pred2ix_size, pred_embedding_dim, transE_dim, hidden_size, device):
        super(ESGCN, self).__init__()
        self.pred2ix_size = pred2ix_size
        self.pred_embedding_dim = pred_embedding_dim
        self.transE_dim = transE_dim
        self.input_size = self.transE_dim + self.pred_embedding_dim
        self.hidden_size = hidden_size
        print('hidden_size', hidden_size)
        self.embedding = nn.Embedding(self.pred2ix_size, self.pred_embedding_dim)
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, bidirectional=True)
        self.gat = GAT(nfeat=400, nhid=32, nclass=1, dropout=0.5, nheads=8, alpha=0.2)
        self.device = device
        self.initial_hidden = self._init_hidden()
        
    def forward(self, input_tensor, G):
        # bi-lstm
        pred_embedded = self.embedding(input_tensor[0])
        obj_embedded = input_tensor[1]
        embedded = torch.cat((pred_embedded, obj_embedded), 2)
        lstm_out, (hidden_state, cell_state) = self.lstm(embedded, self.initial_hidden)
        #lstm_out = lstm_out.permute(1, 0, 2)
        lstm_out = torch.flatten(lstm_out, start_dim=1)
        #print('lstm_out', lstm_out)
        #lstm_out = lstm_out.view(lstm_out.shape[0], -1)
        
        #pygcn
        adj = nx.adjacency_matrix(G)
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
        adj = normalize_adj(adj + sp.eye(adj.shape[0]))
        adj = torch.FloatTensor(np.array(adj.todense()))
        #print('adj', adj)

        features = normalize_features(lstm_out.detach().numpy())
        features = torch.FloatTensor(np.array(features))
        
        logits = self.gat(features, adj)
        #logp = F.log_softmax(logits, 1)
        return logits


    def _init_hidden(self):
        return (torch.randn(2, 1, self.hidden_size, device=self.device), 
            torch.randn(2, 1, self.hidden_size, device=self.device))

In [12]:
from __future__ import unicode_literals, print_function, division
import os
import os.path as path
import sys
import argparse
ROOTDIR = path.dirname(os.getcwd())
DATADIR = path.join("data")
sys.path.append(DATADIR)
import utils
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

def evaluate(model, g, features, labels):
    model.eval()
    with th.no_grad():
        logits = model(g, features)
        labels = labels
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

def train(esgcn, data, label, criterion, optimizer, n_epoch, save_every, directory, device, clip, entity2vec, pred2ix, regularization, adj):
    if not path.exists(directory):
        os.makedirs(directory)
    print('n epoch', n_epoch)
    for epoch in range(n_epoch):
        total_loss = 0
        total_acc = 0
        for i in range(len(data)):
            esgcn.zero_grad()
            pred_tensor, obj_tensor = tensor_from_data(entity2vec, pred2ix, data[i])
            input_tensor = [pred_tensor.to(device), obj_tensor.to(device)]
            weight_tensor = tensor_from_weight(len(data[i]), data[i], label[i]).to(device)
            atten_weight = esgcn(input_tensor, adj[i])
            # loss
            if regularization:
                loss = criterion(atten_weight.view(-1), weight_tensor.view(-1)).to(device) + \
                    torch.sum(torch.abs(atten_weight))
            else:
                loss = criterion(atten_weight.view(-1), weight_tensor.view(-1)).to(device)

            # clip gradient
            _ = nn.utils.clip_grad_norm_(esgcn.parameters(), clip)

            #optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        total_loss = total_loss/len(data)
        if epoch % save_every == 0:
            torch.save({
                "epoch": epoch,
                "model_state_dict": esgcn.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": total_loss
                }, path.join(directory, "checkpoint_epoch_{}.pt".format(epoch)))
        print("epoch: {}".format(epoch), total_loss)

def train_iter(db_name, base, data, label, pred2ix, pred2ix_size, entity2vec, pred_embedding_dim, transe_dim, hidden_size, criterion, clip, lr, n_epoch, save_every, regularization, device, data_graph):
    if regularization == True:
        print("use regularization in training")
    for i in range(5):
        train_data, train_label, train_graph, _, _, _ = split_data(base, i, data, label, data_graph)
        esgcn = ESGCN(pred2ix_size, pred_embedding_dim, transe_dim, hidden_size, device)
        esgcn.to(device)
        print(esgcn)
        optimizer = optim.Adam(esgcn.parameters(), lr=lr,  weight_decay=5e-4)
        directory = os.path.join(os.getcwd(), "esgcn_checkpoint-{}-{}".format(db_name, i))
        train(esgcn, train_data, train_label, criterion, optimizer, n_epoch, save_every, directory, device, clip, entity2vec, pred2ix, regularization, train_graph)

def writer(DB_DIR, skip_i, directory, top_or_rank, output):
    with open(path.join(DB_DIR, 
            "{}".format(skip_i+1), 
            "{}_desc.nt".format(skip_i+1)),
            encoding="utf8") as fin, \
    open(path.join(directory,
            "{}".format(skip_i+1),
            "{}_{}.nt".format(skip_i+1, top_or_rank)),
            "w", encoding="utf8") as fout:
        if top_or_rank == "top5" or top_or_rank == "top10":
            top_list = output.squeeze(0).numpy().tolist()
            print('top list', top_list)
            for t_num, triple in enumerate(fin):
                if t_num in top_list:
                    fout.write(triple)
        elif top_or_rank == "rank":
            rank_list = output.squeeze(0).numpy().tolist()
            #print(rank_list)
            triples = [triple for _, triple in enumerate(fin)]
            for rank in rank_list:
              try:
                  fout.write(triples[rank])
              except:
                  pass
                    
    return

def generator(DB_NAME, base, data, label, entity2vec, pred2ix, pred2ix_size, pred_embedding_dim, transE_dim, hidden_size, device, use_epoch, db_base, DB_DIR, skip_num, data_graph):
    directory = path.join("data_esa", DB_NAME)
    if not path.exists(directory):
        os.makedirs(directory)

    print("generating entity summarization results:") 
    for num in tqdm(range(5)):
        CHECK_DIR = path.join("esgcn_checkpoint-{}-{}".format(DB_NAME, num))
        print("CHECK_DIR", CHECK_DIR)
        esgcn = ESGCN(pred2ix_size, pred_embedding_dim, transE_dim, hidden_size, device)
        checkpoint = torch.load(path.join(CHECK_DIR, "checkpoint_epoch_{}.pt".format(use_epoch)))
        esgcn.load_state_dict(checkpoint["model_state_dict"])
        esgcn.to(device)
        for i in range(num*base, (num+1)*base):
            print('i', i)
            data_i = i - num*base
            print('data_i', data_i)
            _, _,_, test_data, test_label, test_graph = split_data(base, num, data, label, data_graph)
            pred_tensor, obj_tensor = tensor_from_data(entity2vec, pred2ix, test_data[data_i])
            input_tensor = [pred_tensor.to(device), obj_tensor.to(device)]
            weight_tensor = tensor_from_weight(len(test_data[data_i]), test_data[data_i], test_label[data_i]).to(device)
            atten_weight = esgcn(input_tensor, test_graph[data_i])
            atten_weight = atten_weight.view(1, -1).cpu()
            weight_tensor = weight_tensor.view(1, -1).cpu()
            (_, label_top10) = torch.topk(weight_tensor, 10)
            (_, output_top10) = torch.topk(atten_weight, 10)
            (_, label_top5) = torch.topk(weight_tensor, 5)
            (_, output_top5) = torch.topk(atten_weight, 5)
            (_, output_rank) = torch.topk(atten_weight, len(test_data[data_i]))
            if num == 4:
                skip_i = i + skip_num + db_base
            else:
                skip_i = i + db_base
            if not path.exists(path.join(directory, "{}".format(skip_i+1))):
                os.makedirs(path.join(directory, "{}".format(skip_i+1)))
            writer(DB_DIR, skip_i, directory, "top10", output_top10)
            writer(DB_DIR, skip_i, directory, "top5", output_top5)
            writer(DB_DIR, skip_i, directory, "rank", output_rank)

def main(DB_NAME, mode, top_n, file_n, transE_dim, pred_embedding_dim, lr, clip, save_every, n_epoch, use_epoch, loss_function, regularization):
  if DB_NAME == "dbpedia":
    print("training model on dbpedia")
    DB_START, DB_END = [1, 141], [101, 166]
    base = 25
    skip_num = 40
    db_base = 0
  elif DB_NAME == "lmdb":
    print("training model on lmdb")
    DB_START, DB_END = [101, 166], [141, 176]
    base = 10
    skip_num = 25
    db_base = 100
  DB_DIR = path.join(DATADIR, DB_NAME)
  print("DB_DIR", DB_DIR)

  # load data
  data, _, label, _, _, data_graph = process_data(DB_NAME, DB_START, DB_END, top_n, file_n)

  entity2vec, pred2vec, entity2ix, pred2ix = load_transE(DB_NAME)
  pred2ix_size = len(pred2ix)
  hidden_size = transE_dim + pred_embedding_dim

  # train
  ## cuda 
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("cuda or cpu: {}".format(device))

  ## loss function
  if loss_function == "BCE":
    criterion = torch.nn.BCELoss()
  elif loss_function == "MSE":
    criterion = torch.nn.MSELoss()
  else:
    print("please choose choose the correct loss fucntion")
    sys.exit()
  print("loss function: {}".format(loss_function))

  if mode == "train" or mode == "all":
    ## training iteration (5-fold cross validation)
    train_iter(DB_NAME, base, data, label, pred2ix, pred2ix_size, entity2vec, pred_embedding_dim, transE_dim, hidden_size, criterion, clip, lr, n_epoch, save_every, regularization, device, data_graph)

  if mode == "test" or mode == "all":
    #generate
    generator(DB_NAME, base, data, label, entity2vec, pred2ix, pred2ix_size, pred_embedding_dim, transE_dim, hidden_size, device, use_epoch, db_base, DB_DIR, skip_num, data_graph)

In [0]:
main("dbpedia", "train", 5, 6, 100,100, 0.01, 50, 2, 50, 24, "BCE", False)

training model on dbpedia
DB_DIR data/dbpedia


In [0]:
main("dbpedia", "test", 5, 6, 100,100, 0.01, 50, 2, 50, 24, "BCE", False)