In [27]:
import os
import sys
import argparse
import pickle
import math
import unicodedata
import pandas as pd
import numpy as np

from fuzzywuzzy import fuzz
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.corpus import stopwords

In [2]:
# arguments
index_entpath = "../indexes/entity_2M.pkl"
index_reachpath = "../indexes/reachability_2M.pkl"
index_namespath = "../indexes/names_2M.pkl"
ent_resultpath = "../entity_detection/query-text/val.txt"
rel_resultpath = "../relation_prediction/results/topk-retrieval-valid-hits-3.txt"
outpath = "./tmp/results"

In [3]:
tokenizer = TreebankWordTokenizer()
stopwords = set(stopwords.words('english'))

def tokenize_text(text):
    tokens = tokenizer.tokenize(text)
    return tokens

def www2fb(in_str):
    if in_str.startswith("www.freebase.com"):
        return 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.'))
    return in_str

def get_index(index_path):
    print("loading index from: {}".format(index_path))
    with open(index_path, 'rb') as f:
        index = pickle.load(f)
    return index

def strip_accents(text):
    return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn')


In [4]:
def get_query_text(ent_resultpath):
    print("getting query text...")
    lineids = []
    id2query = {}
    notfound = 0
    with open(ent_resultpath, 'r') as f:
        for line in f:
            items = line.strip().split(" %%%% ")
            try:
                lineid = items[0].strip()
                query = items[1].strip()
                # mid = items[2].strip()
            except:
                # print("ERROR: line does not have >2 items  -->  {}".format(line.strip()))
                notfound += 1
                continue
            # print("{}   -   {}".format(lineid, query))
            lineids.append(lineid)
            id2query[lineid] = query
    print("notfound (empty query text): {}".format(notfound))
    return lineids, id2query

def get_relations(rel_resultpath):
    print("getting relations...")
    lineids = []
    id2rels = {}
    with open(rel_resultpath, 'r') as f:
        for line in f:
            items = line.strip().split(" %%%% ")
            lineid = items[0].strip()
            rel = www2fb(items[1].strip())
            label = items[2].strip()
            score = items[3].strip()
            # print("{}   -   {}".format(lineid, rel))
            if lineid in id2rels.keys():
                id2rels[lineid].append( (rel, label, score) )
            else:
                id2rels[lineid] = [(rel, label, score)]
                lineids.append(lineid)
    return lineids, id2rels

In [5]:
def find_ngrams(input_list, n):
    ngrams = zip(*[input_list[i:] for i in range(n)])
    return set(ngrams)

In [6]:
def calc_tf_idf(query, cand_ent_name, cand_ent_count, num_entities, index_ent):
    query_terms = tokenize_text(query)
    doc_tokens = tokenize_text(cand_ent_name)
    common_terms = set(query_terms).intersection(set(doc_tokens))

    # len_intersection = len(common_terms)
    # len_union = len(set(query_terms).union(set(doc_tokens)))
    # tf = len_intersection / len_union
    tf = math.log10(cand_ent_count + 1)
    k1 = 0.5
    k2 = 0.5
    total_idf = 0
    for term in common_terms:
        df = len(index_ent[term])
        idf = math.log10( (num_entities - df + k1) / (df + k2) )
        total_idf += idf
    return tf * total_idf

In [7]:
def pick_best_name(question, names_list):
    best_score = None
    best_name = None
    for name in names_list:
        score =  fuzz.ratio(name, question)
        if best_score == None or score > best_score:
            best_score = score
            best_name = name

    return best_name

In [8]:
rel_lineids, id2rels = get_relations(rel_resultpath)
ent_lineids, id2query = get_query_text(ent_resultpath)  # ent_lineids may have some examples missing

getting relations...
getting query text...
notfound (empty query text): 0


In [9]:
def get_questions(datapath):
    print("getting questions...")
    id2question = {}
    with open(datapath, 'r') as f:
        for line in f:
            items = line.strip().split("\t")
            lineid = items[0].strip()
            sub = items[1].strip()
            pred = items[2].strip()
            obj = items[3].strip()
            question = items[4].strip()
            # print("{}   -   {}".format(lineid, question))
            if lineid.startswith("valid"):
                id2question[lineid] = (sub, pred, question)
    return id2question

datapath = "../data/SimpleQuestions_v2_modified/all.txt"
id2question = get_questions(datapath)
print(len(id2question))
print(id2question['valid-1'])

getting questions...
10845
('fb:m.0f3xg_', 'fb:symbols.namesake.named_after', 'Who was the trump ocean club international hotel and tower named after')


In [10]:
print(len(id2query))
print(len(id2rels))

10845
10845


In [32]:
data = []
for N in [5, 20, 100]:
    print("N - {}".format(N))
    row = []
    row.append(N)
    for sim in ["tfidf", "fuzzy"]:
        fname = "id2mids_h-{}_s-{}.pkl".format(N, sim)
        id2mids = pickle.load(open(fname, "rb"))

        found = 0
        notfound = 0

        for lineid in id2question.keys():
            if lineid not in id2mids.keys():
                notfound += 1
                continue

            found_this_example = False
            truth_mid, truth_rel, question = id2question[lineid]
        #     print(id2question[lineid])
            for (mid, mid_name, mid_score) in id2mids[lineid]:
                if mid == truth_mid:
                        found_this_example = True
                        break


            if found_this_example:
                found += 1
            else:
                notfound += 1    

        retrieval = found / (found + notfound) * 100.0
        row.append(retrieval)
    #     print(found)
    #     print(notfound)    
        print(sim, retrieval)
    print("-" * 40)
    data.append(row)

df = pd.DataFrame(data)
df.columns = ['N', 'tfidf', 'fuzzy']
df.head()

N - 5
tfidf 69.75564776394651
fuzzy 76.87413554633473
----------------------------------------
N - 20
tfidf 77.48271092669434
fuzzy 83.28261871830337
----------------------------------------
N - 100
tfidf 84.96081143384048
fuzzy 88.27109266943292
----------------------------------------


Unnamed: 0,N,tfidf,fuzzy
0,5,69.755648,76.874136
1,20,77.482711,83.282619
2,100,84.960811,88.271093


In [16]:
mids_not_retrieved = len(id2question) - len(id2mids)
mids_not_retrieved

57

In [23]:
id2mids['valid-1']

[('fb:m.0f3xg_', 'trump ocean club international hotel and tower', 1.0),
 ('fb:m.031n7n', 'trump international hotel and tower', 0.86),
 ('fb:m.05d9c4', 'trump international hotel and tower', 0.86),
 ('fb:m.08cbdd', 'trump international hotel and tower', 0.86),
 ('fb:m.07dwg4', 'trump international hotel and tower , las vegas', 0.75)]

In [17]:
len(id2rels)
id2rels['valid-1']

[('fb:symbols.namesake.named_after', '1', '-0.9163265228271484'),
 ('fb:protected_sites.protected_site.governing_body',
  '0',
  '-2.0246400833129883'),
 ('fb:sports.professional_sports_team.owner_s', '0', '-2.1866111755371094')]