In [60]:
import os
import sys
import argparse
import pickle
import math
import unicodedata
import pandas as pd
import numpy as np

from fuzzywuzzy import fuzz
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.corpus import stopwords

In [18]:
# arguments
index_entpath = "../indexes/entity_2M.pkl"
index_reachpath = "../indexes/reachability_2M.pkl"
index_namespath = "../indexes/names_2M.pkl"
ent_resultpath = "../entity_detection/query-text/val.txt"
rel_resultpath = "../relation_prediction/results/topk-retrieval-valid-hits-3.txt"
outpath = "./tmp/results"

In [61]:
tokenizer = TreebankWordTokenizer()
stopwords = set(stopwords.words('english'))

def tokenize_text(text):
    tokens = tokenizer.tokenize(text)
    return tokens

def www2fb(in_str):
    if in_str.startswith("www.freebase.com"):
        return 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.'))
    return in_str

def get_index(index_path):
    print("loading index from: {}".format(index_path))
    with open(index_path, 'rb') as f:
        index = pickle.load(f)
    return index

def strip_accents(text):
    return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn')


In [80]:
def get_query_text(ent_resultpath):
    print("getting query text...")
    lineids = []
    id2query = {}
    notfound = 0
    with open(ent_resultpath, 'r') as f:
        for line in f:
            items = line.strip().split(" %%%% ")
            try:
                lineid = items[0].strip()
                query = items[1].strip()
                # mid = items[2].strip()
            except:
                # print("ERROR: line does not have >2 items  -->  {}".format(line.strip()))
                notfound += 1
                continue
            # print("{}   -   {}".format(lineid, query))
            lineids.append(lineid)
            id2query[lineid] = query
    print("notfound (empty query text): {}".format(notfound))
    return lineids, id2query

def get_relations(rel_resultpath):
    print("getting relations...")
    lineids = []
    id2rels = {}
    with open(rel_resultpath, 'r') as f:
        for line in f:
            items = line.strip().split(" %%%% ")
            lineid = items[0].strip()
            rel = www2fb(items[1].strip())
            label = items[2].strip()
            score = items[3].strip()
            # print("{}   -   {}".format(lineid, rel))
            if lineid in id2rels.keys():
                id2rels[lineid].append( (rel, label, score) )
            else:
                id2rels[lineid] = [(rel, label, score)]
                lineids.append(lineid)
    return lineids, id2rels

In [81]:
_, tmp = get_relations(rel_resultpath)
tmp['valid-1']

getting relations...


[('fb:symbols.namesake.named_after', '1', '-0.9163265228271484'),
 ('fb:protected_sites.protected_site.governing_body',
  '0',
  '-2.0246400833129883'),
 ('fb:sports.professional_sports_team.owner_s', '0', '-2.1866111755371094')]

In [6]:
def find_ngrams(input_list, n):
    ngrams = zip(*[input_list[i:] for i in range(n)])
    return set(ngrams)

In [7]:
def calc_tf_idf(query, cand_ent_name, cand_ent_count, num_entities, index_ent):
    query_terms = tokenize_text(query)
    doc_tokens = tokenize_text(cand_ent_name)
    common_terms = set(query_terms).intersection(set(doc_tokens))

    # len_intersection = len(common_terms)
    # len_union = len(set(query_terms).union(set(doc_tokens)))
    # tf = len_intersection / len_union
    tf = math.log10(cand_ent_count + 1)
    k1 = 0.5
    k2 = 0.5
    total_idf = 0
    for term in common_terms:
        df = len(index_ent[term])
        idf = math.log10( (num_entities - df + k1) / (df + k2) )
        total_idf += idf
    return tf * total_idf

In [8]:
#outfile = open(os.path.join(outpath, "linking-results.txt"), 'w')
notfound_ent = 0
notfound_c = 0

index_ent = get_index(index_entpath)
num_entities_fbsubset = 1959820  # 2M - 1959820 , 5M - 1972702

loading index from: ../indexes/entity_2M.pkl


In [9]:
index_names = get_index(index_namespath)

loading index from: ../indexes/names_2M.pkl


In [10]:
index_reach = get_index(index_reachpath)

loading index from: ../indexes/reachability_2M.pkl


In [86]:
def pick_best_name(question, names_list):
    best_score = None
    best_name = None
    for name in names_list:
        score =  fuzz.ratio(name, question)
        if best_score == None or score > best_score:
            best_score = score
            best_name = name

    return best_name

In [82]:
rel_lineids, id2rel = get_relations(rel_resultpath)
ent_lineids, id2query = get_query_text(ent_resultpath)  # ent_lineids may have some examples missing

getting relations...
getting query text...
notfound (empty query text): 0


In [30]:
def get_questions(datapath):
    print("getting questions...")
    id2question = {}
    with open(datapath, 'r') as f:
        for line in f:
            items = line.strip().split("\t")
            lineid = items[0].strip()
            sub = items[1].strip()
            pred = items[2].strip()
            obj = items[3].strip()
            question = items[4].strip()
            # print("{}   -   {}".format(lineid, question))
            if lineid.startswith("valid"):
                id2question[lineid] = (sub, pred, question)
    return id2question

datapath = "../data/SimpleQuestions_v2_modified/all.txt"
id2question = get_questions(datapath)
print(len(id2question))
print(id2question['valid-1'])

getting questions...
10845
('fb:m.0f3xg_', 'fb:symbols.namesake.named_after', 'Who was the trump ocean club international hotel and tower named after')


In [31]:
def get_docs(query):
    docids = index_ent[query]
    docs = []
    for id in docids:
        try:
            docs.append( (id, index_names[id]) )
        except:
            continue
    return docs

In [32]:
get_docs("sasha")

[('fb:m.04y6901', ['sasha siem']),
 ('fb:m.027qk4',
  ['sasha',
   'александр борисович годунов',
   'aleksander borisovich godunov',
   'alexander borisovich godunov',
   'aleksandr godunov',
   'alexander godunov',
   'sascha']),
 ('fb:m.0g7s7dd', ['i am... sasha fierce']),
 ('fb:m.0q35ylm', ['alexandra lipskaia', 'sasha lipskaia']),
 ('fb:m.04cgjl', ['sasha victorine']),
 ('fb:m.0c489mr', ['sasha moorsom young', 'sasha moorsom']),
 ('fb:m.03ckj5b', ['sasha toperich', 'sasa toperic', 'sasa toperic']),
 ('fb:m.03pkzn',
  ['alexander popov',
   'the russian rocket',
   'aleksandr popov',
   'the tsar of swimming',
   'alexander popov',
   'alexander vladimirovich popov',
   'sasha']),
 ('fb:m.05zqvbh', ['sasha andrews']),
 ('fb:m.03yg5gn', ['3680 sasha', '1987 my']),
 ('fb:m.01hw40t',
  ['global underground 009',
   'global underground 003 : sasha in san francisco',
   'global underground 009 : sasha in san francisco']),
 ('fb:m.0hy_626',
  ['ante marich darko',
   "sasha d'ark",
   'h

In [33]:
print(len(id2query))
print(len(id2rel))

10845
10845


In [180]:
index_ent["carlos"]

{'fb:m.0cmc07w',
 'fb:m.0fzbtm',
 'fb:m.02x5rzh',
 'fb:m.0ypyl2t',
 'fb:m.0ks7xv',
 'fb:m.06zlybp',
 'fb:m.0bq9t12',
 'fb:m.04mncd0',
 'fb:m.05f9byq',
 'fb:m.06x9qp',
 'fb:m.0dfyj60',
 'fb:m.091rks',
 'fb:m.09v2l7b',
 'fb:m.05zmgk1',
 'fb:m.03nxdj2',
 'fb:m.0jsszsq',
 'fb:m.0b_khsh',
 'fb:m.0_skgc1',
 'fb:m.071rhy',
 'fb:m.01s86h',
 'fb:m.0ryln9h',
 'fb:m.07r4f1',
 'fb:m.0k7gv7b',
 'fb:m.0w5040w',
 'fb:m.01dtnl',
 'fb:m.0j3cw9x',
 'fb:m.04nl9g1',
 'fb:m.09jw3l',
 'fb:m.0qgqxh_',
 'fb:m.0gb4zpj',
 'fb:m.0vp3m68',
 'fb:m.0nby3zv',
 'fb:m.0ngk659',
 'fb:m.09tbhbg',
 'fb:m.0b8vgm',
 'fb:m.08lrqt',
 'fb:m.027vhb2',
 'fb:m.0jt34d9',
 'fb:m.0jtny0',
 'fb:m.0cgzktk',
 'fb:m.0ngkbdk',
 'fb:m.03f08s8',
 'fb:m.0_yw3rf',
 'fb:m.0c5ltt',
 'fb:m.0wjcrq2',
 'fb:m.027wz1p',
 'fb:m.0jv9gj',
 'fb:m.0zdgq_5',
 'fb:m.02wbcqp',
 'fb:m.0h4nk49',
 'fb:m.082l9t',
 'fb:m.02pjqvs',
 'fb:m.0gfgv1v',
 'fb:m.09k81c_',
 'fb:m.0921x2',
 'fb:m.0zwb6ws',
 'fb:m.075xnpx',
 'fb:m.047rq25',
 'fb:m.0bxl72g',
 'fb:m.0277ys

In [66]:
# explore
lineid = 'valid-1'
question = id2question[lineid]
print(question)
cand_relations = id2rel[lineid]
print(cand_relations)
query_text = id2query[lineid].lower()  # lowercase the query
print(query_text)
query_tokens = tokenize_text(query_text)
print(query_tokens)
N = min(len(query_tokens), 3)
print(N)

('fb:m.0f3xg_', 'fb:symbols.namesake.named_after', 'Who was the trump ocean club international hotel and tower named after')
[('fb:symbols.namesake.named_after', '1', '-0.9163265228271484'), ('fb:protected_sites.protected_site.governing_body', '0', '-2.0246400833129883'), ('fb:sports.professional_sports_team.owner_s', '0', '-2.1866111755371094')]
trump ocean club international hotel and tower
['trump', 'ocean', 'club', 'international', 'hotel', 'and', 'tower']
3


In [67]:
C = []  # candidate entities
for n in range(N, 0, -1):
    ngrams_set = find_ngrams(query_tokens, n)
    print("ngrams_set: {}".format(ngrams_set))
    for ngram_tuple in ngrams_set:
        ngram = " ".join(ngram_tuple)
        ngram = strip_accents(ngram)
        # unigram stopwords have too many candidates so just skip over
        if ngram in stopwords:
            continue
        print("ngram: {}".format(ngram))
        ## PROBLEM! - ngram doesnt exist in index - at test-2592 - KeyError: 'p.a.r.c.e. parce'
        try:
            cand_mids = index_ent[ngram]  # search entities
        except:
            continue
        C.extend(cand_mids)
        # print("C: {}".format(C))
    if (len(C) > 0):
        print("early termination...")
        break
    break
print(C)

ngrams_set: {('international', 'hotel', 'and'), ('hotel', 'and', 'tower'), ('club', 'international', 'hotel'), ('ocean', 'club', 'international'), ('trump', 'ocean', 'club')}
ngram: international hotel and
ngram: hotel and tower
ngram: club international hotel
ngram: ocean club international
ngram: trump ocean club
early termination...
['fb:m.08cbdd', 'fb:m.0f3xg_', 'fb:m.07dwg4', 'fb:m.031n7n', 'fb:m.05d9c4', 'fb:m.08cbdd', 'fb:m.0f3xg_', 'fb:m.07dwg4', 'fb:m.031n7n', 'fb:m.05d9c4', 'fb:m.0f3xg_', 'fb:m.0f3xg_', 'fb:m.0f3xg_']


In [68]:
print(index_names['fb:m.0504s2'])

['scotty', 'gomer', 'scott gomez', 'scott carlos gomez']


In [69]:
C_pruned = []
for mid in set(C):
    if mid in index_reach.keys():  # PROBLEM: don't know why this may not exist??
        count_mid = C.count(mid)  # count number of times mid appeared in C
        C_pruned.append((mid, count_mid))

print(C_pruned)

[('fb:m.08cbdd', 2), ('fb:m.0f3xg_', 5), ('fb:m.07dwg4', 2), ('fb:m.031n7n', 2), ('fb:m.05d9c4', 2)]


In [74]:
C_tfidf_pruned = []
for mid, count_mid in C_pruned:
    if mid in index_names.keys():
        cand_ent_name = pick_best_name(question, index_names[mid])
        tfidf = calc_tf_idf(query_text, cand_ent_name, count_mid, num_entities_fbsubset, index_ent)
        C_tfidf_pruned.append((mid, cand_ent_name, tfidf))
# print("C_tfidf_pruned[:10]: {}".format(C_tfidf_pruned[:10]))
print(C_tfidf_pruned)

C_tfidf_pruned.sort(key=lambda t: -t[2])

cand_mids = C_tfidf_pruned[:3]
print(cand_mids)

[('fb:m.08cbdd', 'trump international hotel and tower', 7.349073098354211), ('fb:m.0f3xg_', 'trump ocean club international hotel and tower', 17.035447732786867), ('fb:m.07dwg4', 'trump international hotel and tower , las vegas', 7.349073098354211), ('fb:m.031n7n', 'trump international hotel and tower , chicago', 7.349073098354211), ('fb:m.05d9c4', 'trump international hotel and tower', 7.349073098354211)]
[('fb:m.0f3xg_', 'trump ocean club international hotel and tower', 17.035447732786867), ('fb:m.08cbdd', 'trump international hotel and tower', 7.349073098354211), ('fb:m.07dwg4', 'trump international hotel and tower , las vegas', 7.349073098354211)]


In [87]:
def fuzzy_match_score(name, question):
    score =  fuzz.ratio(name, question)
    return score

In [None]:
notfound_ent = 0
notfound_c = 0
notfound_c_lineids = []
notfound_ent = 0
notcorrect_ent_lineids = []
candidate_mids = {}
HITS_TOP_ENTITIES = 5

for i, lineid in enumerate(rel_lineids):
    if lineid not in ent_lineids:
        notfound_ent += 1
        continue
    
    question = id2question[lineid]
    query_text = id2query[lineid].lower()  # lowercase the query
    query_tokens = tokenize_text(query_text)

    # print("lineid: {}, query_text: {}, relation: {}".format(lineid, query_text, pred_relation))
    # print("query_tokens: {}".format(query_tokens))

    N = min(len(query_tokens), 3)
    C = []  # candidate entities
    for n in range(N, 0, -1):
        ngrams_set = find_ngrams(query_tokens, n)
        # print("ngrams_set: {}".format(ngrams_set))
        for ngram_tuple in ngrams_set:
            ngram = " ".join(ngram_tuple)
            ngram = strip_accents(ngram)
            # unigram stopwords have too many candidates so just skip over
            if ngram in stopwords:
                continue
            # print("ngram: {}".format(ngram))
            ## PROBLEM! - ngram doesnt exist in index - at test-2592 - KeyError: 'p.a.r.c.e. parce'
            try:
                cand_mids = index_ent[ngram]  # search entities
            except:
                continue
            C.extend(cand_mids)
            # print("C: {}".format(C))
        if (len(C) > 0):
            # print("early termination...")
            break
    # print("C[:5]: {}".format(C[:5]))

    # relation correction
    C_pruned = []
    for mid in set(C):
        if mid in index_reach.keys():  # PROBLEM: don't know why this may not exist??
            count_mid = C.count(mid)  # count number of times mid appeared in C
            C_pruned.append((mid, count_mid))

    C_tfidf_pruned = []
    for mid, count_mid in C_pruned:
        if mid in index_names.keys():
            cand_ent_name = pick_best_name(question, index_names[mid])
#             score = calc_tf_idf(query_text, cand_ent_name, count_mid, num_entities_fbsubset, index_ent)
            score = fuzzy_match_score(cand_ent_name, query_text)
            C_tfidf_pruned.append((mid, cand_ent_name, score))
    # print("C_tfidf_pruned[:10]: {}".format(C_tfidf_pruned[:10]))

    if len(C_tfidf_pruned) == 0:
        #print("WARNING: C_tfidf_pruned is empty.")
        notfound_c_lineids.append(lineid)
        notfound_c += 1
        continue

    C_tfidf_pruned.sort(key=lambda t: -t[2])
    cand_mids = C_tfidf_pruned[:HITS_TOP_ENTITIES]

    candidate_mids[lineid] = cand_mids
#     print(question)
#     print(cand_mids)
#     print("-" * 40)
#     if (i+1) % 3 == 0:
#         break

In [None]:
candidate_mids['valid-3']

In [201]:
notfound_c_lineids = ['test-40', 'test-146', 'test-312', 'test-414', 'test-578', 'test-742', 'test-848']
df = []
pred_rel_fault = 0
names_fault = 0
other_fault = 0
names = ['question', 'gold_entity_id', 'gold_entity_name', 'gold_relation', 'query_text', 'predicted_relation']
for lineid in notfound_c_lineids[:50]:
    ent, rel, question = id2question[lineid]
    pred_relation = www2fb(id2rel[lineid])
    query_text = id2query[lineid].lower()  # lowercase the query
#     pred_ent_mid = id2pred_ent[lineid]
    if not pred_relation == rel:
        pred_rel_fault += 1
    else:
        other_fault += 1
    if ent in index_names.keys():
        row = [question, ent, pick_best_name(question, index_names[ent]), rel, query_text, pred_relation]
    else:
        names_fault += 1
    df.append(row)

df = pd.DataFrame(df)
df.columns = names

print(pred_rel_fault)
print(names_fault)
print(other_fault)

3
0
4


In [202]:
df

Unnamed: 0,question,gold_entity_id,gold_entity_name,gold_relation,query_text,predicted_relation
0,Which label is somevelvetsidewalk signed to,fb:m.01pm4nb,some velvet sidewalk,fb:music.artist.label,somevelvetsidewalk,fb:music.artist.label
1,what is a short-lived British sitcom series,fb:m.0c4xc,situation comedy,fb:tv.tv_genre.programs,short-lived,fb:tv.tv_genre.programs
2,Who is the focus of uttar pradesh has more tha...,fb:m.0j2hj_0,uttar pradesh has more than one capital.,fb:base.uncommon.exception.focus,than,fb:base.culturalevent.event.entity_involved
3,what genre of music is locd out,fb:m.01rrs9n,loc 'd out,fb:music.album.genre,locd out,fb:music.album.genre
4,what is cassiesteele's gender?,fb:m.03_fby,cassie steele,fb:people.person.gender,cassiesteele,fb:people.person.gender
5,what is the genre of gusgofficial,fb:m.03f3bp7,kostas karamitroudis,fb:music.artist.genre,gusgofficial,fb:music.album.genre
6,What is the title of the netlix film in the ge...,fb:m.03_3d,land of the rising sun,fb:media_common.netflix_genre.titles,netlix,fb:film.film_genre.films_in_this_genre


In [208]:
notcorrect_ent_lineids = ['test-2', 'test-6', 'test-7', 'test-8', 'test-11', 'test-12', 'test-15', 'test-18', 'test-19', 'test-20', 'test-22', 'test-23', 'test-24', 'test-25', 'test-26', 'test-27', 'test-29', 'test-30', 'test-31', 'test-32', 'test-39', 'test-41', 'test-43', 'test-45', 'test-47', 'test-48', 'test-49', 'test-50', 'test-52', 'test-61', 'test-65', 'test-66', 'test-67', 'test-68', 'test-69', 'test-71', 'test-73', 'test-75', 'test-76', 'test-77', 'test-78', 'test-79', 'test-81', 'test-82', 'test-85', 'test-87', 'test-88', 'test-91', 'test-92', 'test-93', 'test-94', 'test-95', 'test-96', 'test-97', 'test-98', 'test-100', 'test-104', 'test-105', 'test-110', 'test-112', 'test-113', 'test-118', 'test-125', 'test-127', 'test-128']

df = []
pred_rel_fault = 0
names_fault = 0
other_fault = 0
names = ['question', 'gold_ent_name', 'query_text', 'predicted_ent_name']
for lineid in notcorrect_ent_lineids[:50]:
    ent, rel, question = id2question[lineid]
    pred_relation = www2fb(id2rel[lineid])
    query_text = id2query[lineid].lower()  # lowercase the query
    pred_ent_mid = id2pred_ent[lineid]
    if not pred_relation == rel:
        pred_rel_fault += 1
    else:
        other_fault += 1
    if ent in index_names.keys():
        row = [question, pick_best_name(question,index_names[ent]), query_text,  pick_best_name(question,index_names[pred_ent_mid])]
    else:
        names_fault += 1
    df.append(row)

df = pd.DataFrame(df)
df.columns = names

print(pred_rel_fault)
print(names_fault)
print(other_fault)

13
1
37


In [209]:
df

Unnamed: 0,question,gold_ent_name,query_text,predicted_ent_name
0,what format is fearless,fearless,fearless,fearless
1,what was the cause of death of yves klein,yves klein,yves klein,yves klein blue
2,Which equestrian was born in dublin?,"dublin , republic of ireland",dublin,dublin
3,What is a tv action show?,action,action show,junit in action
4,What's a song by jean grae,jean grae,jean grae,grae fruits : the jean grae compilation
5,What position does carlos gomez play?,carlos argelis gomez pena,carlos gomez,scott carlos gomez
6,What's a release on pretty in pink,pretty in pink,pretty in pink,pretty in pink
7,Who created the typeface chicago?,chicago,typeface chicago,typeface
8,what position does pee wee reese play in baseball,pee wee reese,pee wee reese,pee wee & jackie : pee wee reese & jackie robi...
9,which artist recorded one life to live,one life to live,one life to live,lady in the dark : one life to live


In [128]:
df.to_csv('incorrect_ents_without_rel_correction.csv')