In [1]:
from model import *
from reader import *

In [2]:
opts = Options()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess = tf.InteractiveSession(config=config)

#relation enhancement method
use_jeval = True

In [3]:
#parameters
opts.hidden_size = 512
opts.num_samples = 2048*3
opts.keep_prob = 0.5
opts.num_layers = 2
opts.learning_rate=0.001

#### You have to select one of the three datasets,
#### the default dataset is FB15K-237 (datasets\[0\])
#### you can also select another two datasets.


In [4]:
#select one dataset
datasets = ['237', 'FB15K', 'WN18']
used_dataset = datasets[0]

In [5]:
if used_dataset == '237':
    #set log filename and data path
    file_name = '237-dskg-hs512'
    opts.data_path = 'data/FB15k-237/'
    
    #different datasets use different data parser
    model = FBRespective(opts, sess)
if used_dataset == 'FB15K':
    file_name = 'fb-dskg-hs512'
    opts.data_path = 'data/FB15k-237/'
    model = FBRespective(opts, sess)
if used_dataset == 'WN18':
    file_name = 'wn-dskg-hs512'
    opts.data_path = 'data/wordnet-mlj12/'
    model = WNRespective(opts, sess)

load file from local
start gen filter mat
Instructions for updating:
Please switch to tf.train.get_or_create_global_step


In [6]:
#calculate ranks
def cal_ranks(probs, method, label):
    if method == 'min':
        probs = probs - probs[range(len(label)), label].reshape(len(probs), 1)
        ranks = (probs > 0).sum(axis=1) + 1
    else:
        ranks = pd.DataFrame(probs).rank(axis=1, ascending=False, method=method)
        ranks = ranks.values[range(len(label)), label]
    return ranks

#calculate performance
def cal_performance(ranks, top=10):
    m_r = sum(ranks) * 1.0 / len(ranks)
    h_10 = sum(ranks <= top) * 1.0 / len(ranks)
    mrr = (1. / ranks).sum() / len(ranks)
    return m_r, h_10, mrr

def eval_entity_prediction(model, data, filter_mat, method='min', return_ranks=False, return_probs=False, return_label_probs=False):
    options = model._options
    batch_size = options.batch_size
    
    label = data[:, 2]
    
    data, padding_num = model.padding_data(data)

    num_batch = len(data) // batch_size 
    
    e_placeholder, r_placeholder, fectch_entity_probs = model._eval_e, model._eval_r, model._entity_probs
    
    probs = []
    for i in range(num_batch):
        e = data[:, 0][i * batch_size:(i + 1) * batch_size]
        r = data[:, 1][i * batch_size:(i + 1) * batch_size]
        
        feed_dict = {}
        feed_dict[e_placeholder] = e
        feed_dict[r_placeholder] = r
        
        probs.append(sess.run(fectch_entity_probs, feed_dict))
    probs = np.concatenate(probs)[:len(data) - padding_num]

    if return_label_probs:
        return probs[range(len(label)), label]
    
    if return_probs:
        return probs

    filter_probs = probs * filter_mat
    filter_probs[range(len(label)), label] = probs[range(len(label)), label]

    filter_ranks = cal_ranks(filter_probs, method=method, label=label)
    if return_ranks:
        return filter_ranks
    ranks = cal_ranks(probs, method=method, label=label)
    m_r, h_10, mrr = cal_performance(ranks)
    f_m_r, f_h_10, f_mrr = cal_performance(filter_ranks)
    
    return (m_r, h_10, mrr, f_m_r, f_h_10, f_mrr)

def eval_relation_prediction(model, data, filter_mat, method='min', return_ranks=False, return_probs=False):
    options = model._options
    batch_size = options.batch_size
    
    #data[:, 0]-->e, data[:, 1]-->r, data[:, 2]-->e2
    label = data[:, 1]
    
    data, padding_num = model.padding_data(data)

    num_batch = len(data) // batch_size
    
    e_placeholder, fectch_relation_probs = model._eval_e, model._relation_probs
    
    probs = []
    
    for i in range(num_batch):
        e = data[:, 0][i * batch_size:(i + 1) * batch_size]
        
        feed_dict = {}
        feed_dict[e_placeholder] = e
        
        probs.append(sess.run(fectch_relation_probs, feed_dict))
        
    probs = np.concatenate(probs)[:len(data) - padding_num]
    return probs


In [7]:
#preprocess data

test_data = np.array(model._test_data[['h_id', 'r_id', 't_id']].values)
train_data = model._train_data[['h_id', 'r_id', 't_id']].values
valid_data = model._valid_data[['h_id', 'r_id', 't_id']].values

filter_mat = model._tail_test_filter_mat
vfilter_mat = model._tail_valid_filter_mat

all_data = np.concatenate([train_data, test_data,valid_data])
p_data = np.concatenate([test_data,valid_data])

def gen_rev_rel(test_data):
    half = len(test_data)//2
    forward = test_data[:half]
    back = test_data[half:]
    rev_rel_test_data = test_data[:]
    rev_rel = np.concatenate([back[:,1], forward[:,1]])
    return rev_rel

rev_rel = gen_rev_rel(test_data)
vrev_rel=  gen_rev_rel(valid_data)

rev_rel_test_data = np.stack([np.arange(model._entity_num),np.arange(model._entity_num)], axis=1)

In [8]:
def cal_r(probs, label, filter_mat):
    filter_probs = probs * filter_mat
    
    filter_probs[range(len(label)), label] = probs[range(len(label)), label]
    filter_ranks = cal_ranks(filter_probs, method='min', label=label)
    
    return filter_ranks


def joint_eval(test_data, filter_mat, rev_rel):
    label=test_data[:, 2]

    ep =  eval_entity_prediction(model, data=test_data, filter_mat=filter_mat, return_probs=True)
    efr = cal_r(ep, label, filter_mat)
    if use_jeval:
        rp = eval_relation_prediction(model, rev_rel_test_data, filter_mat=None, return_probs=True).T
        rp = rp**0.33
        rp = rp[rev_rel]
        joint_probs = ep * rp
        joint_fr = cal_r(joint_probs, label, filter_mat)
    else:
        joint_fr = efr
    return joint_fr, efr

def joint_eval_raw(test_data, filter_mat, rev_rel):
    label=test_data[:, 2]
    
    
    ep =  eval_entity_prediction(model, data=test_data, filter_mat=filter_mat, return_probs=True)
    efr = cal_ranks(ep, method='min', label=label)
    if use_jeval:
        rp = eval_relation_prediction(model, rev_rel_test_data, filter_mat=None, return_probs=True).T
        rp = rp**0.33
        rp = rp[rev_rel]
        joint_probs = ep * rp
        joint_fr = cal_ranks(joint_probs, method='min', label=label)
    else:
        joint_fr = efr
    return joint_fr, efr

def process_ranks(efr, i=0, last_mean_loss=1000, title=''):

    MR, H1, MRR = cal_performance(efr[:len(efr)], top=1)
    _, H10, _ = cal_performance(efr[:len(efr)], top=10)
    msg = '%s epoch:%i, Hits@1:%.3f, Hits@10:%.3f, MR:%.3f, MRR:%.3f, mean_loss:%.3f' % (format(title,'<15'), i, H1, H10, MR, MRR, last_mean_loss)
    print(msg)
    return (i, H1, H10, MR, MRR, last_mean_loss)

def handle_eval(i=0, last_mean_loss=1000, valid=True, test=False):
    if valid:
        jfr, efr = joint_eval(test_data=valid_data, filter_mat=vfilter_mat, rev_rel=vrev_rel)
        jrr, rr = joint_eval_raw(test_data=valid_data, filter_mat=vfilter_mat, rev_rel=vrev_rel)
        
        process_ranks(rr, i, last_mean_loss, title='Valid-R')
        process_ranks(jrr, i, last_mean_loss, title='Valid-R-RH')
        
        msg = process_ranks(efr, i, last_mean_loss, title='Valid-F')
        jmsg = process_ranks(jfr, i, last_mean_loss, title='Valid-F-RH')
        
        #process early stop
        global best_hits_1, dropped_time
        current_hits_1 = jmsg[1]
        if current_hits_1 > best_hits_1:
            best_hits_1 = current_hits_1
            dropped_time = 0
        else:
            dropped_time += 1
        
        
        valid_results.append(msg)
        valid_results.append(jmsg)
        if i % 50 == 0:
            pd.DataFrame(valid_results, columns=['epoch','Hits@1', 'Hits@10', 'MR', 'MRR', 'mean_loss']).to_csv('results/'+file_name+'valid')
        
    if test:
        jfr, efr = joint_eval(test_data=test_data, filter_mat=filter_mat, rev_rel=rev_rel)
        jrr, rr = joint_eval_raw(test_data=test_data, filter_mat=filter_mat, rev_rel=rev_rel)
        
        process_ranks(rr, i, last_mean_loss, title='Test-R')
        process_ranks(jrr, i, last_mean_loss, title='Test-R-RH')
        
        
        msg = process_ranks(efr, i, last_mean_loss, title='Test-F')
        jmsg = process_ranks(jfr, i, last_mean_loss, title='Test-F-RH')
        results.append(msg)
        results.append(jmsg)
        if i % 50 == 0:
            pd.DataFrame(results, columns=['epoch','Hits@1', 'Hits@10', 'MR', 'MRR', 'mean_loss']).to_csv('results/'+file_name+'test')
    return 

In [9]:
epoch =0
results = []
valid_results = []
last_mean_loss=1000

In [10]:
#early stop setting
global best_hits_1, dropped_time
best_hits_1 = 0
dropped_time = 0


max_dropped_time = 3

max_epoch = 300

## the function handle_eval(i=i, last_mean_loss=last_mean_loss, valid=True, test=True) returns the evaluation results:

**Valid** and **Test** denote the datasets

**R** denotes Raw results

**F** denotes Filtered results

**RH** denotes using relation enhancement method

In [11]:

for i in range(epoch, max_epoch):
    if i % 20 == 0:
        handle_eval(i=i, last_mean_loss=last_mean_loss, valid=True, test=False)
    last_mean_loss = model.train()
    epoch += 1
    
    #early stop
    if dropped_time >= max_dropped_time:
        break
    
handle_eval(i=i, last_mean_loss=last_mean_loss, valid=True, test=True)

Valid-R         epoch:0, Hits@1:0.000, Hits@10:0.001, MR:7244.412, MRR:0.001, mean_loss:1000.000
Valid-R-RH      epoch:0, Hits@1:0.000, Hits@10:0.001, MR:7207.007, MRR:0.001, mean_loss:1000.000
Valid-F         epoch:0, Hits@1:0.000, Hits@10:0.001, MR:7123.872, MRR:0.001, mean_loss:1000.000
Valid-F-RH      epoch:0, Hits@1:0.000, Hits@10:0.001, MR:7086.761, MRR:0.001, mean_loss:1000.000
2048 265 0.001 100000
2048 265 0.001 9.09361203032
2048 265 0.001 7.27079046537
2048 265 0.001 6.80453052881
2048 265 0.001 6.51655452116
2048 265 0.001 6.30637829439
2048 265 0.001 6.15042848767
2048 265 0.001 6.03160143618
2048 265 0.001 5.93442380113
2048 265 0.001 5.85395814968
2048 265 0.001 5.78373944804
2048 265 0.001 5.72603296604
2048 265 0.001 5.67308706068
2048 265 0.001 5.63056590962
2048 265 0.001 5.58961372016
2048 265 0.001 5.554965676
2048 265 0.001 5.52393407282
2048 265 0.001 5.49358127162
2048 265 0.001 5.46647464104
2048 265 0.001 5.4425500096
Valid-R         epoch:20, Hits@1:0.115, Hi

2048 265 0.001 4.80037028115
2048 265 0.001 4.7982216979
2048 265 0.001 4.79662321918
2048 265 0.001 4.79763258268
2048 265 0.001 4.79423569193
2048 265 0.001 4.7947435397
2048 265 0.001 4.79476264738
2048 265 0.001 4.79317397531
2048 265 0.001 4.79188492073
2048 265 0.001 4.79028921487
2048 265 0.001 4.78686061895
2048 265 0.001 4.78787261135
2048 265 0.001 4.78506594964
2048 265 0.001 4.78369640494
Valid-R         epoch:180, Hits@1:0.095, Hits@10:0.308, MR:402.373, MRR:0.161, mean_loss:4.785
Valid-R-RH      epoch:180, Hits@1:0.096, Hits@10:0.312, MR:374.294, MRR:0.162, mean_loss:4.785
Valid-F         epoch:180, Hits@1:0.247, Hits@10:0.524, MR:188.091, MRR:0.339, mean_loss:4.785
Valid-F-RH      epoch:180, Hits@1:0.251, Hits@10:0.527, MR:160.151, MRR:0.342, mean_loss:4.785
2048 265 0.001 4.78535209692
2048 265 0.001 4.78461379285
2048 265 0.001 4.78269881482
2048 265 0.001 4.78196121612
2048 265 0.001 4.77994404199
2048 265 0.001 4.77947779062
2048 265 0.001 4.77817325052
2048 265 0.00