## Import packages

In [1]:
import pandas as pd
import numpy as np
from operator import itemgetter
from CFModel import CFModel

Using Theano backend.
Using gpu device 0: GRID K520 (CNMeM is disabled, cuDNN 5006)


## Define constants

In [2]:
TEST_CSV_FILE = 'fb15k_test.csv'
CVSC_ENTITIES_CSV_FILE = 'fb15k_cvsc_entities.csv'
CVSC_PAIRS_CSV_FILE = 'fb15k_cvsc_pairs.csv'
MODEL_WEIGHTS_FILE = 'test_weights.h5'
K_FACTORS = 20

## Load FB215-237 data

In [3]:
triples = pd.read_csv(TEST_CSV_FILE, 
                      sep='\t', 
                      usecols=['subj', 'rel', 'obj', 'pid', 'rid'])
entities = pd.read_csv(CVSC_ENTITIES_CSV_FILE, sep='\t', usecols=['entity'])['entity'].values[1:]
entity_pairs = pd.read_csv(CVSC_PAIRS_CSV_FILE, 
                           sep='\t', 
                           usecols=['subj', 'obj', 'pid'])

## Print basic dataset statistics

In [4]:
n_pairs = triples['pid'].max() + 1
m_relations = triples['rid'].max() + 1
l_entities = len(entities)
print n_pairs, 'pairs,', m_relations, 'relations,', l_entities, 'entities'

283868 pairs, 237 relations, 14281 entities


## Load model weights into evaluation model

In [5]:
model = CFModel(n_pairs, m_relations, K_FACTORS)
model.load_weights(MODEL_WEIGHTS_FILE)

## Execute evaluation protocol

From [[2]](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/main_cvsc2015.pdf):

> Given a set of triples in a set disjoint from a training
> knowledge graph, we test models on predicting the subject or
> object of each triple, given the relation type and the other
> argument. We rank all entities in the training knowledge base in
> order of their likelihood of filling the argument position. We
> report the mean reciprocal rank of the correct entity, as well as
> HITS@10 – the percent of test triples for which the correct
> argument was ranked in the top ten. We use filtered measures
> following the protocol proposed in Bordes et al. (2013) – that
> is, when we rank entities for a given position, we remove all
> other entities that are known to be part of an existing triple in
> the training, development, or test set. This avoids penalizing
> the model for ranking other correct fillers higher than the
> tested argument. We thus report filtered mean reciprocal
> rank (labeled MRR in the Figures), and filtered HITS@10. In the
> figures we present MRR values scaled by 100, so that the maximum
> possible MRR is 100.

** *Note: filtering not yet implemented, code neither complete nor debugged for non-filtering case anyways* **

In [52]:
def sp_query_reciprocal_rank(model, subj, rid, obj, entities):
    objs = [ result[0] for result in sp_query_results(model, subj, rid, entities) ]
    return reciprocal_rank(obj, objs)

def sp_query_hits_at_10(model, subj, rid, obj, entities):
    objs = [ result[0] for result in sp_query_results(model, subj, rid, entities) ]
    if obj in objs[:10]:
        return 1.0
    else:
        return 0.0

def sp_query_results(model, subj, rid, entities):
    return sorted([ [ subj, model.rank(pid, rid) ] for pid in sp_query_pairs(subj, entities) ], 
                  reverse=True, 
                  key=itemgetter(1))

def sp_query_pairs(subj, entities):
    return [ pair_id(subj, obj) for obj in entities if pair_id(subj, obj) > -1 ]

def po_query_reciprocal_rank(model, subj, rid, obj, entities):
    subjs = [ result[0] for result in po_query_results(model, obj, rid, entities) ]
    return reciprocal_rank(subj, subjs)

def po_query_hits_at_10(model, subj, rid, obj, entities):
    subjs = [ result[0] for result in po_query_results(model, obj, rid, entities) ]
    if subj in subjs[:10]:
        return 1.0
    else:
        return 0.0

def po_query_results(model, obj, rid, entities):
    return sorted([ [ obj, model.rank(pid, rid) ] for pid in sp_query_pairs(subj, entities) ], 
                  reverse=True, 
                  key=itemgetter(1))

def po_query_pairs(obj, entities):
    return [ pair_id(subj, obj) for subj in entities if pair_id(subj, obj) > -1]

def pair_id(subj, obj):
    pair = entity_pairs[(entity_pairs['subj'] == subj) & (entity_pairs['obj'] == obj)]
    if len(pair) > 0:
        return pair['pid'].values[0]
    else:
        return -1

def reciprocal_rank(correct_response, responses):
    return 1. / np.float(np.where(responses == correct_response)[0][0])

In [6]:
pairs = entity_pairs.to_dict(orient='records')
subj_idx = {}
obj_idx = {}
for pair in pairs:
    subj = pair['subj']
    obj = pair['obj']
    pid = pair['pid']
    if subj not in subj_idx.keys():
        subj_idx[subj] = {}
    if obj not in obj_idx.keys():
        obj_idx[obj] = {}
    subj_idx[subj][obj] = pid
    obj_idx[obj][subj] = pid

In [7]:
tuples = triples.to_dict(orient='records')

In [8]:
np.where(np.array(['fee', 'fi', 'foo', 'fum']) == 'wubba')

(array([], dtype=int64),)

In [9]:
for tuple in tuples:
    scores = []
    subj = tuple['subj']
    rid = tuple['rid']
    obj = tuple['obj']
    for entity in entities:
        if subj in subj_idx.keys() and entity in subj_idx[subj].keys():
            pid = subj_idx[subj][entity]
            score = model.rate(pid, rid)
            scores.append([entity, score])
    scores = sorted(scores, reverse=True, key=itemgetter(1))
    results = [ x[0] for x in scores ]
    print obj, np.where(np.array(results) == obj)[0]

/m/05lf_ []
/m/01q99h []
/m/0f8l9c []
/m/013t85 [3362]
/m/0m0bj [1282]
/m/04ghz4m []
/m/04y9mm8 []
/m/04j53 []
/m/082gq []
/m/029q3k [13]
/m/0jgd []
/m/0dxmyh []
/m/02lq67 []
/m/02hcv8 []
/m/02xry []
/m/01g5v []
/m/09nqf []
/m/0sxkh []
/m/027jk []
/m/04vjh [240]
/m/01mkq [49]
/m/0631_ [1634]
/m/0d060g [191]
/m/01bx35 []
/m/01s695 []
/m/04n6k []
/m/0121rx []
/m/07ssc []
/m/01_d4 []
/m/0lcx []
/m/04399 [333]
/m/01lj9 []
/m/01y9jr []
/m/09gq0x5 []
/m/011yn5 []
/m/01d_h8 []
/m/0dtfn [530]
/m/04zl8 []
/m/01gb54 [188]
/m/0ckd1 []
/m/015fr []
/m/0dky9n []
/m/0jm_ []
/m/0411q []
/m/030wkp [1375]
/m/09nqf []
/m/09jwl []
/m/04k4l []
/m/027l4q []
/m/0d1qmz []
/m/04jplwp []
/m/0g_rs_ []
/m/03h_yfh []
/m/0bdw6t [536]
/m/0k2cb []
/m/06c62 [25]
/m/09rwjly []
/m/05zppz []
/m/04f62k []
/m/013m43 [1777]
/m/01d_h8 []
/m/062zjtt [25]
/m/01c9dd []
/m/05dppk [9]
/m/013b2h []
/m/07yk1xz []
/m/02bjrlw []
/m/03z19 []
/m/03s6l2 [412]
/m/05z96 []
/m/02zk08 []
/m/0kz10 [7]
/m/0vbk []
/m/07ssc [21]
/m/0f4hc []
/m/

KeyboardInterrupt: 

In [None]:
triples.head(3)

In [None]:
entity_pairs.head(5)

In [27]:
x = []
for entity in entities:
    if entity in subj_idx['/m/01sl1q'].keys():
        print entity, subj_idx['/m/01sl1q'][entity]
len(x)

/m/09c7w0 171784
/m/01n7q 1574663
/m/02_286 1587424
/m/07b_l 222038
/m/0f2wj 1371196
/m/027rn 263890
/m/030qb3t 872911
/m/03gh4 648956
/m/01rzqj 1788071
/m/0gsgr 2827105
/m/0f2tj 2888182
/m/01vq3 2667718
/m/01vvyd8 1180479
/m/04h9h 2612222
/m/019pm_ 2424398
/m/0f2w0 183894
/m/0151ns 835523
/m/05r7t 189971
/m/0bth54 147213
/m/0f4vbz 1598898
/m/0227vl 1479822
/m/01vw20_ 1806086
/m/0jfx1 2887349
/m/029j_ 885065
/m/0cgbf 1068570
/m/01b7h8 1393620
/m/05qg6g 338119
/m/06w2sn5 595257
/m/03lt8g 1699776
/m/08phg9 2909950
/m/0hvb2 2668270
/m/03v3xp 2507885
/m/0828jw 152541
/m/06cv1 1068571
/m/012xdf 2873677
/m/0c1pj 2166231
/m/015f7 2801017
/m/036hf4 2618462
/m/016tw3 955114
/m/01pgzn_ 868489
/m/0bq2g 2507886
/m/049qx 151042
/m/01whg97 2777501
/m/01f6zc 2423221
/m/03_x5t 1699777
/m/012vd6 2925437
/m/01vw8mh 2028657
/m/01lly5 2821300
/m/01g23m 2873128
/m/0gn30 1767766
/m/01jfrg 1478730
/m/01wxyx1 1180480
/m/0436kgz 2285460
/m/03_gd 1235462
/m/06cgy 988787
/m/01vwllw 2438476
/m/05mkhs 1721264
/m/0

0

In [None]:
model.rank

In [25]:
len(triples)

20466

In [53]:
po_query_pairs('/m/027rn', entities)

KeyboardInterrupt: 

In [36]:
triples['sp_reciprocal_rank'] = sp_query_reciprocal_rank(model, triples['subj'], triples['rid'], triples['obj'], entities)
triples['po_reciprocal_rank'] = po_query_reciprocal_rank(model, triples['subj'], triples['rid'], triples['obj'], entities)
triples['sp_hits_at_10'] = sp_query_hits_at_10(model, triples['subj'], triples['rid'], triples['obj'], entities)
triples['po_hits_at_10'] = po_query_hits_at_10(model, triples['subj'], triples['rid'], triples['obj'], entities)

mrr = (triples['sp_reciprocal_rank'].sum() + triples['po_reciprocal_rank'].sum()) / (np.float(len(triples)) * 2.0)
hits_at_10 = (triples['sp_hits_at_10'].sum() + triples['po_hits_at_10'].sum()) / (np.float(len(triples)) * 2.0)

print 'Mean reciprocal rank:', mrr
print 'HITS@10:', hits_at_10

0          /m/08966
1         /m/01hww_
2        /m/09v3jyg
3          /m/02jx1
4          /m/02jx1
5         /m/02bfmn
6        /m/05zrvfd
7          /m/060bp
8         /m/07l450
9         /m/07h1h5
10       /m/0djb3vw
11         /m/031y2
12         /m/01d8l
13         /m/0ydpd
14        /m/0738b8
15         /m/070xg
16         /m/0mwk9
17        /m/09306z
18         /m/0kbws
19         /m/0kbws
20        /m/017j69
21         /m/07z1m
22       /m/03gqgt3
23        /m/01c9f2
24        /m/01c9f2
25         /m/01lp8
26        /m/01xwqn
27       /m/065ym0c
28        /m/06kxk2
29        /m/08q1tg
            ...    
20436    /m/043sct5
20437      /m/0c921
20438      /m/0d085
20439     /m/05mt_q
20440     /m/02hct1
20441     /m/084qpk
20442      /m/085h1
20443      /m/0glnm
20444       /m/0mkg
20445      /m/0gthm
20446    /m/0gvvf4j
20447      /m/0b3wk
20448      /m/0l15n
20449      /m/0_565
20450    /m/02ppm4q
20451     /m/03q3sy
20452    /m/0gl02yg
20453      /m/0fkvn
20454    /m/05q8pss


ValueError: Series lengths must match to compare

## References

[[1]](https://www.microsoft.com/en-us/download/details.aspx?id=52312) K. Toutanova, "FB215-237 Knowledge Base Completion Dataset," Web page https://www.microsoft.com/en-us/download/details.aspx?id=52312, May 2016. Last accessed 2016-08-14.

[[2]](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/main_cvsc2015.pdf) K. Toutanova and D. Chen, “Observed versus latent features for knowledge base and text inference,” in 3rd Workshop on Continuous Vector Space Models and Their Compositionality, Jul. 2015.