In [1]:
import tensorflow as tf

class AggrE(object):
    def __init__(self,
                 sampled_contexts,
                 masks,
                 hops,
                 epoch=20,
                 batch_size=128,
                 dim=64,
                 l2=1e-7,
                 lr=5e-3,
                 negtive_num=-1):
        
        self.epoch = epoch
        self.batch_size = batch_size
        self.output_dim = dim
        self.l2 = l2
        self.lr = lr
        self.negtive_num = negtive_num
        self.e_sampled_contexts = sampled_contexts[0]
        self.masks = masks[0]
        self.r_sampled_contexts = sampled_contexts[1]
        self.r_masks = masks[1]
        self.hops = hops

        self.n_entities = n_entity
        self.n_relations = n_relation

        self._build_inputs()
        self._build_embedding()  
        self._build_aggre()
        self._build_train()
        self._build_eval()

    def _build_inputs(self):

        self.heads = tf.placeholder(tf.int32, [self.batch_size], name='heads')
        self.relations = tf.placeholder(tf.int32, [self.batch_size], name='relations')
        self.tails = tf.placeholder(tf.int32, [self.batch_size], name='labels')
        self.is_training = tf.placeholder(tf.bool, name="is_training")

        self.labels = tf.expand_dims(self.relations, -1)

    def _build_embedding(self):
        
        self.entities_emb = tf.get_variable(
            'entities', [self.n_entities, self.output_dim], tf.float32,
            tf.contrib.layers.xavier_initializer())
        self.relations_emb = tf.get_variable(
            'relations', [self.n_relations, self.output_dim], tf.float32,
            tf.contrib.layers.xavier_initializer())

    def _build_aggre(self):
        
        self.hn_e = self.e_sampled_contexts[:, :, 0]
        self.hn_r = self.e_sampled_contexts[:, :, 1]
        
        self.e_embs = []
        self.e_embs.append(self.entities_emb)
        
        self.rn_h = self.r_sampled_contexts[:, :, 0]
        self.rn_t = self.r_sampled_contexts[:, :, 1]
        
        self.r_embs = []
        self.r_embs.append(self.relations_emb)
        
        #self.output_keep_prob = 1
        for i in range(self.hops):
                        
            self.en_h_emb = tf.nn.embedding_lookup(self.e_embs[-1], self.hn_e)  # n_entity, 4, 64
            self.en_r_emb = tf.nn.embedding_lookup(self.r_embs[-1], self.hn_r)  # n_entity, 4, 64  
            self.e_context_info = self.e_context_aggregation(self.en_h_emb, self.en_r_emb, self.masks, i)
            #self.e_context_info = tf.layers.dropout(self.e_context_info, 1-self.output_keep_prob, training = self.is_training)
            self.new_e_emb = self.e_update(self.e_embs[-1], self.e_context_info, i)
            
            
            self.rn_h_emb = tf.nn.embedding_lookup(self.e_embs[-1], self.rn_h)
            self.rn_t_emb = tf.nn.embedding_lookup(self.e_embs[-1], self.rn_t)
            self.r_context_info = self.r_context_aggregation(self.rn_h_emb, self.rn_t_emb, self.r_masks, i)
            #self.r_context_info = tf.layers.dropout(self.r_context_info, 1-self.output_keep_prob, training = self.is_training)
            self.new_r_emb = self.r_update(self.r_embs[-1], self.r_context_info, i)
            
            self.r_embs.append(self.new_r_emb)
            self.e_embs.append(self.new_e_emb)
                
        self.h_emb = tf.nn.embedding_lookup(self.e_embs[-1], self.heads)
        self.r_emb = tf.nn.embedding_lookup(self.r_embs[-1], self.relations)
        self.t_emb = tf.nn.embedding_lookup(self.e_embs[-1], self.tails)

        self.q_emb = self.h_emb * self.t_emb
        self.b = tf.get_variable('bias', [self.n_relations], tf.float32,tf.contrib.layers.xavier_initializer())
        self.scores = tf.matmul(self.q_emb, tf.transpose(self.r_embs[-1])) + self.b


    def _build_train(self):
        if self.negtive_num != -1:
            self.base_loss = tf.reduce_mean(
                tf.nn.sampled_softmax_loss(weights=self.r_embs[-1],
                                           biases=self.b,
                                           labels=self.labels,
                                           inputs=self.q_emb,
                                           num_sampled=self.negtive_num,
                                           num_classes=self.n_relations))
        else:
            self.base_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self.relations, logits=self.scores))

        self.l2_loss = self.l2 * sum(tf.nn.l2_loss(var)
            for var in tf.trainable_variables() if 'bias' not in var.name)
        self.loss = self.base_loss + self.l2_loss
        self.optimizer = tf.train.AdamOptimizer(self.lr).minimize(self.loss)

    def _build_eval(self):
        self.scores_normalized = tf.nn.sigmoid(self.scores)
        correct_predictions = tf.equal(self.relations, tf.cast(tf.argmax(self.scores, axis=-1), tf.int32))
        self.acc = tf.reduce_mean(tf.cast(correct_predictions, tf.float64))

    def train(self, sess, feed_dict):
        return sess.run([self.optimizer, self.loss, self.acc], feed_dict)

    def _eval(self, sess, feed_dict):
        return sess.run([self.acc, self.scores_normalized], feed_dict)
    
    def r_context_aggregation(self, rn_t_emb, rn_h_emb, r_masks, hopi):
        context_info = rn_h_emb * rn_t_emb
        weight = tf.nn.softmax(tf.reduce_sum(context_info * tf.expand_dims(self.r_embs[-1],1), -1, True) * r_masks, 1)
        return tf.reduce_sum(context_info * weight, 1)

    def r_update(self, r_embs, r_context_info, hopi):
        return r_embs + r_context_info 
    
    def e_context_aggregation(self, en_h_emb, hn_r_emb, e_masks, hopi):
        context_info = en_h_emb * hn_r_emb
        weight = tf.nn.softmax(tf.reduce_sum(context_info * tf.expand_dims(self.e_embs[-1],1), -1, True) * e_masks, 1)
        return tf.reduce_sum(context_info * weight, 1)
    
    def e_update(self, e_embs, e_context_info, hopi):
        return e_embs + e_context_info
    

In [2]:
import os
import re
import pickle
import numpy as np
import pickle
import math


def dump_data(obj, wfpath, wfname):
    with open(os.path.join(wfpath, wfname), 'wb') as wf:
        pickle.dump(obj, wf)


def load_file(rfpath, rfname):
    with open(os.path.join(rfpath, rfname), 'rb') as rf:
        return pickle.load(rf)


def dim_factorization(d):
    half = int(math.sqrt(d)) + 1
    while d % half > 0:
        half -= 1
    x = half
    y = d // half
    assert x * y == d
    print("dim factorization", x, y)
    return x, y


def read_entities(file_name):
    d = {}
    file = open(file_name)
    for line in file:
        index, name = line.strip().split('\t')
        d[name] = int(index)
    file.close()

    return d


def read_relations(file_name):
    d = {}
    file = open(file_name)
    for line in file:
        index, name = line.strip().split('\t')
        d[name] = int(index)
    file.close()

    return d


def read_triplets(file_name):
    data = []

    file = open(file_name)
    for line in file:
        head, relation, tail = line.strip().split('\t')

        head_idx = entity_dict[head]
        relation_idx = relation_dict[relation]
        tail_idx = entity_dict[tail]

        data.append((head_idx, relation_idx, tail_idx))
        data.append((tail_idx, relation_idx, head_idx))

    file.close()

    return data

    
def load_data(dataset, context_samples_num):
    global entity_dict, relation_dict
    global n_relation, n_entity
    directory = 'data/' + dataset + '/'

    print('reading entity dict and relation dict ...')
    entity_dict = read_entities(directory + 'entities.dict')
    relation_dict = read_relations(directory + 'relations.dict')
    
    n_entity = len(entity_dict)
    n_relation = len(relation_dict)
        
    print('entitiy  num:', n_entity)
    print('relation num:', n_relation)

    print('reading train, validation, and test data ...')
    train_triplets = read_triplets(directory + 'train.txt')
    valid_triplets = read_triplets(directory + 'valid.txt')
    test_triplets = read_triplets(directory + 'test.txt')

    triplets = [train_triplets, valid_triplets, test_triplets]

    print('sampling contexts ...')
    e_contexts = dict()
    r_contexts = dict()
    for (head_idx, relation_idx, tail_idx) in train_triplets:
        if tail_idx not in e_contexts:
            e_contexts[tail_idx] = []
        e_contexts[tail_idx].append([head_idx, relation_idx])
        
        if relation_idx not in r_contexts:
            r_contexts[relation_idx] = []
        r_contexts[relation_idx].append([head_idx, tail_idx])
    
    e_sampled_contexts = []
    e_masks = []
    for h in range(n_entity):
        mask = []
        if h in e_contexts:
            context_list = e_contexts[h]
            if len(context_list) >= context_samples_num[0]:
                idxs = np.random.choice(len(context_list), size=context_samples_num[0], replace=False) 
                e_ns = np.array(context_list)[idxs]
                for n in e_ns:
                    mask.append([1])
            else:
                lenth = context_samples_num[0] - len(context_list)
                e_ns = np.pad(context_list, [[0,lenth],[0,0]])
                for ii,n in enumerate(e_ns):
                    if ii <len(context_list):
                        mask.append([1])
                    else:
                        mask.append([0])
        else:
            e_ns = np.zeros((context_samples_num[0],2),dtype=np.int32)
            mask = [[0]]*context_samples_num[0]
            
        e_sampled_contexts.append(e_ns)
        e_masks.append(mask)
    
    r_sampled_contexts = []
    r_masks = []
    for r in range(n_relation):
        mask = []
        if r in r_contexts:
            context_list = r_contexts[r]
            if len(context_list) >= context_samples_num[1]:
                idxs = np.random.choice(len(context_list), size=context_samples_num[1], replace=False) 
                r_ns = np.array(context_list)[idxs]
                for n in r_ns:
                    mask.append([1])
            else:
                lenth = context_samples_num[1] - len(context_list)
                r_ns = np.pad(context_list, [[0,lenth],[0,0]])
                for ii,n in enumerate(r_ns):
                    if ii <len(context_list):
                        mask.append([1])
                    else:
                        mask.append([0])
        else:
            r_ns = np.zeros((context_samples_num[1],2),dtype=np.int32)
            mask = [[0]]*context_samples_num[1]
            
        r_sampled_contexts.append(r_ns)
        r_masks.append(mask)

    return triplets, (np.array(e_sampled_contexts,dtype=np.int32), np.array(r_sampled_contexts,dtype=np.int32)), (e_masks, r_masks)

def evaluate(entity_pairs, labels, return_score=False):
    acc_list = []
    if return_score:
        scores_list = []
        
    s = 0
    while s + model.batch_size <= len(labels):
        acc, scores = model._eval(
            sess,
            get_feed_dict(entity_pairs,
                          labels,
                          s,
                          s + model.batch_size,
                          training=False))
        acc_list.append(acc)
        if return_score:
            scores_list.extend(scores)
        s += model.batch_size

    if return_score:
        return float(np.mean(acc_list)), np.array(scores_list, np.float32)
    else:
        return float(np.mean(acc_list))


def calculate_ranking_metrics(triplets, scores, true_relations):
    for i in range(scores.shape[0]):
        head, relation, tail = triplets[i]
        for j in true_relations[head, tail] - {relation}:
            scores[i, j] -= 1.0

    sorted_indices = np.argsort(-scores, axis=1)
    relations = np.array(triplets)[0:scores.shape[0], 1]
    sorted_indices -= np.expand_dims(relations, 1)
    zero_coordinates = np.argwhere(sorted_indices == 0)
    rankings = zero_coordinates[:, 1] + 1

    mrr = float(np.mean(1 / rankings))
    mr = float(np.mean(rankings))
    hit1 = float(np.mean(rankings <= 1))
    hit3 = float(np.mean(rankings <= 3))
    hit5 = float(np.mean(rankings <= 5))
    hit10 = float(np.mean(rankings <= 10))

    return mrr, mr, hit1, hit3, hit5, hit10


def get_feed_dict(entity_pairs, labels, start, end, training):
    feed_dict = {}
    feed_dict[model.heads] = entity_pairs[start:end, 0]
    feed_dict[model.tails] = entity_pairs[start:end, 1]
    feed_dict[model.relations] = labels[start:end]
    feed_dict[model.is_training] = training

    return feed_dict


In [3]:
import numpy as np
from collections import defaultdict
global model, sess

triplets, sampled_contexts, masks = load_data('wn18rr', context_samples_num = [4,32]) 
#triplets, sampled_contexts, masks = load_data('FB15k-237', context_samples_num = [8,4])

train_triplets, valid_triplets, test_triplets = triplets

train_entity_pairs = np.array([[triplet[0], triplet[2]]
                               for triplet in train_triplets], np.int32)
valid_entity_pairs = np.array([[triplet[0], triplet[2]]
                               for triplet in valid_triplets], np.int32)
test_entity_pairs = np.array([[triplet[0], triplet[2]]
                              for triplet in test_triplets], np.int32)

train_labels = np.array([triplet[1] for triplet in train_triplets], np.int32)
valid_labels = np.array([triplet[1] for triplet in valid_triplets], np.int32)
test_labels  = np.array([triplet[1] for triplet in test_triplets], np.int32)

# prepare for top-k evaluation
true_relations = defaultdict(set)
for head, relation, tail in train_triplets + valid_triplets + test_triplets:
    true_relations[(head, tail)].add(relation)

model = AggrE(epoch=20,
             batch_size=512,
             dim=256,
             l2=1e-7,
             lr=5e-3,
             negtive_num=-1,
             sampled_contexts=sampled_contexts,
             masks = masks,
             hops=2)
'''
model = AggrE(epoch=20,
             batch_size=1024,
             dim=256,
             l2=1e-6,
             lr=5e-3,
             negtive_num=-1,
             sampled_contexts=sampled_contexts,
             masks = masks,
             hops=4) #FB15k-237
'''

best_valid_acc = 0.0
final_res = None  # acc, mrr, mr, hit1, hit3, hit5

with tf.Session() as sess:
    print('start training ...')
    sess.run(tf.global_variables_initializer())

    for step in range(model.epoch):

        # shuffle training data
        index = np.arange(len(train_labels))
        np.random.shuffle(index)
        train_entity_pairs = train_entity_pairs[index]
        train_labels = train_labels[index]

        # training
        s = 0
        while s + model.batch_size <= len(train_labels):
            _, loss, acc = model.train(
                sess,
                get_feed_dict(train_entity_pairs,
                              train_labels,
                              s,
                              s + model.batch_size,
                              training=True))
            s += model.batch_size

            if s % (model.batch_size*300) == 0:
                # evaluation
                print('epoch %2d   ' % step, end='')
                train_acc = evaluate(train_entity_pairs, train_labels)
                valid_acc = evaluate(valid_entity_pairs, valid_labels)
                test_acc, test_scores = evaluate(test_entity_pairs, test_labels, return_score=True)

                # show evaluation result for current epoch
                current_res = 'acc: %.4f' % test_acc
                print('train acc: %.4f   valid acc: %.4f   test acc: %.4f' % (train_acc, valid_acc, test_acc))

                mrr, mr, hit1, hit3, hit5, hit10 = calculate_ranking_metrics(
                    test_triplets, test_scores, true_relations)
                current_res += '   mrr: %.4f   mr: %.4f   h1: %.4f   h3: %.4f   h5: %.4f   h10: %.4f' % (
                    mrr, mr, hit1, hit3, hit5, hit10)
                print('           mrr: %.4f   mr: %.4f   h1: %.4f   h3: %.4f   h5: %.4f   h10: %.4f'
                    % (mrr, mr, hit1, hit3, hit5, hit10))
                print()
                              
                # update final results according 2to validation accuracy
                if valid_acc > best_valid_acc:
                    best_valid_acc = valid_acc
                    final_res = current_res

    # show final evaluation result
    print('final results\n%s' % final_res)

reading entity dict and relation dict ...
entitiy  num: 14541
relation num: 237
reading train, validation, and test data ...
sampling contexts ...
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

start training ...
epoch  0   train acc: 0.7816   valid acc: 0.9189   test acc: 0.9181
           mrr: 0.9566   mr: 1.2269   h1: 0.9250   h3: 0.9881   h5: 0.9949   h10: 0.9975

epoch  1   train acc: 0.8092   valid acc: 0.9160   test acc: 0.9161
           mrr: 0.9607   mr: 1.1785   h1: 0.9328   h3: 0.9884   h5: 0.9950   h10: 0.9982

epoch  2   train acc: 0.8163   valid acc: 0.9184   test acc: 0.9191
           mrr: 0.9630   mr: 1.1833   h1: 0.9376   h3: 0.9877   h5: 0.9940   h10: 0.99

KeyboardInterrupt: 