In [1]:
import os
import json
import argparse
import glob


import multiprocessing
import time
import math
import networkx as nx
import os
import random
from tkinter import ALL
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
from func_timeout import func_set_timeout, FunctionTimedOut





In [2]:
END_REL = "END OF HOP"


In [3]:
from utils import load_jsonl, dump_jsonl
from knowledge_graph.knowledge_graph import KnowledgeGraph
from knowledge_graph.knowledge_graph_cache import KnowledgeGraphCache
from knowledge_graph.knowledge_graph_freebase import KnowledgeGraphFreebase
from config import cfg



In [4]:
# Load WebQSP
def load_webqsp():
    load_data_path = cfg.preprocessing["step0"]["load_data_path"]
    dump_data_path = cfg.preprocessing["step0"]["dump_data_path"]
    folder_path = cfg.preprocessing["step0"]["dump_data_folder"]

    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    with open(load_data_path, "r") as f:
        train_dataset = json.loads(f.read())
        data_list = []
        for json_obj in train_dataset["Questions"]:
            question = json_obj["ProcessedQuestion"]
            for parse in json_obj["Parses"]:
                topic_entities = [parse["TopicEntityMid"]]
                answers = []
                for answer_json_obj in parse["Answers"]:
                    if answer_json_obj["AnswerType"] == "Entity":
                        answers.append(answer_json_obj["AnswerArgument"])
                if len(answers) == 0:
                    continue
                data_list.append({
                    "question": question,
                    "topic_entities": topic_entities,
                    "answers": answers,
                })
        with open(dump_data_path, "w") as f:
            for json_obj in data_list:
                f.write(json.dumps(json_obj) + "\n")

In [5]:
@func_set_timeout(10)
def generate_paths(item, kg: KnowledgeGraphFreebase, pair_max: int = 20, path_max: int = 100):
    paths = []
    entities = [entity for entity in item['topic_entities']]
    answers = [answer for answer in item['answers']]
    for src in entities:
        for tgt in answers:
            if len(paths) > path_max:
                break
            n_paths = []
            n_paths.extend(kg.search_one_hop_relaiotn(src, tgt))
            n_paths.extend(kg.search_two_hop_relation(src, tgt))
            paths.extend(n_paths)
    return paths[:path_max]

In [6]:
def run_search_to_get_path():
    load_data_path = cfg.preprocessing["step1"]["load_data_path"]
    dump_data_path = cfg.preprocessing["step1"]["dump_data_path"]
    kg = KnowledgeGraphFreebase()
    train_dataset = load_jsonl(load_data_path)
    
    outf = open(dump_data_path, 'w')
    for item in tqdm(train_dataset):
        try:
            paths = generate_paths(item, kg)
        except FunctionTimedOut:
            continue
        outline = json.dumps([item, paths], ensure_ascii=False)
        print(outline, file=outf)
        outf.flush()
    outf.close()

In [7]:
def run_score_path():
    load_data_path = cfg.preprocessing["step2"]["load_data_path"]
    dump_data_path = cfg.preprocessing["step2"]["dump_data_path"]

    kg = KnowledgeGraphFreebase()

    data_list = load_jsonl(load_data_path)

    data_with_path_list = []
    for (item, paths) in tqdm(data_list):
        m = set()
        for path in paths:
            if isinstance(path, str):
                path = (path,)
            path = tuple(path)
            if path == ("type.object.type", "type.type.instance"):
                continue
            m.add(path)
        data_with_path_list.append((item, tuple(m)))

    def cal_path_val(topic_entity, path, answers):
        preds = kg.deduce_leaves_by_path(topic_entity, path)
        preds = set(preds)
        hit = preds & answers
        full = preds
        if not full:
            return 1
        return len(hit) / len(full)

    m_list = []
    for item, p_strs in tqdm(data_with_path_list):
        answers = set(item['answers'])
        topic_entities = item['topic_entities']
        path_and_score_list = []
        for p_str in p_strs:
            path = p_str
            p_val_list = []
            for topic_entity in topic_entities:
                p_val_list.append(cal_path_val(topic_entity, path, answers))
            p_val = max(p_val_list)
            path_and_score_list.append(dict(path=path, score=p_val))
        m_item = deepcopy(item)
        m_item['path_and_score_list'] = path_and_score_list
        m_list.append(m_item)
    
    dump_jsonl(m_list, dump_data_path)

In [8]:

@func_set_timeout(10)
def generate_data_list(path_json_obj, json_obj, pos_rels, kg: KnowledgeGraphFreebase):
    new_data_list = []
    neg_num = 15

    path = path_json_obj["path"]
    path = path + [END_REL]
    question = json_obj["question"] + " [SEP]"
    topic_entities = json_obj["topic_entities"]
    
    filter_threshold = 5
    current_filter_threshold = 1
    filter_flag = False
    

    for rel in path[:-1]:
        current_filter_threshold *= filter_threshold
        
        candidate_entities = set()
        for h in topic_entities:
            candidate_entities |= set(kg.get_hr2t_with_limit(h, rel, current_filter_threshold + 1))
            if len(candidate_entities) > current_filter_threshold:
                break

        if len(candidate_entities) > current_filter_threshold:
            filter_flag = True
            break
    
    if filter_flag:
        return None

    prefix_list = []
    for rel in path:
        prefix = ",".join(prefix_list)
        prefix_list.append(rel)
        
        data_row = []
        data_row.append(question)
        data_row.append(rel)

        neg_rels = set()
        for h in topic_entities:
            neg_rels |= set(kg.get_relation(h, limit=100))
            if len(neg_rels) > 100:
                break
        neg_rels = list(neg_rels)
        neg_rels.append(END_REL)
        neg_rels = [r for r in neg_rels if r not in pos_rels[prefix]]

        if len(neg_rels) > 0:
            sample_rels = []
            while len(sample_rels) < neg_num:
                sample_rels.extend(neg_rels)
            
            neg_rels = random.sample(sample_rels, neg_num)            
            neg_rels = neg_rels[:neg_num]
            
            data_row.extend(neg_rels)
            new_data_list.append(data_row)
        
        # update for next step
        if rel != END_REL:
            next_question = question + f" {rel} #"
            question = next_question
            topic_entities = kg.deduce_leaves_from_src_list_and_relation(topic_entities, rel)

    return new_data_list



In [9]:

def run_negative_sampling():
    load_data_path = cfg.preprocessing["step3"]["load_data_path"]
    dump_data_path = cfg.preprocessing["step3"]["dump_data_path"]
    folder_path = cfg.preprocessing["step3"]["dump_data_folder"]
    
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    kg = KnowledgeGraphFreebase()
    threshold = 0.5
    
    data_list = load_jsonl(load_data_path)
    update_paths = {}

    new_data_list = []
    timeout_count = 0
    for json_obj in tqdm(data_list, desc="negative-sampling"):
        path_and_score_list = json_obj["path_and_score_list"]
        path_and_score_list = [path_json_obj for path_json_obj in path_and_score_list if path_json_obj["score"] >= threshold]
        pos_rels = {}  # 1-hop positive, 2-hop positive, ...
        for path_json_obj in path_and_score_list:
            path = path_json_obj["path"]
            path = path + [END_REL]
            prefix_list = []
            for rel in path:
                prefix = ",".join(prefix_list)
                if prefix not in pos_rels:
                    pos_rels[prefix] = set()
                pos_rels[prefix].add(rel)
                prefix_list.append(rel)
                
        for path_json_obj in path_and_score_list:
            try:
                data = generate_data_list(path_json_obj, json_obj, pos_rels, kg)
            except FunctionTimedOut:
                continue
            if data is not None:
                new_data_list.extend(data)

    print("timeout_count:", timeout_count)
    
    new_data_list = np.array(new_data_list)
    df = pd.DataFrame(new_data_list)
    df.to_csv(dump_data_path, header=False, index=False)


In [10]:
def run():
    load_webqsp()
    run_search_to_get_path()
    run_score_path()
    run_negative_sampling()

In [11]:
run()

  4%|▎         | 125/3351 [21:07<9:05:07, 10.14s/it]


KeyboardInterrupt: 


                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX ns: <http://rdf.freebase.com/ns/>
        
            select distinct ?r1 where {
                ns:m.05v8c ?r1_ ns:m.011_7j . 
                FILTER regex(?r1_, "http://rdf.freebase.com/ns/")
                bind(strafter(str(?r1_),str(ns:)) as ?r1)
            }
        
