In [21]:
import os
import sys
import argparse
import pickle
import math
import unicodedata
import pandas as pd
import numpy as np

from collections import defaultdict
from fuzzywuzzy import fuzz
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.corpus import stopwords

In [22]:
# arguments
results_path = 'linking-results/'

In [23]:
def get_questions(datapath):
    print("getting questions...")
    id2question = {}
    with open(datapath, 'r') as f:
        for line in f:
            items = line.strip().split("\t")
            lineid = items[0].strip()
            sub = items[1].strip()
            name = items[2].strip()
            pred = items[3].strip()
            obj = items[4].strip()
            question = items[5].strip()
            # print("{}   -   {}".format(lineid, question))
            id2question[lineid] = (sub, name, pred, question)
    return id2question

datapath = "../data/SimpleQuestions_v2_augmented/all.txt"
id2question = get_questions(datapath)
print(len(id2question))
print(id2question['valid-1'])

getting questions...
107808
('fb:m.0f3xg_', 'trump ocean club international hotel and tower', 'fb:symbols.namesake.named_after', 'who was the trump ocean club international hotel and tower named after')


In [24]:
def get_mids(fpath):
    id2mids = defaultdict(list)
    with open(fpath, 'r') as f:
        for i, line in enumerate(f):
            items = line.strip().split(" %%%% ")
            lineid = items[0]
            cand_mids = items[1:]            
            for mid_entry in cand_mids:
                mid, mid_name, score = mid_entry.split("\t")
                id2mids[lineid].append(mid)
    return id2mids

def get_retrieval_rate(id2mids, id2question, hits, idspath):
    n_total = 0
    n_retrieved = 0
    lineids = open(idspath, 'r').read().splitlines()
    for lineid in lineids:
        n_total += 1
        truth_mid, truth_name, truth_rel, question = id2question[lineid]
        if not lineid in id2mids.keys():
            continue
        cand_mids = id2mids[lineid][:hits]
        if truth_mid in cand_mids:
            n_retrieved += 1
    return (n_retrieved / n_total) * 100.0

In [26]:
datasets = ['train', 'valid', 'test']
hits = [1, 5, 20, 50, 100]
for dataset in datasets:
    fpath = results_path + "{}-h100.txt".format(dataset)
    id2mids = get_mids(fpath)
    idspath = '../data/SimpleQuestions_v2_augmented/{}_lineids.txt'.format(dataset)
    for hit in hits:
        retrieval = get_retrieval_rate(id2mids, id2question, hit, idspath)
        print("data: {}, hits: {}, retrieval: {}".format(dataset, hit, retrieval))

data: train, hits: 1, retrieval: 64.062168751325
data: train, hits: 5, retrieval: 78.87958448166206
data: train, hits: 20, retrieval: 84.93216027135891
data: train, hits: 50, retrieval: 88.04059783760864
data: train, hits: 100, retrieval: 90.30236379054483
data: valid, hits: 1, retrieval: 63.508120649651964
data: valid, hits: 5, retrieval: 78.37587006960557
data: valid, hits: 20, retrieval: 84.45475638051045
data: valid, hits: 50, retrieval: 87.53596287703016
data: valid, hits: 100, retrieval: 89.70765661252901
data: test, hits: 1, retrieval: 62.43216919437874
data: test, hits: 5, retrieval: 77.25059134548489
data: test, hits: 20, retrieval: 83.50262047214879
data: test, hits: 50, retrieval: 86.81879319141042
data: test, hits: 100, retrieval: 89.01256899030658
