## Import packages

In [1]:
import pandas as pd
import numpy as np
from operator import itemgetter
from CFModel import CFModel

Using Theano backend.
Using gpu device 0: GRID K520 (CNMeM is disabled, cuDNN 5006)


## Define constants

In [2]:
TEST_CSV_FILE = 'fb15k_test.csv'
CVSC_ENTITIES_CSV_FILE = 'fb15k_cvsc_entities.csv'
CVSC_PAIRS_CSV_FILE = 'fb15k_cvsc_pairs.csv'
MODEL_WEIGHTS_FILE = 'fb15k_cvsc_weights.h5'
K_FACTORS = 50

## Load FB215-237 data

In [16]:
triples = pd.read_csv(TEST_CSV_FILE, 
                      sep='\t', 
                      usecols=['subj', 'rel', 'obj', 'pid', 'rid'])
entities = pd.read_csv(CVSC_ENTITIES_CSV_FILE, sep='\t', usecols=['entity'])['entity'].values[1:]
entity_pairs = pd.read_csv(CVSC_PAIRS_CSV_FILE, 
                           sep='\t', 
                           usecols=['subj', 'obj', 'pid'])

## Print basic dataset statistics

In [51]:
n_pairs = triples['pid'].max() + 1
m_relations = triples['rid'].max() + 1
l_entities = len(entities)
print n_pairs, 'pairs,', m_relations, 'relations,', l_entities, 'entities'

283868 pairs, 237 relations, 14281 entities


## Load model weights into evaluation model

In [19]:
model = CFModel(n_pairs, m_relations, K_FACTORS)
model.load_weights(MODEL_WEIGHTS_FILE)

## Execute evaluation protocol

From [[2]](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/main_cvsc2015.pdf):

> Given a set of triples in a set disjoint from a training
> knowledge graph, we test models on predicting the subject or
> object of each triple, given the relation type and the other
> argument. We rank all entities in the training knowledge base in
> order of their likelihood of filling the argument position. We
> report the mean reciprocal rank of the correct entity, as well as
> HITS@10 – the percent of test triples for which the correct
> argument was ranked in the top ten. We use filtered measures
> following the protocol proposed in Bordes et al. (2013) – that
> is, when we rank entities for a given position, we remove all
> other entities that are known to be part of an existing triple in
> the training, development, or test set. This avoids penalizing
> the model for ranking other correct fillers higher than the
> tested argument. We thus report filtered mean reciprocal
> rank (labeled MRR in the Figures), and filtered HITS@10. In the
> figures we present MRR values scaled by 100, so that the maximum
> possible MRR is 100.

** *Note: filtering not yet implemented, code neither complete nor debugged for non-filtering case anyways* **

In [52]:
def sp_query_reciprocal_rank(model, subj, rid, obj, entities):
    objs = [ result[0] for result in sp_query_results(model, subj, rid, entities) ]
    return reciprocal_rank(obj, objs)

def sp_query_hits_at_10(model, subj, rid, obj, entities):
    objs = [ result[0] for result in sp_query_results(model, subj, rid, entities) ]
    if obj in objs[:10]:
        return 1.0
    else:
        return 0.0

def sp_query_results(model, subj, rid, entities):
    return sorted([ [ subj, model.rank(pid, rid) ] for pid in sp_query_pairs(subj, entities) ], 
                  reverse=True, 
                  key=itemgetter(1))

def sp_query_pairs(subj, entities):
    return [ pair_id(subj, obj) for obj in entities if pair_id(subj, obj) > -1 ]

def po_query_reciprocal_rank(model, subj, rid, obj, entities):
    subjs = [ result[0] for result in po_query_results(model, obj, rid, entities) ]
    return reciprocal_rank(subj, subjs)

def po_query_hits_at_10(model, subj, rid, obj, entities):
    subjs = [ result[0] for result in po_query_results(model, obj, rid, entities) ]
    if subj in subjs[:10]:
        return 1.0
    else:
        return 0.0

def po_query_results(model, obj, rid, entities):
    return sorted([ [ obj, model.rank(pid, rid) ] for pid in sp_query_pairs(subj, entities) ], 
                  reverse=True, 
                  key=itemgetter(1))

def po_query_pairs(obj, entities):
    return [ pair_id(subj, obj) for subj in entities if pair_id(subj, obj) > -1]

def pair_id(subj, obj):
    pair = entity_pairs[(entity_pairs['subj'] == subj) & (entity_pairs['obj'] == obj)]
    if len(pair) > 0:
        return pair['pid'].values[0]
    else:
        return -1

def reciprocal_rank(correct_response, responses):
    return 1. / np.float(np.where(responses == correct_response)[0][0])

In [47]:
entity_pairs.head(3)

Unnamed: 0,subj,obj,pid
0,/m/027rn,/m/06cx9,0
1,/m/017dcd,/m/06v8s0,1
2,/m/07s9rl0,/m/0170z3,2


In [53]:
po_query_pairs('/m/027rn', entities)

KeyboardInterrupt: 

In [36]:
triples['sp_reciprocal_rank'] = sp_query_reciprocal_rank(model, triples['subj'], triples['rid'], triples['obj'], entities)
triples['po_reciprocal_rank'] = po_query_reciprocal_rank(model, triples['subj'], triples['rid'], triples['obj'], entities)
triples['sp_hits_at_10'] = sp_query_hits_at_10(model, triples['subj'], triples['rid'], triples['obj'], entities)
triples['po_hits_at_10'] = po_query_hits_at_10(model, triples['subj'], triples['rid'], triples['obj'], entities)

mrr = (triples['sp_reciprocal_rank'].sum() + triples['po_reciprocal_rank'].sum()) / (np.float(len(triples)) * 2.0)
hits_at_10 = (triples['sp_hits_at_10'].sum() + triples['po_hits_at_10'].sum()) / (np.float(len(triples)) * 2.0)

print 'Mean reciprocal rank:', mrr
print 'HITS@10:', hits_at_10

0          /m/08966
1         /m/01hww_
2        /m/09v3jyg
3          /m/02jx1
4          /m/02jx1
5         /m/02bfmn
6        /m/05zrvfd
7          /m/060bp
8         /m/07l450
9         /m/07h1h5
10       /m/0djb3vw
11         /m/031y2
12         /m/01d8l
13         /m/0ydpd
14        /m/0738b8
15         /m/070xg
16         /m/0mwk9
17        /m/09306z
18         /m/0kbws
19         /m/0kbws
20        /m/017j69
21         /m/07z1m
22       /m/03gqgt3
23        /m/01c9f2
24        /m/01c9f2
25         /m/01lp8
26        /m/01xwqn
27       /m/065ym0c
28        /m/06kxk2
29        /m/08q1tg
            ...    
20436    /m/043sct5
20437      /m/0c921
20438      /m/0d085
20439     /m/05mt_q
20440     /m/02hct1
20441     /m/084qpk
20442      /m/085h1
20443      /m/0glnm
20444       /m/0mkg
20445      /m/0gthm
20446    /m/0gvvf4j
20447      /m/0b3wk
20448      /m/0l15n
20449      /m/0_565
20450    /m/02ppm4q
20451     /m/03q3sy
20452    /m/0gl02yg
20453      /m/0fkvn
20454    /m/05q8pss


ValueError: Series lengths must match to compare

## References

[[1]](https://www.microsoft.com/en-us/download/details.aspx?id=52312) K. Toutanova, "FB215-237 Knowledge Base Completion Dataset," Web page https://www.microsoft.com/en-us/download/details.aspx?id=52312, May 2016. Last accessed 2016-08-14.

[[2]](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/main_cvsc2015.pdf) K. Toutanova and D. Chen, “Observed versus latent features for knowledge base and text inference,” in 3rd Workshop on Continuous Vector Space Models and Their Compositionality, Jul. 2015.