In [1]:
import sys
import time
import matplotlib.pyplot as plt
import wordninja
import requests as r
import networkx as nx

from pullnet import PullNet
from graftnet import GraftNet
from pullnet_data_loader import DataLoader
from fpnet_data_loader import FpNetDataLoader
from relreasoner_data_loader import RelReasonerDataLoader
from fpnet import FactsPullNet
from util import *
from multiprocessing.pool import Pool
from preprocessing import use_helper
from collections import defaultdict
from string import punctuation


import warnings
warnings.filterwarnings("ignore")

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
cfg = get_config('config/webqsp.yml')
word2id = load_dict(cfg['data_folder'] + cfg['word2id'])
relation2id = load_dict(cfg['data_folder'] + cfg['relation2id'])
entity2id = load_dict(cfg['data_folder'] + cfg['entity2id'])
num_hop = 1

In [3]:
with open(cfg['data_folder'] + cfg['train_data']) as f:
    train_data = [json.loads(l) for l in f]
#     for e in train_data:
#         del e['passages']
#         del e['subgraph']
with open(cfg['data_folder'] + cfg['dev_data']) as f:
    dev_data = [json.loads(l) for l in f]
#     for e in dev_data:
#         del e['passages']
#         del e['subgraph']
with open(cfg['data_folder'] + cfg['test_data']) as f:
    test_data = [json.loads(l) for l in f]
#     for e in test_data:
#         del e['passages']
#         del e['subgraph']

In [19]:
# Stastistic of the training+dev data pattern.
is_entity = lambda x: x.startswith('m.') or x.startswith('g.')

def _bfs(g, start):

    graph = nx.Graph()
    for s in g:
        for p in g[s]:
            for o in g[s][p]:
                graph.add_edge(s, o, data=p)
    graph = graph.to_undirected()
    shortest_path = nx.shortest_path(graph, start, '?x')
    relation = []
    for i in range(len(shortest_path) - 1):
        u, v = shortest_path[i], shortest_path[i + 1]
        data = graph.get_edge_data(u, v)['data']
        if i == 0:
            relation.append(u)
        relation.append(data)
        relation.append(v)

    return relation

def apply_mask(paths):
    max_len = 0
    for i in range(len(paths)):
        paths[i][0] = chr(ord('A') + i)
        for j in range(1, len(paths[i]), 2):
            paths[i][j] = ('r_%d^{(%d)}' % (i, j // 2 + 1))
        # for j in range(2, len(paths[i]) - 1, 2):
        #     paths[i][j] = ('e_%d^{(%d)}' % (i, j // 2))
        max_len = max(max_len, len(paths[i]))
    
    tables = [['' for _ in range(max_len)] for _ in range(len(paths))]
    for i in range(len(paths)):
        for j in range(max_len - 1, -1, -1):
            if (max_len - 1 - j) < len(paths[i]):
                tables[i][j] = paths[i][-(max_len - j)]
    for j in range(max_len):
        lens = defaultdict(int)
        for i in range(len(paths)):
            if tables[i][j] == '':
                continue
            lens[tables[i][j]] += 1
        for i in range(len(paths)):
            if lens[tables[i][j]] > 1:
                tables[i][j] = '|%s|' % tables[i][j]
    for i in range(len(paths)):
        tables[i] = ' '.join(tables[i]).strip()
    return '\n'.join(list(sorted(tables)))

def _analyse_pattern(pattern):
    pattern = pattern.replace('PREFIX ns: <http://rdf.freebase.com/ns/>', '').replace('SELECT DISTINCT ?x', '') \
        .replace('FILTER (!isLiteral(?x) OR lang(?x) = \'\' OR langMatches(lang(?x), \'en\'))', '')\
        .replace('FILTER (?x != ?c)', '')\
        .replace('{', '{\n')\
        .replace('}', '}\n')
    pattern = '\n'.join([e.strip().replace('ns:', '') for e in pattern.split('\n') if e.strip() and not e.strip().startswith('FILTER (?x !=') and not e.strip().startswith('WHERE') and e.strip() != '}'])

    result = -1
    order_by_pattern = 0
    filter_pattern = 1
    union_pattern = 2
    adj_map = dict()
    entities = set()
    for p in pattern.split('\n'):
        if p.startswith('ORDER BY'):
            result = order_by_pattern
            break
        if p.startswith('FILTER('):
            result = filter_pattern
            break
        if p.startswith('UNION'):
            result = union_pattern
            break
        tup = p.split()
        if len(tup) < 4:
            continue
        tup = tup[:4]
        s, p, o, _ = tup
        if is_entity(s):
            entities.add(s)
        if is_entity(o):
            entities.add(o)
        if s not in adj_map:
            adj_map[s] = dict()
        if p not in adj_map[s]:
            adj_map[s][p] = dict()
        if o not in adj_map[s][p]:
            adj_map[s][p][o] = 0
        if o not in adj_map:
            adj_map[o] = dict()
        if p not in adj_map[o]:
            adj_map[o][p] = dict()
        if s not in adj_map[o][p]:
            adj_map[o][p][s] = 1
    entities = list(entities)
    paths = []
    avg_len_path = 0
    total_len = 0
    for e in entities:
        try:
            path = _bfs(adj_map, e)
        except Exception as e2:
            continue
        paths.append(path)
        avg_len_path += (len(path) - 1) // 2
        total_len += 1
    return apply_mask(paths), result, (avg_len_path / total_len if total_len > 0 else -1)
    

pattern_cnt = defaultdict(int)
p_cnt = defaultdict(int)
avg_len_path = 0
total_len = 0
for i, e in tqdm(enumerate((train_data[0] + dev_data[0]))):
    # print(e['sparql'])
    # print(_analyse_pattern(e['sparql']))
    # print()
    try:
        masked, p, len_path = _analyse_pattern(e['sparql'])
        if len_path != -1:
            avg_len_path += len_path
            total_len += 1
        pattern_cnt[masked] += 1
        p_cnt[p] += 1
    except Exception as e2:
        print(e['sparql'])
        print(e2)
        print(i)
        continue

for l in list(sorted(pattern_cnt.items(), key=lambda x: x[1], reverse=True)):
    print()
    print(l[0])
    print('%.2f%%' % (100 * l[1] / len(train_data[0] + dev_data[0])))

print(avg_len_path / len(train_data[0] + dev_data[0]))
# for l in list(sorted(p_cnt.items(), key=lambda x: x[1], reverse=True)):
#     print()
#     print(l[0])
#     print(l[1], '%.2f%%' % (100 * l[1] / len(train_data[0] + dev_data[0])))

3098it [00:00, 14288.84it/s]

#MANUAL SPARQL
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?dt
WHERE {
  ?e ns:government.election.winner  ns:m.0gzh ; # Abraham Lincoln
     ns:government.election.office  ns:m.060d2 ;  # President of the United States
     ns:government.election.election_year ?dt  .
}
division by zero
244
#MANUAL SPARQL
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT ?x
WHERE {
  ?x ns:people.deceased_person.date_of_death ?d .
  FILTER (?x = ns:m.07pzc OR ?x = ns:m.01vz0g4)
}
ORDER BY ?d
LIMIT 1
division by zero
246
#MANUAL SPARQL
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT  ?camp
WHERE {
  ns:m.0271_s ns:government.politician.election_campaigns ?camp . # Mitt Romney
  ?camp ns:government.election_campaign.election ?e . 
  ?e ns:government.election.election_year ?yr .
}
ORDER BY ?yr
LIMIT 1
division by zero
348
#MANUAL SPARQL
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?dt
WHERE {
  ?e ns:government.election.winner  ns:m.03_js ; # John Adams
     ns:government.e




In [14]:
avg_entities = 0
for i, e in tqdm(enumerate((train_data[0] + dev_data[0]))):
    avg_entities += len(e['entities'])
avg_entities / len(train_data[0] + dev_data[0])

3098it [00:00, 2173992.60it/s]


1.2272433828276308

In [16]:
avg_len_relation_path = 0
total_relation_paths = 0
for i, e in tqdm(enumerate((train_data[0] + dev_data[0]))):
    for p in e['ground_truth_path']:
        avg_len_relation_path += (len(p) - 1) // 2
        total_relation_paths += 1
avg_len_relation_path / total_relation_paths

0it [00:00, ?it/s]


KeyError: 'ground_truth_path'

In [31]:
entities = set()
print(len(train_data[0] + dev_data[0] + test_data[0]))
for q in train_data[0] + dev_data[0] + test_data[0]:
    if 'm.0h5qcr7' in q['entities']:
        print(q['sparql'])
        break
    entities.update(q['entities'])

4737
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?x
WHERE {
FILTER (?x != ns:m.0h5qcr7)
FILTER (!isLiteral(?x) OR lang(?x) = '' OR langMatches(lang(?x), 'en'))
ns:m.0h5qcr7 ns:film.film_character.portrayed_in_films ?y .
?y ns:film.performance.actor ?x .
}



In [20]:
base_path = '/data/huangxin/data_most'
facts = dict()
for file in tqdm(os.listdir(base_path)):
    file_path = os.path.join(base_path, file)
    with open(file_path) as f:
        triples = f.readlines()
    triples = list(map(lambda x: x.strip().split('\t')[:3], triples))
    triples = [list(map(lambda x: x.replace('<http://rdf.freebase.com/ns/', '').replace('>', '').replace(' .', ''), triple)) for triple in triples]
    triples = list(filter(lambda x: len(x) >= 3 and (not x[1].startswith('<')) and (x[0].startswith('m.') or x[0].startswith('g.')) and (x[2].startswith('m.') or x[2].startswith('g.')), triples))
    for s, p, o in triples:
        if s not in facts:
            facts[s] = dict()
        if p not in facts[s]:
            facts[s][p] = dict()
        if o not in facts[s][p]:
            facts[s][p][o] = 0
        if o not in facts:
            facts[o] = dict()
        if p not in facts[o]:
            facts[o][p] = dict()
        if s not in facts[o][p]:
            facts[o][p][s] = 1

 39%|███▉      | 43/110 [14:48<22:43, 20.35s/it]

KeyboardInterrupt: 

In [37]:
answers = set()
for q in train_data[0] + dev_data[0] + test_data[0]:
    answers.update(q['answers'])
print(answers)

{'m.0f6l2j', 'm.0jwr2v8', 'm.0bhjdsf', 'm.027lcnq', 'm.04vfk8v', 'm.03cjtxj', 'm.04vkj23', 'm.0bhjdl9', 'm.0gmdq9d', 'm.0267rp', 'm.0cnymms', 'm.04myq1', 'm.0488dx7', 'm.0jw2wy', 'm.04t_z0b', 'm.0c480qq', 'm.06mnpwf', 'm.04kd5d', 'm.0nqq_g', 'm.027h4z2', 'm.06y3b4', 'm.0135g77g', 'm.0d1jrn', 'm.04vzwg2', 'm.04vrrsp', 'm.0fygr7', 'm.04b42p5', 'm.0zv5njx', 'm.0_6mh8q', 'm.0bmwns', 'm.04v27b4', 'm.0488ndq', 'm.04v881h', 'm.0f4qmd', 'm.0_68y65', 'm.0488nl9', 'm.0y497_m', 'm.0257nv', 'm.09gp77', 'm.07nqppm', 'm.0dsnv4l', 'm.0zr2nkk', 'm.0h032j', 'm.0_7hpww', 'm.0zhv2fh', 'm.04b4g66', 'm.027kqw', 'm.0zl72s1', 'm.021_0p', 'm.0s5yg3', 'm.0288tbx', 'm.053d7b', 'm.0g3y9q', 'm.0n_hp', 'm.0488b42', 'm.04b4ktf', 'm.016r1v', 'm.0fb18', 'm.0488891', 'm.0zvpg7j', 'm.0cmn12', 'm.04vzwhg', 'm.0f0w6n', 'm.0zm0g6b', 'm.05jtyv', 'm.0488r6b', 'm.05qd9r', 'm.0zjfz_9', 'm.028d1l', 'm.04v883g', 'm.0bsqg4', 'm.03zb8b', '1995-01-20', 'm.04t46v', 'm.06w3k_d', 'm.04v1w6s', 'm.044dm9', 'm.0n4g1dj', 'm.06vwxkr', 'm.

In [25]:
avg_relations_len = []
for e in facts:
    avg_relations_len.append(len(e))
print(np.max(avg_relations_len))

12


In [40]:
recall = 0
wrong_entities = set()
for e in entities | answers:
    if e in facts:
        recall += 1
    else:
        wrong_entities.add(e)
print(recall / len(entities | answers))
print(list(wrong_entities)[0])

0.9928548148519568
1929-03-04


In [39]:
facts = load_json('datasets/webqsp/facts.json')

In [36]:
facts['m.0h5qcr7']

{'film.film_character.portrayed_in_films': {'m.applynewint571': 1}}

In [9]:
train_data = load_json(cfg['data_folder'] + cfg['train_data'])
dev_data = load_json(cfg['data_folder'] + cfg['dev_data'])
test_data = load_json(cfg['data_folder'] + cfg['test_data'])

train_raw_data = load_json(cfg['data_folder'] + 'WebQSP.train.json')
test_raw_data = load_json(cfg['data_folder'] + 'WebQSP.test.json')

In [227]:
with open(cfg['data_folder'] + 'train2.json') as f:
    train_data2 = [json.loads(l) for l in f]
with open(cfg['data_folder'] + 'dev2.json') as f:
    dev_data2 = [json.loads(l) for l in f]
with open(cfg['data_folder'] + 'test2.json') as f:
    test_data2 = [json.loads(l) for l in f]

In [116]:
train_raw_data['Questions'][0]

{'QuestionId': 'WebQTrn-0',
 'RawQuestion': 'what is the name of justin bieber brother?',
 'ProcessedQuestion': 'what is the name of justin bieber brother',
 'Parses': [{'ParseId': 'WebQTrn-0.P0',
   'AnnotatorId': 1,
   'AnnotatorComment': {'ParseQuality': 'Complete',
    'QuestionQuality': 'Good',
    'Confidence': 'Normal',
    'FreeFormComment': 'First-round parse verification'},
   'Sparql': "PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT DISTINCT ?x\nWHERE {\nFILTER (?x != ns:m.06w2sn5)\nFILTER (!isLiteral(?x) OR lang(?x) = '' OR langMatches(lang(?x), 'en'))\nns:m.06w2sn5 ns:people.person.sibling_s ?y .\n?y ns:people.sibling_relationship.sibling ?x .\n?x ns:people.person.gender ns:m.05zppz .\n}\n",
   'PotentialTopicEntityMention': 'justin bieber',
   'TopicEntityName': 'Justin Bieber',
   'TopicEntityMid': 'm.06w2sn5',
   'InferentialChain': ['people.person.sibling_s',
    'people.sibling_relationship.sibling'],
   'Constraints': [{'Operator': 'Equal',
     'ArgumentType': 'En

In [228]:
avg_recall = 0
qid2answers = dict()
for e in train_data2 + dev_data2 + test_data2:
    answers = set(map(lambda x: x['kb_id'].replace('<fb:', '').replace('>', ''), e['answers']))
    if len(answers) == 0:
        avg_recall += 1
        continue
    recall = sum(answer in facts for answer in answers)
    recall /= len(answers)
    avg_recall += recall
    qid2answers[e['id']] = answers
print(avg_recall / len(train_data + dev_data + test_data))

0.9705600800285678


In [468]:
def extract_tuples(sparql):
    def is_entity(mid):
        return mid.startswith('ns:m.') or mid.startswith('ns:g.') or mid.startswith('?')
    
    def clean_entity(mid):
        mid = mid.strip()
        if mid.startswith('ns:'):
            return ''.join(mid[3:]) 
        return mid
    
    sparql = sparql.replace('{', '\n{\n').replace('}', '\n}\n').replace('(', '\n(\n').replace(')', '\n)\n').replace('||', '\n||\n')
    res = set()
    for line in sparql.split('\n'):
        line = line.strip()
        line_spt = line.split()
        if not line_spt or len(line_spt) < 3:
            continue
        if is_entity(line_spt[0]) and is_entity(line_spt[2]) and line_spt[1].strip()[3:].strip():
            res.add((clean_entity(line_spt[0]), ''.join(line_spt[1].strip()[3:]).strip(), clean_entity(line_spt[2])))
    return res
    
sparqls = [(e['QuestionId'], e['Parses'][0]['Sparql']) for e in (train_raw_data['Questions'] + test_raw_data['Questions'])]
# print(sparqls[0])
ground_truth_tuples = dict()
for e in sparqls:
    qid = e[0]
    sparql = e[1]
    answer = None
    for line in sparql.split('\n'):
        if line.strip().startswith('SELECT DISTINCT'):
            answer = line.split()[2]
            break
    ground_truth_tuples[qid] = {
        'tuples': extract_tuples(sparql),
        'answer': answer
    }
    
print(ground_truth_tuples['WebQTrn-2349'])

{'tuples': {('?y', 'location.location_symbol_relationship.Kind_of_symbol', '?k'), ('m.04rrx', 'government.governmental_jurisdiction.official_symbols', '?y'), ('?y', 'location.location_symbol_relationship.symbol', '?x')}, 'answer': '?x'}


## Build KB facts

In [504]:
base_dir = '/home/hxssg1124/Downloads/webqsp/freebase_2hops/stagg.neighborhoods'
facts = load_json('datasets/webqsp/facts.json')
# facts = dict()
# for file in tqdm(os.listdir(base_dir)):
#     qid = file.replace('.nxhd', '')
#     with open(os.path.join(base_dir, file)) as f:
#         lines = f.readlines()
#     for line in lines:
#         line_spt = line.split(maxsplit=3)
#         s = line_spt[0]
#         rel = line_spt[1]
#         o = line_spt[2]
#         if not rel.startswith('<fb:') or not s.startswith('<fb:') or not o.startswith('<fb:'):
#             continue
#         s = s.replace('<fb:', '').replace('>', '')
#         rel = rel.replace('<fb:', '').replace('>', '')
#         o = o.replace('<fb:', '').replace('>', '')
#         if rel.startswith('m.') or rel.startswith('g.'):
#             continue
#         if (not (s.startswith('m.') or s.startswith('g.'))) or (not (o.startswith('m.') or o.startswith('g.'))):
#             continue
#         if s not in facts:
#             facts[s] = dict()
#         if rel not in facts[s]:
#             facts[s][rel] = dict()
#         if o not in facts[s][rel]:
#             facts[s][rel][o] = 0
#         if o not in facts:
#             facts[o] = dict()
#         if rel not in facts[o]:
#             facts[o][rel] = dict()
#         if s not in facts[o][rel]:
#             facts[o][rel][s] = 1 

In [505]:
facts2 = load_json('datasets/complexwebq/' + 'all_facts_all_new2.json')

In [509]:
# merge two facts:
for k in tqdm(facts2):
    if k not in facts:
        facts[k] = facts2[k]
    else:
        for k2 in facts2[k]:
            if k2 not in facts[k]:
                facts[k][k2] = facts2[k][k2]




  0%|          | 0/7624604 [00:00<?, ?it/s][A[A[A


  0%|          | 29110/7624604 [00:00<00:26, 291094.73it/s][A[A[A


  1%|          | 80041/7624604 [00:00<00:22, 334028.13it/s][A[A[A


  2%|▏         | 132465/7624604 [00:00<00:19, 374827.52it/s][A[A[A


  2%|▏         | 187383/7624604 [00:00<00:17, 414285.10it/s][A[A[A


  3%|▎         | 243309/7624604 [00:00<00:16, 449218.80it/s][A[A[A


  4%|▍         | 300128/7624604 [00:00<00:15, 479327.19it/s][A[A[A


  5%|▍         | 356458/7624604 [00:00<00:14, 501766.34it/s][A[A[A


  5%|▌         | 413492/7624604 [00:00<00:13, 520540.33it/s][A[A[A


  6%|▌         | 470254/7624604 [00:00<00:13, 533817.68it/s][A[A[A


  7%|▋         | 526678/7624604 [00:01<00:13, 542593.65it/s][A[A[A


  8%|▊         | 583852/7624604 [00:01<00:12, 551019.32it/s][A[A[A


  8%|▊         | 640309/7624604 [00:01<00:12, 555011.59it/s][A[A[A


  9%|▉         | 696179/7624604 [00:01<00:12, 556110.77it/s][A[A[A


 10%|▉  

In [511]:
save_json(facts, 'datasets/webqsp/facts_new2.json')

In [469]:
def is_real_entity(mid):
    return mid.startswith('m.') or mid.startswith('g.')


qid2gt = dict()
for q_id, sparql in ground_truth_tuples.items():
    tuples = sparql['tuples']
    answer = sparql['answer']
    g = nx.Graph()
    topic_entities = set()
    for t in tuples:
        if is_real_entity(t[0]):
            topic_entities.add(t[0])
        if is_real_entity(t[2]):
            topic_entities.add(t[2])
        g.add_edge(t[0], t[2], data=t[1])
    g = g.to_undirected()
    all_paths = []
    for topic_entity in topic_entities:
        paths = []
        try:
            path = nx.shortest_path(g, topic_entity, answer)
            for i in range(len(path) - 1):
                s = path[i]
                o = path[i+1]
                p = g.get_edge_data(s, o)['data']
                if i == 0:
                    paths.append(s)
                paths.append(p)
                paths.append(o)
        except Exception as e:
            continue
        all_paths.append(paths)
    qid2gt[q_id] = {
        'path': all_paths,
        'entities': topic_entities,
    }
    
qid2gt['WebQTrn-2349']

{'path': [['m.04rrx',
   'government.governmental_jurisdiction.official_symbols',
   '?y',
   'location.location_symbol_relationship.symbol',
   '?x']],
 'entities': {'m.04rrx'}}

In [None]:
for q_id, sparql in ground_truth_tuples.items():
    print(sparql)
    print('-----------------------')

In [99]:
list(ground_truth_rels.items())[:10]

[('WebQTrn-0',
  {'people.person.gender',
   'people.person.sibling_s',
   'people.sibling_relationship.sibling'}),
 ('WebQTrn-1',
  {'film.actor.film', 'film.performance.character', 'film.performance.film'}),
 ('WebQTrn-3',
  {'base.biblioness.bibs_location.loc_type', 'location.location.containedby'}),
 ('WebQTrn-4', {'location.country.currency_used'}),
 ('WebQTrn-5',
  {'film.actor.film', 'film.performance.character', 'film.performance.film'}),
 ('WebQTrn-6',
  {'sports.pro_athlete.teams',
   'sports.sports_team_roster.from',
   'sports.sports_team_roster.team',
   'sports.sports_team_roster.to'}),
 ('WebQTrn-7', {'sports.sports_team.location'}),
 ('WebQTrn-8', {'people.person.place_of_birth'}),
 ('WebQTrn-9', {'people.person.date_of_birth'}),
 ('WebQTrn-11', {'location.location.time_zones'})]

In [110]:
all_rels = set()
for q in train_data:
    qid = q['id']
    rels = ground_truth_rels[qid]
    all_rels.update(rels)
dev_test_rels = set()
for q in dev_data + test_data:
    qid = q['id']
    rels = ground_truth_rels[qid]
    dev_test_rels.update(rels)
print(len(all_rels | dev_test_rels))

680


In [499]:
base_dir = '/home/hxssg1124/Downloads/webqsp/freebase_2hops/stagg.neighborhoods'
avg_recall = 0
avg_recall2 = 0
recall_cnt2 = 0
total = 0
idx_new_entity = 0
qid2real_int_entities_map = dict()
qid2rel_chain_map = dict()
block_rels = {'Equals', 'GreaterThan', 'GreaterThanOrEqual', 'LessThan', 'LessThanOrEqual', 'NotEquals'}
T = 2
all_used_rels = set()
for q_id, gt in list(qid2gt.items()):
#     if '2349' not in q_id:
#         continue
    gt_rels = ground_truth_rels[q_id]
    gt_tup = ground_truth_tuples[q_id]
    symbolic_answer = gt_tup['answer']
    entitites = gt['entities']
    paths = gt['path']
    if q_id not in qid2answers or not qid2answers[q_id]:
        continue
    answers = qid2answers[q_id]
    # Map from virtual ?y intermediate entities to real entities.
    real_int_entities_map = dict()
    o2new_int_entities = dict()
    for hop in range(1, 3):
        spo_list = list()
        for path in paths:
            if hop * 2 + 1 > len(path):
                continue
            spo_list.append((path[(hop - 1) * 2], path[2*hop - 1], path[2*hop]))
        merged_sp_map = defaultdict(list)
        for s, p, o in spo_list:
            if s in real_int_entities_map:
                for real_s in real_int_entities_map[s]:
                    merged_sp_map[o].append((real_s, p))
            else:
                merged_sp_map[o].append((s, p))
        for o in merged_sp_map:
            int_entities = set() # Find intermediate entities
            sp_list = merged_sp_map[o]
            sp_list = [(o2new_int_entities[s] if s in o2new_int_entities else s, p) for s, p in sp_list]
            
            # Initialize facts to update with original facts
            new_facts = dict()
            for s, p in sp_list:
                if not s.startswith('m.') and not s.startswith('g.'):
                    continue
                if s in facts:
                    new_facts[s] = facts[s]
                    if p in facts[s]:
                        new_facts[s][p] = facts[s][p]
            
            # If the symbolic object is the final answer, simply replace symbolic answers with actual answers.
            if o == symbolic_answer:
                for s, p in sp_list:
                    if s not in new_facts:
                        new_facts[s] = dict()
                    if p not in new_facts[s]:
                        new_facts[s][p] = dict()
                    for answer in answers:
                        new_facts[s][p][answer] = 0
                        if answer not in new_facts:
                            new_facts[answer] = dict()
                        if p not in new_facts[answer]:
                            new_facts[answer][p] = dict()
                        if s not in new_facts[answer][p]:
                            new_facts[answer][p][s] = 1
            # However, if the symbolic object is the intermediate entity, we may need to create a new bridge entity.
            else:
                # First check if s, p all are in new_facts. And all s-p have mutal intermediate entities.
                need_create_new_int_entities = False
                tmp_int_entities = None
                for s, p in sp_list:
                    if s not in new_facts:
                        new_facts[s] = dict()
                        need_create_new_int_entities = True
                    if p not in new_facts[s]:
                        new_facts[s][p] = dict()
                        need_create_new_int_entities = True
                    if tmp_int_entities is None:
                        tmp_int_entities = set(new_facts[s][p].keys())
                    else:
                        tmp_int_entities = set(new_facts[s][p].keys()) & tmp_int_entities
                if not need_create_new_int_entities:
                    need_create_new_int_entities = not tmp_int_entities
                if need_create_new_int_entities:
                    if o not in o2new_int_entities:
                        new_int_entity = 'm.applynewint%d' % idx_new_entity
                        o2new_int_entities[o] = new_int_entity
                        idx_new_entity += 1
                    new_o = o2new_int_entities[o]
                    if new_o not in new_facts:
                        new_facts[new_o] = dict()
                    for s, p in sp_list:
                        if p not in new_facts[new_o]:
                            new_facts[new_o][p] = dict()
                        new_facts[s][p][new_o] = 1
                        new_facts[new_o][p][s] = 0
            facts.update(new_facts)
            for s, p in sp_list:
                if not int_entities:
                    try:
                        int_entities = set(new_facts[s][p].keys())
                    except Exception as e:
                        raise Exception(e)
                else:
                    int_entities = set(new_facts[s][p].keys()) & int_entities
#             print(hop, sp_list)
#             print(q)
#             print(hop, o, '--->', int_entities)
            real_int_entities_map[o] = set(answers) if o == symbolic_answer else int_entities
    qid2real_int_entities_map[q_id] = real_int_entities_map
    
    rel_chain_map = dict()
    rel_chain_entity_map = dict()
    # Fill the rel chain data.
    for hop in range(1, T+1):
        rel_chain_with_hop_map =  dict()
        rel_chain_entity_with_hop_map = dict()
        # Find the following relations based on the previous entity-relation chain.
        local_entity_rel_chain_map = defaultdict(set)
        # Find the intermediate entities based on the topic entity with the given hop
        local_entity_inter_entities_map = defaultdict(set)
        for gt in paths:
            topic_entity = gt[0]
            target_relation = gt[2*hop-1] if (2*hop-1 < len(gt)) else 'EOD'
            prev_relations = [(gt[idx] if idx < len(gt) else 'EOD') for idx in range(1, 2*hop-1, 2)]
#             if 'EOD' in prev_relations:
#                 continue
            local_entity_rel_chain_map[topic_entity].add(tuple(prev_relations + [target_relation]))
            if 2*(hop-1) < len(gt):
                int_entities = gt[2*(hop-1)]
                if int_entities.startswith('?'):
                    int_entities = real_int_entities_map[int_entities]
                if isinstance(int_entities, str):
                    int_entities = {int_entities}
                for int_entity in int_entities:
                    local_entity_inter_entities_map[topic_entity].add(tuple([int_entity] + [gt_e for idx_e, gt_e in enumerate(gt[:2*(hop-1)]) if idx_e % 2 == 1]))
#         print('hop:', hop)
#         print('local_entity_rel_chain_map', local_entity_rel_chain_map)
#         print('local_entity_inter_entities_map', local_entity_inter_entities_map)
#         print('entitysssss:', entities)
        for k, v in local_entity_rel_chain_map.items():
#             print('entity:', entity)
            rel_chain_with_hop_entity_map = dict()
            ground_truth = list(v)
#             print('asdasdasd',ground_truth)
            new_cand_rels = set()
            for inter_entity_prev_relations in local_entity_inter_entities_map[k]:
                inter_entity = inter_entity_prev_relations[0]
                prev_relations = inter_entity_prev_relations[1:]
                temp_rels = set(facts[inter_entity].keys())
                temp_rels = list(filter(lambda x: x not in block_rels, temp_rels))
                temp_rels.append('EOD')
                all_used_rels.update(temp_rels)
                temp_rels = set(map(lambda x: tuple(list(prev_relations) + [x]), temp_rels))
                new_cand_rels.update(temp_rels)
            new_cand_rels = list(new_cand_rels)
            
            rel_chain_with_hop_entity_map['ground_truth'] = ground_truth
            rel_chain_with_hop_entity_map['cands'] = list(new_cand_rels)
            rel_chain_with_hop_map[k] = rel_chain_with_hop_entity_map
            
            recall = len(set(new_cand_rels) & set([tuple(e) for e in ground_truth])) / len(ground_truth)
#             if recall < 1:
#                 print(idx)
#                 print(new_cand_rels)
#                 print('-----------------------------------')
#                 print(ground_truth)
#                 print()
#                 asdsdsds
            avg_recall2 += recall
            recall_cnt2 += 1
        rel_chain_map[str(hop)] = rel_chain_with_hop_map
        
    qid2rel_chain_map[q_id] = rel_chain_map
#     print(rel_chain_map)
#     print(path)
#     print(avg_recall2 / recall_cnt2)
#     break
    for entity in entitites:
        if entity in facts:
            cand_rels.update(set(facts[entity].keys()))
    recall = (len(gt_rels & cand_rels) / len(gt_rels)) if len(gt_rels) > 0 else 1
    recall = sum(e in facts for e in entitites) / len(entitites) if len(entitites) > 0 else 0
    avg_recall += recall
    total += 1
#     print(q_id)
#     print(entitites)
#     print(list(e for e in sparqls if e[0] == q_id)[0][1])
#     print(cand_rels)
#     print(gt_rels)
#     break
print(avg_recall / total)
print(avg_recall2 / recall_cnt2)

0.9993617021276596
0.9995646116335771


In [494]:
sparqls_map = {e[0]: e[1] for e in sparqls}

In [502]:
for idx, q in tqdm(enumerate(train_data2 + dev_data2 + test_data2)):
    if 'passages' in q:
        del q['passages']
    if 'subgraph' in q:
        del q['subgraph']
    answers = qid2answers[q['id']] if q['id'] in qid2answers else []
    rel_chain_map = qid2rel_chain_map[q['id']] if q['id'] in qid2rel_chain_map else dict()
    gt = qid2gt[q['id']]
    entities = gt['entities']
    sparql = sparqls_map[q['id']]
    q['ID'] = q['id']
    q['entities'] = list(entities)
    q['sparql'] = sparql
    q['answers'] = list(answers)
    q['rel_chain_map'] = rel_chain_map




4737it [00:00, 122077.12it/s]


In [503]:
save_json(train_data2, 'datasets/webqsp/train_new.json')
save_json(dev_data2, 'datasets/webqsp/dev_new.json')
save_json(test_data2, 'datasets/webqsp/test_new.json')

In [491]:
facts['m.01j_cy']

{'education.education.institution': {'m.0n1mx4n': 1, 'm.02wn3kr': 1},
 'sports.school_sports_team.school': {'m.0ft5vs': 0}}

In [389]:
for q_id, v in list(qid2real_int_entities_map.items())[-200:]:
    print(q_id, '----->', v)

WebQTest-1791 -----> {'?x': {'1440'}}
WebQTest-1792 -----> {'?x': {'m.07ssc', 'm.06q1r', 'm.0hzc9md'}}
WebQTest-1793 -----> {'?y': {'m.0z3vl6r', 'm.0z3vp21', 'm.04gd4np', 'm.0z9nl0h', 'm.0z9p26m'}, '?x': {'m.0h9814q', 'm.0755sb'}}
WebQTest-1794 -----> {'?x': {'m.02dtg'}}
WebQTest-1795 -----> {'?x': {'m.0j51lwj', 'm.0j51lwr'}}
WebQTest-1796 -----> {'?x': {'m.03b12'}}
WebQTest-1797 -----> {'?x': {'m.09c7w0', 'm.02fp48', 'm.020d5', 'm.02_qg_', 'm.07t2k'}}
WebQTest-1799 -----> {'?x': {'m.0cbd2', 'm.02hv44_', 'm.02xhgwq'}}
WebQTest-1800 -----> {'?x': {'m.018n8'}}
WebQTest-1801 -----> {'?y': {'m.0k39m5'}, '?x': {'m.0cp8vl'}}
WebQTest-1802 -----> {'?x': {'m.02lcqs', 'm.02fqwt', 'm.027wjl3', 'm.02lcrv', 'm.027wj2_', 'm.02lctm', 'm.02hcv8', 'm.02hczc', 'm.042g7t'}}
WebQTest-1803 -----> {'?x': {'m.015smg', 'm.03qtfw8'}}
WebQTest-1804 -----> {'?x': {'m.03x42', 'm.070zw', 'm.02jcw', 'm.0h407', 'm.01v0g', 'm.0ct8m', 'm.083tk', 'm.02h40lc'}}
WebQTest-1805 -----> {'?x': {'m.01428y', 'm.04ygk0'}}
WebQ

In [412]:
for q_id, v in list(qid2rel_chain_map.items())[:100]:
    print(q_id, '----->', v)

WebQTrn-0 -----> {1: {'m.05zppz': {'ground_truth': [('people.person.gender',)], 'cands': [('fictional_universe.fictional_character.gender',), ('EOD',), ('people.person.gender',)]}, 'm.06w2sn5': {'ground_truth': [('people.person.sibling_s',)], 'cands': [('people.person.education',), ('music.group_membership.member',), ('film.actor.film',), ('award.award_winner.awards_won',), ('tv.tv_guest_role.actor',), ('tv.tv_actor.guest_roles',), ('people.sibling_relationship.sibling',), ('music.composer.compositions',), ('EOD',), ('film.producer.film',), ('music.artist.origin',), ('people.person.children',), ('celebrities.celebrity.sexual_relationships',), ('common.topic.notable_types',), ('broadcast.artist.content',), ('music.group_member.membership',), ('music.artist.genre',), ('common.topic.notable_for',), ('influence.influence_node.influenced',), ('people.person.nationality',), ('music.artist.track_contributions',), ('base.popstra.dated.participant',), ('influence.influence_node.influenced_by',)

In [384]:
print([e[1] for e in sparqls if '228' in e[0]][0])

PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?x
WHERE {
FILTER (?x != ns:m.082db)
FILTER (!isLiteral(?x) OR lang(?x) = '' OR langMatches(lang(?x), 'en'))
ns:m.082db ns:music.composer.compositions ?x .
}



In [484]:
facts['m.0gzh'].keys()

dict_keys(['government.government_position_held.office_holder', 'government.politician.government_positions_held', 'people.person.spouse_s', 'people.deceased_person.place_of_burial', 'influence.influence_node.influenced_by', 'government.us_president.vice_president', 'book.author.works_written', 'medicine.notable_person_with_medical_condition.condition', 'user.alexander.misc.murdered_person.murder_method', 'government.politician.party'])

In [500]:
# save all relations
all_rels = set()
for k in facts:
    if k == 'm.01j_cy':
        print('wow')
        print(facts[k])
    for r in facts[k]:
        all_rels.add(r)
all_rels.add('EOD')
all_rels.update(all_used_rels)
all_rels = list(sorted(all_rels))
with open('datasets/webqsp/relations.txt', 'w') as f:
    for r in all_rels:
        f.writelines(r + '\n')

wow
{'education.education.institution': {'m.0n1mx4n': 1, 'm.02wn3kr': 1}}


In [None]:
qid2answers['WebQTrn-228']

In [None]:
facts['m.01y2hn6']['common.topic.notable_types']

In [486]:
save_json(facts, 'datasets/webqsp/facts_new.json')

In [None]:
facts['m.09l3p']['film.actor.film']

In [66]:
def find_ground_truth(sparql):
    base_url = 'http://10.2.0.27:8891/sparql'
    params = {
        'default-graph-uri:': None,
        'query': sparql,
        'format': 'application/sparql-results+json',
        'timeout': 0,
    }
    res = r.get(base_url, params=params)
    return json.loads(res.content)['results']['bindings']

answer_dict = dict()
for e in tqdm(train_raw_data['Questions'][248:]):
    answer_dict[e['QuestionId']] = find_ground_truth(e['Parses'][0]['Sparql'])



  0%|          | 0/2850 [00:00<?, ?it/s][A[A

  0%|          | 6/2850 [00:00<00:58, 48.37it/s][A[A

  0%|          | 8/2850 [00:00<01:41, 28.02it/s][A[A

  0%|          | 10/2850 [00:00<02:34, 18.35it/s][A[A


JSONDecodeError: Expecting value: line 1 column 1 (char 0)

## Find inference chain

In [4]:
train_sp = load_json(cfg['data_folder'] + 'WebQSP.train.json')
test_sp = load_json(cfg['data_folder'] + 'WebQSP.test.json')

In [5]:
train_chain = {e['QuestionId']: 
             [{'chain': p['InferentialChain'], 'topic': p['TopicEntityMid'], 'answers': [a['AnswerArgument'] for a in p['Answers']]} for p in e['Parses']]
             for e in train_sp['Questions']}
test_chain = {e['QuestionId']: 
             [{'chain': p['InferentialChain'], 'topic': p['TopicEntityMid'], 'answers': [a['AnswerArgument'] for a in p['Answers']]} for p in e['Parses']]
             for e in test_sp['Questions']}

In [127]:
avg_recall = 0.0
wrong = 0
for q in train_data + dev_data + test_data:
    q_c = train_chain[q['id']] if q['id'] in train_chain else test_chain[q['id']]
    act_ans = set(map(lambda x: x['kb_id'].replace('<fb:', '').replace('>', ''), q['answers']))
    q_c_ans = set()
    for c in q_c:
        q_c_ans.update(c['answers'])
    q['chain'] = q_c
    q['rel_path'] = list(set(tuple(c['chain']) for c in q_c if c['chain']))
#     for c in q_c:
#         if 'chain' not in c or not c['chain']:
#             wrong += 1
#             q_id = q['id']
#             with open('/home/hxssg1124/Developer/GraftNet/preprocessing/freebase_2hops/stagg.neighborhoods/%s.nxhd' % q_id) as f:
#                 facts = f.readlines()
#             kb = defaultdict(dict)
#             for fact in facts:
#                 try:
#                     s, p, o = fact.strip().split('\t')
#                 except Exception as e:
#                     print(fact)
#                 s = s.replace('<fb:', '').replace('>', '')
#                 p = p.replace('<fb:', '').replace('>', '')
#                 o = o.replace('<fb:', '').replace('>', '')
#                 kb[s][o] = p
#                 kb[o][s] = p
#             topic = c['topic']
#             answers = set(c['answers'])
#             print(topic)
#             hop1entiites = set(kb[topic].keys())
#             rel_chain = []
#             for answer in answers:
#                 if answers in hop1entiites:
#                     rel_chain.append([kb[topic][answer]])
#             for hop1entity in hop1entiites:
#                 hop2entities = set(kb[hop1entity].keys())
#                 for answer in answers:
#                     if answers in hop1entiites:
#                         rel_chain.append([kb[topic][hop1entity], kb[hop1entity][answer]])
#             if wrong > 2:
#                 print(rel_chain)
#             break
#     if wrong > 2:
#         break
    #q['rel_path'] = list(set(tuple(c['chain']) for c in q_c))
print(wrong)

0


In [33]:
wrong = 0
for q in train_data:
    entities = set(map(lambda x: x['kb_id'].replace('<fb:', '').replace('>', ''), q['entities']))
    chain_topic = set(map(lambda x: x['topic'], q['chain']))
    if chain_topic != entities:
        wrong += 1
        if wrong > 20:
            print(q)
            print(entities)
            print(chain_topic)
            break
wrong

{'question': 'what bible does the catholic church follow', 'answers': [{'kb_id': '<fb:m.05crg>', 'text': 'New Testament'}, {'kb_id': '<fb:m.01dfl>', 'text': 'Book of Nehemiah'}, {'kb_id': '<fb:m.05ld9>', 'text': 'Old Testament'}, {'kb_id': '<fb:m.015j7>', 'text': 'The Bible'}], 'entities': [{'kb_id': '<fb:m.02vxy_>', 'text': '<fb:m.02vxy_>'}], 'id': 'WebQTrn-138', 'chain': [{'chain': ['religion.religion.texts'], 'topic': 'm.0c8wxp', 'answers': ['m.015j7', 'm.01dfl', 'm.05crg', 'm.05ld9']}], 'rel_path': [['religion.religion.texts']]}
{'m.02vxy_'}
{'m.0c8wxp'}


21

In [40]:
train_sp['Questions'][106]

{'QuestionId': 'WebQTrn-138',
 'RawQuestion': 'what bible does the catholic church follow?',
 'ProcessedQuestion': 'what bible does the catholic church follow',
 'Parses': [{'ParseId': 'WebQTrn-138.P0',
   'AnnotatorId': 2,
   'AnnotatorComment': {'ParseQuality': 'Complete',
    'QuestionQuality': 'Good',
    'Confidence': 'Low',
    'FreeFormComment': '?'},
   'Sparql': "PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT DISTINCT ?x\nWHERE {\nFILTER (?x != ns:m.0c8wxp)\nFILTER (!isLiteral(?x) OR lang(?x) = '' OR langMatches(lang(?x), 'en'))\nns:m.0c8wxp ns:religion.religion.texts ?x .\n}\n",
   'PotentialTopicEntityMention': 'catholic',
   'TopicEntityName': 'Catholicism',
   'TopicEntityMid': 'm.0c8wxp',
   'InferentialChain': ['religion.religion.texts'],
   'Constraints': [],
   'Time': None,
   'Order': None,
   'Answers': [{'AnswerType': 'Entity',
     'AnswerArgument': 'm.015j7',
     'EntityName': 'The Bible'},
    {'AnswerType': 'Entity',
     'AnswerArgument': 'm.01dfl',
     'E

In [131]:
save_json(train_data, cfg['data_folder'] + cfg['train_data'])
save_json(dev_data, cfg['data_folder'] + cfg['dev_data'])
save_json(test_data, cfg['data_folder'] + cfg['test_data'])

In [135]:
rels = set()
for q in train_data + dev_data + test_data:
    for p in q['rel_path']:
        for r in p:
            rels.add(r)
res = []
with open('datasets/webqsp/relations.txt') as f:
    lines = f.readlines()
    for line in lines:
        line = line.strip()
        line = line.replace('<fb:', '').replace('>', '')
        res.append(line + '\n')
        rels.add(line)
with open('datasets/webqsp/relations_new.txt', 'w') as f:
    for r in rels:
        f.writelines(r + '\n')

In [513]:
train_data2[0]

{'question': 'where is the denver broncos stadium located',
 'answers': ['m.02hxv8'],
 'entities': ['m.0289q'],
 'id': 'WebQTrn-1994',
 'sparql': "PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT DISTINCT ?x\nWHERE {\nFILTER (?x != ns:m.0289q)\nFILTER (!isLiteral(?x) OR lang(?x) = '' OR langMatches(lang(?x), 'en'))\nns:m.0289q ns:sports.sports_team.arena_stadium ?x .\n}\n",
 'rel_chain_map': {'1': {'m.0289q': {'ground_truth': [('sports.sports_team.arena_stadium',)],
    'cands': [('sports.sports_team.arena_stadium',),
     ('EOD',),
     ('sports.sports_team_coach_tenure.team',)]}},
  '2': {'m.0289q': {'ground_truth': [('sports.sports_team.arena_stadium',
      'EOD')],
    'cands': [('sports.sports_team.arena_stadium',
      'sports.sports_team.arena_stadium'),
     ('sports.sports_team.arena_stadium', 'EOD')]}}},
 'ID': 'WebQTrn-1994'}

In [87]:
for e in train_sp['Questions']:
    if len(e['Parses']) > 4:
        print(e)
        print(e['RawQuestion'])
        print(e['Parses'][2])
        break
len(train_sp['Questions'][427]['Parses'])

{'QuestionId': 'WebQTrn-587', 'RawQuestion': 'where are you if you re in zagreb?', 'ProcessedQuestion': 'where are you if you re in zagreb', 'Parses': [{'ParseId': 'WebQTrn-587.P0', 'AnnotatorId': 1, 'AnnotatorComment': {'ParseQuality': 'Complete', 'QuestionQuality': 'Good', 'Confidence': 'Normal', 'FreeFormComment': 'First-round parse verification'}, 'Sparql': "PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT DISTINCT ?x\nWHERE {\nFILTER (?x != ns:m.0fhzy)\nFILTER (!isLiteral(?x) OR lang(?x) = '' OR langMatches(lang(?x), 'en'))\nns:m.0fhzy ns:base.aareas.schema.administrative_area.administrative_parent ?x .\n}\n", 'PotentialTopicEntityMention': 'zagreb', 'TopicEntityName': 'Zagreb', 'TopicEntityMid': 'm.0fhzy', 'InferentialChain': ['base.aareas.schema.administrative_area.administrative_parent'], 'Constraints': [], 'Time': None, 'Order': None, 'Answers': [{'AnswerType': 'Entity', 'AnswerArgument': 'm.01pj7', 'EntityName': 'Croatia'}]}, {'ParseId': 'WebQTrn-587.P1', 'AnnotatorId': 1, 'A

1