In [203]:
import os
import sys
import argparse
import pickle
import math
import unicodedata
import pandas as pd
import numpy as np
import math

%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

from collections import defaultdict
from fuzzywuzzy import fuzz
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.corpus import stopwords
from sklearn.linear_model import LogisticRegression
from random import shuffle

import pprint
pp = pprint.PrettyPrinter(indent=4)

In [204]:
# read the lineids
data_path = 'SimpleQuestions_v2_augmented/'
train_lineids = open(os.path.join(data_path, "train_lineids.txt"), 'r').read().splitlines()
valid_lineids = open(os.path.join(data_path, "valid_lineids.txt"), 'r').read().splitlines()
test_lineids = open(os.path.join(data_path, "test_lineids.txt"), 'r').read().splitlines()

In [205]:
# paths
ent_path = 'entity-linking-results/'
train_ent_path = os.path.join(ent_path, "train-h100.txt")
valid_ent_path = os.path.join(ent_path, "valid-h100.txt")
test_ent_path = os.path.join(ent_path, "test-h100.txt")

rel_path = 'relation-pred-results/'
train_rel_path = os.path.join(rel_path, "topk-retrieval-train-hits-5.txt")
valid_rel_path = os.path.join(rel_path, "topk-retrieval-valid-hits-5.txt")
test_rel_path = os.path.join(rel_path, "topk-retrieval-test-hits-5.txt")

In [206]:
def get_questions(datapath):
    print("getting questions...")
    id2question = {}
    with open(datapath, 'r') as f:
        for line in f:
            items = line.strip().split("\t")
            lineid = items[0].strip()
            sub = items[1].strip()
            name = items[2].strip()
            pred = items[3].strip()
            obj = items[4].strip()
            question = items[5].strip()
            # print("{}   -   {}".format(lineid, question))
            id2question[lineid] = (sub, name, pred, question)
    return id2question

id2question = get_questions(os.path.join(data_path, "all.txt"))
print(len(id2question))
print(id2question['valid-1'])

getting questions...
107808
('fb:m.0f3xg_', 'trump ocean club international hotel and tower', 'fb:symbols.namesake.named_after', 'who was the trump ocean club international hotel and tower named after')


In [207]:
def get_mids(fpath, hits):
    id2mids = defaultdict(list)
    with open(fpath, 'r') as f:
        for i, line in enumerate(f):
            items = line.strip().split(" %%%% ")
            lineid = items[0]
            cand_mids = items[1:][:hits]     
            for mid_entry in cand_mids:
                mid, mid_name, score = mid_entry.split("\t")
                id2mids[lineid].append( (mid, mid_name, score) )
    return id2mids

def www2fb(in_str):
    if in_str.startswith("www.freebase.com"):
        return 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.'))
    return in_str

def get_rels(rel_resultpath, hits):
    id2rels = {}
    with open(rel_resultpath, 'r') as f:
        for line in f:
            items = line.strip().split(" %%%% ")
            lineid = items[0].strip()
            rel = www2fb(items[1].strip())
            label = items[2].strip()
            score = items[3].strip()
            # print("{}   -   {}".format(lineid, rel))
            if lineid in id2rels.keys():
                if len(id2rels[lineid]) < hits:
                    id2rels[lineid].append( (rel, label, score) )
            else:
                id2rels[lineid] = [(rel, label, score)]
    return id2rels

In [208]:
def get_index(index_path):
    print("loading index from: {}".format(index_path))
    with open(index_path, 'rb') as f:
        index = pickle.load(f)
    return index

# load up graph reachability index
index_reachpath = "../indexes/reachability_2M.pkl"
index_reach = get_index(index_reachpath)

loading index from: ../indexes/reachability_2M.pkl


### Load up the reachability graph, i.e. what predicates exist for which MIDs

In [209]:
index_reach['fb:m.0n1vy1h']

{'fb:common.topic.notable_types',
 'fb:people.person.gender',
 'fb:people.person.profession'}

### Load up the predicted MIDs and relations for each question in train/valid/test set

In [210]:
# load up the linking, relation prediction results
train_id2mids = get_mids(train_ent_path, hits=20)
train_id2rels = get_rels(train_rel_path, hits=5)

valid_id2mids = get_mids(valid_ent_path, hits=20)
valid_id2rels = get_rels(valid_rel_path, hits=5)

test_id2mids = get_mids(test_ent_path, hits=20)
test_id2rels = get_rels(test_rel_path, hits=5)

In [211]:
valid_id2mids['valid-1']

[('fb:m.0f3xg_', 'trump ocean club international hotel and tower', '1.0'),
 ('fb:m.031n7n', 'trump international hotel and tower', '0.86'),
 ('fb:m.08cbdd', 'trump international hotel and tower', '0.86'),
 ('fb:m.05d9c4', 'trump international hotel and tower', '0.86'),
 ('fb:m.07dwg4', 'trump international hotel and tower , las vegas', '0.75')]

In [212]:
valid_id2rels['valid-1']

[('fb:symbols.namesake.named_after', '1', '-0.26024723052978516'),
 ('fb:aviation.aircraft_model.manufacturer', '0', '-2.970513343811035'),
 ('fb:award.award.presented_by', '0', '-3.0647497177124023'),
 ('fb:organization.organization.founders', '0', '-3.5357484817504883'),
 ('fb:time.event.instance_of_recurring_event', '0', '-4.868573188781738')]

### For the validation set, do cross linking and check retrieval rates

The retrieval rate is checked in total and also for top 1, 2, 3 cross-linked results. The results are sorted according to the combined score for each (mid, rel) pair

combined score = entity linking score * relation prediction score

In [213]:
lineids = valid_lineids
id2mids = valid_id2mids
id2rels = valid_id2rels

id2answers = defaultdict(list)
found, notfound_both, notfound_mid, notfound_rel = 0, 0, 0, 0
retrieved, retrieved_top1, retrieved_top2, retrieved_top3 = 0, 0, 0, 0

lineids_found1 = []
lineids_found2 = []
lineids_found3 = []

# for every lineid 
for i, lineid in enumerate(lineids):
    if i % 10000 == 0:
        print("line {}".format(i))
    # sanity checks and get truth
    if lineid not in id2mids.keys() and lineid not in id2rels.keys():
        notfound_both += 1
        continue
    elif lineid not in id2mids.keys():
        notfound_mid += 1
        continue
    elif lineid not in id2rels.keys():
        notfound_rel += 1
        continue
    
    found += 1
    truth_mid, truth_name, truth_rel, question = id2question[lineid]        
    # for every predicted mid for this lineid
    mids = id2mids[lineid]
    rels = id2rels[lineid]
    for (mid, mid_name, mid_score) in mids:
        # for every rel for this lineid
        for (rel, rel_label, rel_log_score) in rels:
            # if this (mid, rel) exists in FB
            if rel in index_reach[mid]:
                rel_score = math.exp(float(rel_log_score))
                comb_score = float(mid_score) * rel_score
                id2answers[lineid].append( (mid, rel, mid_name, mid_score, rel_score, comb_score) )
            if mid == truth_mid and rel == truth_rel:
                retrieved += 1
    id2answers[lineid].sort(key=lambda t: t[5], reverse=True)
    
    if len(id2answers[lineid]) >= 1 and id2answers[lineid][0][0] == truth_mid and id2answers[lineid][0][1] == truth_rel:
        retrieved_top1 += 1
        retrieved_top2 += 1
        retrieved_top3 += 1
        lineids_found1.append(lineid)
    elif len(id2answers[lineid]) >= 2 and id2answers[lineid][1][0] == truth_mid and id2answers[lineid][1][1] == truth_rel:
        retrieved_top2 += 1
        retrieved_top3 += 1
        lineids_found2.append(lineid)
    elif len(id2answers[lineid]) >= 3 and id2answers[lineid][2][0] == truth_mid and id2answers[lineid][2][1] == truth_rel:
        retrieved_top3 += 1
        lineids_found3.append(lineid)
               

line 0
line 10000


In [214]:
id2answers['valid-1']

[('fb:m.0f3xg_',
  'fb:symbols.namesake.named_after',
  'trump ocean club international hotel and tower',
  '1.0',
  0.7708609818740425,
  0.7708609818740425),
 ('fb:m.031n7n',
  'fb:symbols.namesake.named_after',
  'trump international hotel and tower',
  '0.86',
  0.7708609818740425,
  0.6629404444116765),
 ('fb:m.05d9c4',
  'fb:symbols.namesake.named_after',
  'trump international hotel and tower',
  '0.86',
  0.7708609818740425,
  0.6629404444116765),
 ('fb:m.07dwg4',
  'fb:symbols.namesake.named_after',
  'trump international hotel and tower , las vegas',
  '0.75',
  0.7708609818740425,
  0.5781457364055319)]

### Results - retrieval rate

The retrieval rate is 81.28% overall.
At top = 1, the retrieval rate is 69.93%
At top = 2, the retrieval rate is 76.14%
At top = 3, the retrieval rate is 77.95%

In [215]:
found / len(valid_lineids)

0.9955452436194896

In [216]:
retrieved / len(valid_lineids)

0.8128074245939675

In [217]:
retrieved_top1 / len(valid_lineids)

0.6993039443155452

In [218]:
retrieved_top2 / len(valid_lineids)

0.7614849187935034

In [219]:
retrieved_top3 / len(valid_lineids)

0.7795823665893271

### Inspecting the samples not retrieved at top = 1

In [220]:
def print_info(lineids):
    for lineid in lineids:
        print("Question: {}".format(id2question[lineid][3]))
        pp.pprint(id2answers[lineid])
        print("-" * 40)

In [222]:
id2answers['valid-1']

[('fb:m.0f3xg_',
  'fb:symbols.namesake.named_after',
  'trump ocean club international hotel and tower',
  '1.0',
  0.7708609818740425,
  0.7708609818740425),
 ('fb:m.031n7n',
  'fb:symbols.namesake.named_after',
  'trump international hotel and tower',
  '0.86',
  0.7708609818740425,
  0.6629404444116765),
 ('fb:m.05d9c4',
  'fb:symbols.namesake.named_after',
  'trump international hotel and tower',
  '0.86',
  0.7708609818740425,
  0.6629404444116765),
 ('fb:m.07dwg4',
  'fb:symbols.namesake.named_after',
  'trump international hotel and tower , las vegas',
  '0.75',
  0.7708609818740425,
  0.5781457364055319)]

### What mistakes are mostly made in top=1 position?

In [223]:
incorrect_both, incorrect_mid, incorrect_rel = 0, 0, 0
for lineid in lineids_found2:
    cand_answers = id2answers[lineid]
    top_mid, top_rel, top_mid_name, _, _, _ = cand_answers[0]
    right_mid, right_rel, right_mid_name, _, _, _ = cand_answers[1]
    if top_mid != right_mid and top_rel != right_rel:
        incorrect_both += 1
    elif top_mid != right_mid:
        incorrect_mid += 1
    elif top_rel != right_rel:
        incorrect_rel += 1    

In [224]:
len(lineids_found2)

670

In [227]:
incorrect_both / len(lineids_found2) * 100.0

8.507462686567164

In [228]:
incorrect_rel / len(lineids_found2) * 100.0

39.850746268656714

In [229]:
incorrect_mid / len(lineids_found2) * 100.0

51.64179104477612