In [1]:
!pwd

/cluster/work/cotterell/kdu/measureLM/data/YagoECQ


In [2]:
from typing import List, Dict, Tuple
from SPARQLWrapper import SPARQLWrapper2
from SPARQLWrapper.SPARQLExceptions import EndPointInternalError
from collections import namedtuple
import re
from time import sleep
from urllib.error import HTTPError
import os
import json
from tqdm import tqdm
import time
# from preprocessing.utils import extract_name_from_yago_uri

In [3]:
Predicate = namedtuple("Predicate", ["uri", "kb_name", "relation"])
DATA_ROOT = "."
PRED_URI_TO_S_OBJ_CLASSES_PATH = os.path.join(DATA_ROOT, "yago_pred_uri_to_s_obj_classes.json")
PRED_URI_TO_SO_PAIRS_PATH = os.path.join(DATA_ROOT, "yago_pred_uri_to_so_pairs_randomized_1k.json")
YAGO_QEC_PATH = os.path.join(DATA_ROOT, "yago_qec.json")
TRY_QUERYING_MISSING_PREDS = False

In [4]:
def extract_name_from_yago_uri(uri: str):
    is_reversed = False
    if uri.startswith("reverse-"):
        uri = uri.split("reverse-")[1]
        is_reversed = True
    pattern = r"http://(?:www\.)?([^\/]+)\.org\/(.+)$"
    matches = re.match(pattern, uri)

    if matches:
        kb_domain = matches.group(1)
        relation = matches.group(2)
    else:
        raise ValueError(f"Could not find match containing kb_domain and relation for uri {uri}.")

    domain_to_name = {"schema": "schema", "yago-knowledge": "yago", "w3": "w3"}
    kb_name = domain_to_name[kb_domain]
    kb_name = ("reverse-" if is_reversed else "") + kb_name

    return kb_name, relation


In [5]:
##################################
# 1. Get all relevant predicates #
##################################
sparql = SPARQLWrapper2("https://yago-knowledge.org/sparql/query")
query_p = """
    SELECT DISTINCT ?p WHERE {
     ?s ?p ?obj . 
    }  ORDER BY ?p
"""

# Sparql query
sparql.setQuery(query_p)

# Adding values
relevant_preds: List[Predicate] = []
ineligible_relations = ["schema#fromClass", "schema#fromProperty", "logo", "image"]

for result in sparql.query().bindings:
    uri = result["p"].value
    kb_name, relation = extract_name_from_yago_uri(uri)
    if kb_name != "w3" and relation not in ineligible_relations:
        relevant_preds.append(Predicate(uri=uri, kb_name=kb_name, relation=relation))

In [6]:
###########################################################
# 1b. Get all subject and object types for each predicate #
###########################################################
from SPARQLWrapper.SPARQLExceptions import QueryBadFormed
query_single = """
    SELECT DISTINCT ?s ?obj WHERE {{
     ?s <{}> ?obj . 
    }}  LIMIT 10
"""

query_superclasses = """
    SELECT DISTINCT ?superclasses WHERE {{
     <{}> rdf:type/rdfs:subClassOf* ?superclasses .
    }}  
"""

types = [
    "http://schema.org/CreativeWork",
    "http://schema.org/Event",
    "http://schema.org/Intangible",
    "http://schema.org/Organization",
    "http://schema.org/Person",
    "http://schema.org/Place",
    "http://schema.org/Product",
    "http://schema.org/Taxon",
    "http://schema.org/FictionalEntity",
]
if os.path.exists(PRED_URI_TO_S_OBJ_CLASSES_PATH):
    print(f"Loading pred_to_s_and_obj_types from file {PRED_URI_TO_S_OBJ_CLASSES_PATH}.")
    with open(PRED_URI_TO_S_OBJ_CLASSES_PATH, "r") as f:
        pred_to_s_and_obj_types_with_reverse = json.load(f)
else:
    pred_to_s_and_obj_types = dict()
    for pred in relevant_preds:
        concrete_query_single = query_single.format(pred.uri)
        # print(query)
        sparql.setQuery(concrete_query_single)
        superclasses_s = []
        superclasses_obj = []
        for result in sparql.query().bindings:
            s = result["s"].value
            obj = result["obj"].value
            try:
                concrete_query_superclasses_s = query_superclasses.format(s)
                sparql.setQuery(concrete_query_superclasses_s)
                superclasses_s += [x["superclasses"].value for x in sparql.query().bindings]
            except QueryBadFormed as e:
                print("Raising error for s:", s)
                # raise ValueError
            try:
                concrete_query_superclasses_obj = query_superclasses.format(obj)
                sparql.setQuery(concrete_query_superclasses_obj)
                superclasses_obj += [x["superclasses"].value for x in sparql.query().bindings]
            except QueryBadFormed:
                print("Raising error for obj:", obj)
                # raise ValueError
            
        pred_to_s_and_obj_types[pred] = (set(superclasses_s), set(superclasses_obj))
            
    pred_to_s_and_obj_types = {k.uri: (list(s_set.intersection(types)), list(obj_set.intersection(types))) for k, (s_set, obj_set) in pred_to_s_and_obj_types.items()}
    pred_to_s_and_obj_types_with_reverse = {**pred_to_s_and_obj_types, **{f"reverse-{k}": (obj, s) for k, (s, obj) in pred_to_s_and_obj_types.items()}}

with open(PRED_URI_TO_S_OBJ_CLASSES_PATH, "w", encoding='utf-8') as fp:
    json.dump(pred_to_s_and_obj_types_with_reverse, fp, ensure_ascii=False, indent=4)

Loading pred_to_s_and_obj_types from file ./yago_pred_uri_to_s_obj_classes.json.


In [7]:
#######################################################################
# 2. Extract all subject-object pairs for each pred in relevant preds #
#######################################################################
def get_so_pairs_for_pred(pred: Predicate) -> List[Tuple[str, str, str, str]]:
    start_time = time.time()
    print(f"Querying {pred.uri}.")
    so_pairs = []

    query = """
SELECT DISTINCT ?output_s ?output_obj ?s ?obj WHERE {{
  ?s <{}> ?obj . # which predicate to use
  OPTIONAL {{ 
    ?s rdfs:label ?s_label .
    FILTER (LANG(?s_label) = 'en')
  }} # Get the label (name) for the subject if it exists
  BIND(COALESCE(?s_label, ?s) AS ?output_s) # if the label does not exist, stick with the URI
  
  OPTIONAL {{
    ?obj rdfs:label ?obj_label . 
    FILTER (LANG(?obj_label) = 'en')
  }} # Get the label (name) for the object if it exists
  OPTIONAL {{
    ?obj rdf:type ?obj_type . 
    ?obj_type rdfs:label ?obj_type_name .
    FILTER (LANG(?obj_type_name) = 'en')
  }} # get the name of the type for the object if the type exists
  BIND(COALESCE(IF(STR(?obj_label) != "Generic instance", ?obj_label, ?obj_type_name), ?obj) AS ?output_obj) 
  # if the label exists, go with the label, but if it's "Generic instance", then go with the type; if that does not exist, then stick with the OG object.
  BIND(MD5(CONCAT(STR(?s), STR(?obj))) AS ?sortkey) .
}}
ORDER BY ?sortkey
LIMIT 1000
""".format(
        pred.uri
    )

    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            so_pairs.append((result["output_s"].value, result["output_obj"].value, result["s"].value, result["obj"].value))
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {pred.uri}. Trying chained query.")
        so_pairs = get_so_pairs_for_pred_chained(pred)
    finally:
        time_elapsed = time.time() - start_time
        print(f"Time elapsed: {time_elapsed}")
        return so_pairs
    
def get_so_pairs_for_pred_chained(pred: Predicate) -> List[Tuple[str, str, str, str]]:
    start_time = time.time()
    print(f"Querying {pred.uri} (chained).")
    so_pairs = []

    query = """
SELECT DISTINCT ?s ?o WHERE {{
    ?s <{}> ?o
    BIND(MD5(CONCAT(STR(?s), STR(?o))) AS ?sortkey) .
}}
ORDER BY ?sortkey
LIMIT 1000
""".format(
        pred.uri
    )
    
    query_get_name_of_obj = """
SELECT DISTINCT ?output_obj WHERE {{
    OPTIONAL {{
      <{0}> rdfs:label ?obj_label . 
      FILTER (LANG(?obj_label) = 'en')
    }} # Get the label (name) for the object if it exists
    OPTIONAL {{
      <{0}> rdf:type ?obj_type . 
      ?obj_type rdfs:label ?obj_type_name .
      FILTER (LANG(?obj_type_name) = 'en')
    }} # get the name of the type for the object if the type exists
    BIND(COALESCE(IF(STR(?obj_label) != "Generic instance", ?obj_label, ?obj_type_name), <{0}>) AS ?output_obj) 
}}  
"""

    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            # print(result)
            s_uri = result["s"].value
            o_uri = result["o"].value
            query_s = query_get_name_of_obj.format(s_uri)            
            sparql.setQuery(query_s)
            res_s = sparql.query().bindings
            if len(res_s) != 1:
                raise ValueError(f">1 label returned for subject {s_uri}: {res_s}")
            s_name = res_s[0]["output_obj"].value

            query_obj = query_get_name_of_obj.format(o_uri)
            sparql.setQuery(query_obj)
            res_obj = sparql.query().bindings
            if len(res_obj) != 1:
                raise ValueError(f">1 label returned for subject {o_uri}: {res_obj}")
            obj_name = res_obj[0]["output_obj"].value

            so_pairs.append((s_name, obj_name, s_uri, o_uri))
        
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {pred.uri} when chaining. Skipping.")
        so_pairs = None
    except Error as e:
        print(e)
    finally:
        time_elapsed = time.time() - start_time
        print(f"Time elapsed: {time_elapsed}")
        return so_pairs

In [8]:
if os.path.exists(PRED_URI_TO_SO_PAIRS_PATH):
    print(f"Loading pred_uri_to_so_pairs from file {PRED_URI_TO_SO_PAIRS_PATH}.")
    with open(PRED_URI_TO_SO_PAIRS_PATH) as f:
        pred_uri_to_so_pairs = json.load(f)
else:
    pred_to_so_pairs: Dict[Predicate, List[Tuple[str, str, str, str]]] = {
        pred: get_so_pairs_for_pred(pred) for pred in tqdm(relevant_preds)
    }
    pred_to_so_pairs = {k: v for k, v in pred_to_so_pairs.items() if v is not None}
    pred_uri_to_so_pairs = {
        k.uri: v for k, v in pred_to_so_pairs.items() if v is not None
    }

if TRY_QUERYING_MISSING_PREDS:
    missing_preds = [p for p in relevant_preds if p.uri not in pred_uri_to_so_pairs]
    print("Missing preds:", missing_preds)
    missing_pred_to_so_pairs: Dict[Predicate, List[Tuple[str, str, str, str]]] = {
        pred: get_so_pairs_for_pred(pred) for pred in tqdm(missing_preds)
    }
    missing_pred_to_so_pairs = {k: v for k, v in missing_pred_to_so_pairs.items() if v is not None}
    missing_pred_uri_to_so_pairs = {
        k.uri: v for k, v in missing_pred_to_so_pairs.items() if v is not None
    }

    pred_uri_to_so_pairs = {**missing_pred_uri_to_so_pairs, **pred_uri_to_so_pairs}

with open(PRED_URI_TO_SO_PAIRS_PATH, "w", encoding='utf-8') as fp:
    json.dump(pred_uri_to_so_pairs, fp, ensure_ascii=False, indent=4)

def augment_pred_uri_to_so_pairs_with_reverse(pred_uri_to_so_pairs):
    return {
        **pred_uri_to_so_pairs,
        **{
            f"reverse-{k}": [(a_label, e_label, a_uri, e_uri) for (e_label, a_label, e_uri, a_uri) in v]
            for k, v in pred_uri_to_so_pairs.items()
        },
    }

pred_uri_to_so_pairs_with_reverse = augment_pred_uri_to_so_pairs_with_reverse(
    pred_uri_to_so_pairs
)

Loading pred_uri_to_so_pairs from file ./yago_pred_uri_to_so_pairs_randomized_1k.json.


In [9]:
########################################################################################################
# 4. Construct queries containing entities, corresponding answers, query forms, and context templates. #
########################################################################################################
from yago_questions import yago_topic_to_qfs

keys = set(yago_topic_to_qfs).intersection(set(pred_uri_to_so_pairs_with_reverse))
print(keys)
yago_qec = {
    k: {
        "query_forms": yago_topic_to_qfs[k],
        "entities": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[0],
        "answers": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[1],
        "entity_uris": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[2],
        "answer_uris": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[3],
        "context_templates": [yago_topic_to_qfs[k]["open"][-1] + " {answer}.\n"],
        "entity_types": pred_to_s_and_obj_types_with_reverse[k][0], 
        "answer_types": pred_to_s_and_obj_types_with_reverse[k][1], 
    }
    for k in keys
}
yago_qec
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

{'reverse-http://schema.org/organizer', 'http://schema.org/children', 'http://yago-knowledge.org/resource/consumes', 'http://yago-knowledge.org/resource/appearsIn', 'http://schema.org/leader', 'reverse-http://schema.org/owns', 'http://schema.org/startDate', 'http://yago-knowledge.org/resource/terminus', 'http://schema.org/contentLocation', 'http://schema.org/icaoCode', 'http://schema.org/deathPlace', 'http://schema.org/locationCreated', 'http://yago-knowledge.org/resource/flowsInto', 'http://yago-knowledge.org/resource/studentOf', 'http://yago-knowledge.org/resource/conferredBy', 'http://schema.org/material', 'http://yago-knowledge.org/resource/follows', 'http://schema.org/parentTaxon', 'http://schema.org/populationNumber', 'http://schema.org/administrates', 'reverse-http://yago-knowledge.org/resource/terminus', 'reverse-http://yago-knowledge.org/resource/doctoralAdvisor', 'http://schema.org/lowestPoint', 'http://yago-knowledge.org/resource/radialVelocity', 'http://yago-knowledge.org/r

In [10]:
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec_reloaded = json.load(fp)

In [11]:
for qid, v in yago_qec.items():
    print(qid, len(v["entities"]), len(set(v["entities"])))

reverse-http://schema.org/organizer 1000 282
http://schema.org/children 1000 994
http://yago-knowledge.org/resource/consumes 202 89
http://yago-knowledge.org/resource/appearsIn 1000 917
http://schema.org/leader 1000 943
reverse-http://schema.org/owns 1000 919
http://schema.org/startDate 1000 999
http://yago-knowledge.org/resource/terminus 1000 926
http://schema.org/contentLocation 1000 992
http://schema.org/icaoCode 1000 1000
http://schema.org/deathPlace 1000 1000
http://schema.org/locationCreated 1000 997
http://yago-knowledge.org/resource/flowsInto 1000 985
http://yago-knowledge.org/resource/studentOf 722 508
http://yago-knowledge.org/resource/conferredBy 1000 970
http://schema.org/material 1000 954
http://yago-knowledge.org/resource/follows 1000 1000
http://schema.org/parentTaxon 1000 1000
http://schema.org/populationNumber 1000 995
http://schema.org/administrates 1000 800
reverse-http://yago-knowledge.org/resource/terminus 1000 899
reverse-http://yago-knowledge.org/resource/doctora

In [12]:
# import pandas as pd
# pd.DataFrame(
#     yago_qec["reverse-http://schema.org/homeLocation"]["entities"]
# ).value_counts()

In [None]:
####################################
# 5. Get the degree of each entity #
####################################
import itertools
from collections import defaultdict

ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH = os.path.join(DATA_ROOT, "entity_name_to_possible_entity_uris.json")
ENTITY_URI_TO_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_uri_to_degree.json")
ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH = os.path.join(DATA_ROOT, "entity_uri_to_degree_including_ambiguous_entities.json")
ENTITY_URI_TO_PREDICATE_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_uri_to_predicate_degree_path.json")
ENTITY_NAMESAKE_TO_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_namesake_to_degree.json")
ENTITY_NAMESAKE_TO_NUM_URIS_PATH = os.path.join(DATA_ROOT, "entity_namesake_to_num_uris.json")

# Load cached jsons
# {entity_name: List[entity_uri]}
entity_name_to_possible_entity_uris: Dict[str, List[str]] = dict()
if os.path.exists(ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH):
    print(f"Loading entity_name_to_possible_entity_uris from file {ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH}.")
    with open(ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH) as f:
        entity_name_to_possible_entity_uris = json.load(f)

# {entity_uri: degree}
entity_uri_to_degree: Dict[str, int] = dict()
if os.path.exists(ENTITY_URI_TO_DEGREE_PATH):
    print(f"Loading entity_uri_to_degree from file {ENTITY_URI_TO_DEGREE_PATH}.")
    with open(ENTITY_URI_TO_DEGREE_PATH) as f:
        entity_uri_to_degree = json.load(f)    

# {entity_uri: degree}
entity_uri_to_degree_including_ambiguous_entities: Dict[str, int] = dict()
if os.path.exists(ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH):
    print(f"Loading entity_uri_to_degree_including_ambiguous_entities from file {ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH}.")
    with open(ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH) as f:
        entity_uri_to_degree_including_ambiguous_entities = json.load(f)

# {predicate: {entity: degree}}
entity_uri_to_predicate_degree: Dict[str, Dict[str, int]] = defaultdict(dict) 
if os.path.exists(ENTITY_URI_TO_PREDICATE_DEGREE_PATH):
    print(f"Loading entity_uri_to_predicate_degree from file {ENTITY_URI_TO_PREDICATE_DEGREE_PATH}.")
    with open(ENTITY_URI_TO_PREDICATE_DEGREE_PATH) as f:
        entity_uri_to_predicate_degree = defaultdict(dict, json.load(f))

In [14]:
def get_degree_for_entity_uri(entity_uri: str) -> int:
    start_time = time.time()
    # print(f"Querying degree of {entity_uri}.")

    query = """
SELECT (COUNT(?edge) as ?degree)
WHERE {{
  {{
    <{0}> ?edge ?object.
  }}
  UNION
  {{
    ?subject ?edge <{0}>.
  }}
}}
""".format(
        entity_uri
    )
    try:
        sparql.setQuery(query)
        res_degree = sparql.query().bindings
        if len(res_degree) != 1:
            raise ValueError(f">1 degree returned for entity {entity_uri}: {res_degree}")
        degree = res_degree[0]["degree"].value
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {entity_uri} when chaining. Skipping.")
        degree = None
    except Exception as e:
        print(e)
        degree = None
    finally:
        time_elapsed = time.time() - start_time
        # print(f"Time elapsed: {time_elapsed}")
        return degree
    
def get_possible_entity_uris_per_entity(entity: str) -> int:
    start_time = time.time()
    # print(f"Querying number of possible uris for entity {entity}.")

    query = """
SELECT ?entity_uri
WHERE {{
  ?entity_uri rdfs:label "{}"@en.
}}
""".format(
        entity
    )
    # print(query)
    entity_uris = []
    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            # print(result)
            entity_uris.append(result["entity_uri"].value)
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {entity}. Skipping.")
        entity_uris = None
    except Exception as e:
        print(e)
        entity_uris = None
    finally:
        time_elapsed = time.time() - start_time
        # print(f"Time elapsed: {time_elapsed}")
        return entity_uris
    
def get_predicate_degree_for_entity_uri(entity_uri: str, predicate: str) -> int:
    start_time = time.time()
    # print(f"Querying degree of {entity_uri}.")

    if predicate.startswith("reverse-"):
        predicate = predicate.split("reverse-")[1]
        query = """
                SELECT (COUNT(?subject) as ?degree) WHERE {{
                    {{
                        ?subject <{predicate}> <{entity_uri}>.
                    }}
                }}
                """
    else:
        query = """
                SELECT (COUNT(?object) as ?degree) WHERE {{
                    {{
                        <{entity_uri}> <{predicate}> ?object.
                    }}
                }}
                """
        
    query = query.format(
        entity_uri=entity_uri,
        predicate=predicate
    )
    # print(query)
    try:
        sparql.setQuery(query)
        res_degree = sparql.query().bindings
        if len(res_degree) != 1:
            raise ValueError(f">1 degree returned for entity {entity_uri}: {res_degree}")
        degree = res_degree[0]["degree"].value
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {entity_uri} when chaining. Skipping.")
        degree = None
    except Exception as e:
        print(e)
        degree = None
    finally:
        time_elapsed = time.time() - start_time
        # print(f"Time elapsed: {time_elapsed}")
        return degree

In [15]:
# Construct the relevant list of entity uris and entity names.
entity_uris: List[str] = list(set(itertools.chain.from_iterable([v["entity_uris"] for _, v in yago_qec.items()])))
entities: List[str] = list(set(itertools.chain.from_iterable([v["entities"] for _, v in yago_qec.items()])))
predicate_to_entity_uris: Dict[str, List[str]] = {predicate: qec["entity_uris"] for predicate, qec in yago_qec.items()}
answers: List[str] = list(set(itertools.chain.from_iterable([v["answers"] for _, v in yago_qec.items()])))

# # Test run examples
# entity_uris = ["http://yago-knowledge.org/resource/Paul_McCartney", "http://yago-knowledge.org/resource/Paul_Allen__u0028_editor_u0029_"]
# entities = ["Paul McCartney", "Paul Allen"]
# predicate_to_entity_uris = defaultdict(
#     dict,
#     {
#         "reverse-http://schema.org/author": ["http://yago-knowledge.org/resource/Anton_Chekhov", "http://yago-knowledge.org/resource/Alexander_Hamilton"],
#         "http://schema.org/author": ["http://yago-knowledge.org/resource/A_Marriage_Proposal"],
#     }
# )
len(entity_uris), len(entities), sum(len(v) for v in predicate_to_entity_uris.values()), len(answers)

(96271, 95030, 112536)

In [36]:
ENTITIES_PATH = os.path.join(DATA_ROOT, "entities.json")
ANSWERS_PATH = os.path.join(DATA_ROOT, "answers.json")

with open(ENTITIES_PATH, "w", encoding='utf-8') as fp:
    json.dump(entities, fp, ensure_ascii=False, indent=4)
    
with open(ANSWERS_PATH, "w", encoding='utf-8') as fp:
    json.dump(answers, fp, ensure_ascii=False, indent=4)

In [16]:
missing_entities = set(entities).difference(set(entity_name_to_possible_entity_uris))
len(missing_entities), len(set(entities)), len(set(entity_name_to_possible_entity_uris))

(49991, 95030, 45040)

In [17]:
missing_entity_uris = set(entity_uris).difference(set(entity_uri_to_degree))
len(missing_entity_uris), len(set(entity_uris)), len(set(entity_uri_to_degree))

(50808, 96271, 45464)

In [18]:
# Construct the jsons mapping (a) from entity names to uris, (b) from entity uris to degrees, and (c) from entity uris (including those of the namesake) to degrees
# This will additively build from the cached files.
missing_entities = set(entities).difference(set(entity_name_to_possible_entity_uris))
entity_name_to_possible_entity_uris = {
    **entity_name_to_possible_entity_uris, 
    **{
        entity: get_possible_entity_uris_per_entity(entity) for entity in tqdm(missing_entities)
    }
}
with open(ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_name_to_possible_entity_uris, fp, ensure_ascii=False, indent=4)

missing_entity_uris = set(entity_uris).difference(set(entity_uri_to_degree))
entity_uri_to_degree = {
    **entity_uri_to_degree,
    **{entity_uri: get_degree_for_entity_uri(entity_uri) for entity_uri in tqdm(missing_entity_uris)}
}
with open(ENTITY_URI_TO_DEGREE_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_uri_to_degree, fp, ensure_ascii=False, indent=4)

missing_entity_uris_ambiguous = {entity_uri for entity in entities for entity_uri in entity_name_to_possible_entity_uris[entity]}.difference(set(entity_uri_to_degree_including_ambiguous_entities))
entity_uri_to_degree_including_ambiguous_entities = {
    **entity_uri_to_degree_including_ambiguous_entities,
    **{
        entity_uri: get_degree_for_entity_uri(entity_uri) if entity_uri not in entity_uri_to_degree else entity_uri_to_degree[entity_uri] for entity_uri in missing_entity_uris_ambiguous
    },
}
with open(ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_uri_to_degree_including_ambiguous_entities, fp, ensure_ascii=False, indent=4)


missing_entity_predicate_uris = defaultdict(
    dict, 
    {
        predicate: {
            entity_uri for entity_uri in set(predicate_to_entity_uris[predicate]).difference(set(entity_uri_to_predicate_degree[predicate]))
        } for predicate in tqdm(yago_qec_reloaded)
    }
)
entity_uri_to_predicate_degree = {
    predicate: {
        **entity_uri_to_predicate_degree[predicate],
        **{
            entity_uri: get_predicate_degree_for_entity_uri(entity_uri, predicate) for entity_uri in missing_entity_predicate_uris[predicate]
        } 
    } for predicate in tqdm(yago_qec_reloaded)
}
with open(ENTITY_URI_TO_PREDICATE_DEGREE_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_uri_to_predicate_degree, fp, ensure_ascii=False, indent=4)    

len(entity_name_to_possible_entity_uris), len(entity_uri_to_degree), len(entity_uri_to_degree_including_ambiguous_entities), len(entity_uri_to_predicate_degree)

 10%|▉         | 4789/49991 [08:30<1:20:21,  9.37it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

 31%|███       | 15590/49991 [27:43<1:01:15,  9.36it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

 54%|█████▎    | 26816/49991 [47:45<41:15,  9.36it/s]  IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_r

({'Paul McCartney': ['http://yago-knowledge.org/resource/Paul_McCartney'],
  'Paul Allen': ['http://yago-knowledge.org/resource/Paul_Allen__u0028_editor_u0029_',
   'http://yago-knowledge.org/resource/Paul_Allen__u0028_sports_commentator_u0029_',
   'http://yago-knowledge.org/resource/Paul_Allen__u0028_footballer_u0029_'],
  'Joan Warburton': ['http://yago-knowledge.org/resource/Joan_Warburton'],
  'Artvin Province': ['http://yago-knowledge.org/resource/Artvin_Province'],
  'Grigory Zass': ['http://yago-knowledge.org/resource/Grigory_Zass'],
  'Jimmy McHugh': ['http://yago-knowledge.org/resource/Jimmy_McHugh'],
  'Panteleimon Romanov': ['http://yago-knowledge.org/resource/Panteleimon_Romanov'],
  'Vliet Street Commons': ['http://yago-knowledge.org/resource/Vliet_Street_Commons'],
  'Yolande Donlan': ['http://yago-knowledge.org/resource/Yolande_Donlan'],
  'Framatome': ['http://yago-knowledge.org/resource/Framatome'],
  'Robert E. Kraut': ['http://yago-knowledge.org/resource/Robert_E_u0

In [31]:
len(entity_name_to_possible_entity_uris), len(entity_uri_to_degree), len(entity_uri_to_degree_including_ambiguous_entities), len(entity_uri_to_predicate_degree)

(95031, 96272, 169229, 127)

In [19]:
# Construct and save entity_namesake_to_* jsons.
# {entity_namesake: degree}
entity_namesake_to_degree: Dict[str, int] = {
    k: sum([int(entity_uri_to_degree_including_ambiguous_entities[uri]) for uri in uris]) for k, uris in entity_name_to_possible_entity_uris.items()
}
with open(ENTITY_NAMESAKE_TO_DEGREE_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_namesake_to_degree, fp, ensure_ascii=False, indent=4)

# {entity_namesake: number of uris with that name}
entity_namesake_to_num_uris: Dict[str, int] = {
    k: len(uris) for k, uris in entity_name_to_possible_entity_uris.items()
}
with open(ENTITY_NAMESAKE_TO_NUM_URIS_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_namesake_to_num_uris, fp, ensure_ascii=False, indent=4)

In [20]:
# Incorporate above entity_namesake_to_* and entity_uri_to_* stats into yago_qec

# Load yago entity namesake to degree stats (including all entity uris sharing that namesake)
ENTITY_NAMESAKE_TO_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_namesake_to_degree.json")
with open(ENTITY_NAMESAKE_TO_DEGREE_PATH, "r") as fp:
    entity_namesake_to_degree = json.load(fp)
    
for k, v in yago_qec.items():
    yago_qec[k]["entity_namesake_to_degree"] = [
        int(entity_namesake_to_degree[entity]) for entity in yago_qec[k]["entities"]
    ]
    
# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

k, sorted(list(zip(yago_qec[k]["entity_namesake_to_degree"], yago_qec[k]["entities"])), key=lambda x: x[0])[:10]

('http://yago-knowledge.org/resource/length',
 [(6, 'Mendip Way'),
  (7, 'Usk Valley Walk'),
  (8, 'Cheddington to Aylesbury line'),
  (8, 'Superior National Forest Scenic Byway'),
  (8, 'Wild Azalea Trail'),
  (8, 'Cedar Valley Trail'),
  (8, 'St. Charles Rock Road'),
  (9, 'Plymouth to Yealmpton Branch'),
  (9, 'Lakeshore Road'),
  (9, 'Maine State Route 183')])

In [21]:
# Load yago entity uri to degree stats
ENTITY_URI_TO_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_uri_to_degree.json")
with open(ENTITY_URI_TO_DEGREE_PATH, "r") as fp:
    entity_uri_to_degree = json.load(fp)
    
for k, v in yago_qec.items():
    yago_qec[k]["entity_uri_to_degree"] = [
        int(entity_uri_to_degree[uri]) for uri in yago_qec[k]["entity_uris"]
    ]
    
# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

k, sorted(list(zip(yago_qec[k]["entity_uri_to_degree"], yago_qec[k]["entities"])), key=lambda x: x[0])[-10:]

('http://yago-knowledge.org/resource/length',
 [(108, 'European route E372'),
  (108, 'European route E98'),
  (111, 'European route E232'),
  (117, 'Circuit Ricardo Tormo'),
  (120, 'Bundesautobahn 860'),
  (128, 'Northern line'),
  (128, 'St. Lawrence Seaway'),
  (160, 'Polcevera Viaduct'),
  (167, 'U.S. Route 66'),
  (188, 'Bahrain International Circuit')])

In [22]:
# Load yago entity namesake to number of different uris
ENTITY_NAMESAKE_TO_NUM_URIS_PATH = os.path.join(DATA_ROOT, "entity_namesake_to_num_uris.json")
with open(ENTITY_NAMESAKE_TO_NUM_URIS_PATH, "r") as fp:
    entity_namesake_to_num_uris = json.load(fp)
    
for k, v in yago_qec.items():
    yago_qec[k]["entity_namesake_to_num_uris"] = [
        int(entity_namesake_to_num_uris[entity]) for entity in yago_qec[k]["entities"]
    ]
    
# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

k, sorted(list(zip(yago_qec[k]["entity_namesake_to_num_uris"], yago_qec[k]["entities"])), key=lambda x: x[0])[-10:]

('http://yago-knowledge.org/resource/length',
 [(6, 'National Highway 3'),
  (6, 'Yellow Line'),
  (6, 'A5 motorway'),
  (6, 'A7 motorway'),
  (8, 'State Highway 2'),
  (10, 'Stone Bridge'),
  (11, 'Veterans Memorial Bridge'),
  (16, 'Main Street'),
  (19, 'Blue Line'),
  (29, 'Line 4')])

In [23]:
# Load yago entity namesake to number of different uris
with open(ENTITY_URI_TO_PREDICATE_DEGREE_PATH, "r") as fp:
    entity_uri_to_predicate_degree = json.load(fp)
    
for k, v in yago_qec.items():
    yago_qec[k]["entity_uri_to_predicate_degree"] = [
        int(entity_uri_to_predicate_degree[k][uri]) for uri in yago_qec[k]["entity_uris"]
    ]
    
# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

k, sorted(list(zip(yago_qec[k]["entity_uri_to_predicate_degree"], yago_qec[k]["entities"])), key=lambda x: x[0])[-10:]

('http://yago-knowledge.org/resource/length',
 [(1, 'U.S. Route 27 in Michigan'),
  (1, 'National Highway 329'),
  (1, 'New York State Route 293'),
  (1, 'Nevada State Route 3C'),
  (1, 'British Columbia Highway 118'),
  (1, 'Ohio State Route 307'),
  (1, 'Nebraska Highway 65'),
  (1, 'Gwydir Highway'),
  (1, 'Scioto Trail'),
  (1, 'Caiyuanba Bridge')])

In [33]:
# predicate_to_entity_uris["reverse-http://schema.org/organizer"]

In [25]:
# ##############################
# # 5. Construct fake entities #
# ##############################
# # This section requires GPUs #
# ##############################

# # Can be run separate from previous section #
# from transformers import ReformerModelWithLMHead
# import json
# import re
# import os
# import itertools
# from transformers import ReformerModelWithLMHead

# import torch
# import random
# import numpy as np

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Encoding
# def encode(list_of_strings, pad_token_id=0, device="cpu"):
#     max_length = max([len(string) for string in list_of_strings])

#     # create emtpy tensors
#     attention_masks = torch.zeros((len(list_of_strings), max_length), dtype=torch.long)
#     input_ids = torch.full((len(list_of_strings), max_length), pad_token_id, dtype=torch.long)

#     for idx, string in enumerate(list_of_strings):
#         # make sure string is in byte format
#         if not isinstance(string, bytes):
#             string = str.encode(string)

#         input_ids[idx, :len(string)] = torch.tensor([x + 2 for x in string])
#         attention_masks[idx, :len(string)] = 1

#     return input_ids.to(device), attention_masks.to(device)
    
# # Decoding
# def decode(outputs_ids):
#     decoded_outputs = []
#     for output_ids in outputs_ids.tolist():
#         # transform id back to char IDs < 2 are simply transformed to ""
#         decoded_outputs.append("".join([chr(x - 2) if x > 1 else "" for x in output_ids]))
#     return decoded_outputs

# def set_seed(seed=0):
#     torch.manual_seed(seed)
#     random.seed(seed)
#     np.random.seed(seed)
    
# def extract_entity(text):
#     regex_pattern = r'\[\[([^\[|\]]+?)(?:\([^)]*\))?(?:\|[^|\]]+)*\]\]'
#     match = re.search(regex_pattern, text)

#     if match:
#         result = match.group(1).strip()
#         return result
    
#     return None

In [26]:
# from collections import defaultdict
# from typing import Dict, Set

# with open(YAGO_QEC_PATH, "r") as fp:
#     yago_qec = json.load(fp)
# entity_types_to_real_entities: Dict[str, Set[str]] = defaultdict(set)
# for query_id, qec in yago_qec.items():
#     for et in qec["entity_types"]:
#         entity_types_to_real_entities[et] = entity_types_to_real_entities[et].union(qec["entities"])
# print(json.dumps({k: len(v) for k, v in entity_types_to_real_entities.items()}, indent=4))

# real_entities = set(itertools.chain.from_iterable([v["entities"] for _, v in yago_qec.items()]))
# print(len(real_entities))

In [27]:
# entity_types_to_prompts = {
#     "http://schema.org/CreativeWork": "The creative work called [[",
#     "http://schema.org/Event": "The event called [[",
#     "http://schema.org/Intangible": "The concept called [[",
#     "http://schema.org/Organization": "The organization is called [[",
#     "http://schema.org/Person": "The person named [[",
#     "http://schema.org/Place": "The place is named [[",
#     "http://schema.org/Product": "The product is called [[",
#     "http://schema.org/Taxon": "The taxon named [[",
#     "http://schema.org/FictionalEntity": "The fictional entity named [[",
# }
# entity_types_to_fake_entities = {}
# model = ReformerModelWithLMHead.from_pretrained("google/reformer-enwik8").to(device)

In [28]:
# for et, p in entity_types_to_prompts.items():
#     set_seed(0)
#     print(f"***{et}***")
#     encoded, attention_masks = encode([p], device=device)
#     res = decode(model.generate(encoded, do_sample=True, num_return_sequences=1200, max_length=100))

#     extracted_res = {extract_entity(s) for s in res if extract_entity(s) is not None}
#     print(f"# unique fake ents: {len(extracted_res)}")
#     extracted_res_without_reals = extracted_res.difference(real_entities)
#     print(f"# unique fake ents removing reals: {len(extracted_res_without_reals)}")
#     extracted_res = random.sample(list(extracted_res), 1000)
#     entity_types_to_fake_entities[et] = extracted_res

In [29]:
# YAGO_FAKE_ENTITIES_PATH = os.path.join(DATA_ROOT, "fake_entities.json") 

# # Load yago_qec
# with open(YAGO_QEC_PATH, "r") as fp:
#     yago_qec = json.load(fp)

# # Save fake entities per each entity type.
# with open(YAGO_FAKE_ENTITIES_PATH, "w", encoding='utf-8') as fp:
#     json.dump(entity_types_to_fake_entities, fp, ensure_ascii=False, indent=4)

# # Randomly sample fake entities that are eligible according to entity type for each relation and save to yago_qec
# for k, v in yago_qec.items():
#     entity_types = yago_qec[k]["entity_types"]
#     eligible_fake_entities = list(itertools.chain.from_iterable([entity_types_to_fake_entities[et] for et in entity_types]))
#     yago_qec[k]["fake_entities"] = random.sample(eligible_fake_entities, len(yago_qec[k]["entities"]))

# # Save yago_qec including fake entities
# with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
#     json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

In [30]:
###########################################
# 6. Construct fake entities with chatgpt #
###########################################
import pandas as pd
import random
import torch
import numpy as np
import itertools

def set_seed(seed=0):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

set_seed(0)
YAGO_GPT_FAKE_ENTITIES_PATH = os.path.join(DATA_ROOT, "chatgpt_fake_entities_all.csv") 

fake_entities_gpt_df = pd.read_csv(YAGO_GPT_FAKE_ENTITIES_PATH)
fake_entities_gpt_df = fake_entities_gpt_df.add_prefix("http://schema.org/")

# Load yago_qec
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)

# Randomly sample fake entities that are eligible according to entity type for each relation and save to yago_qec
for k, v in yago_qec.items():
    entity_types = yago_qec[k]["entity_types"]
    eligible_fake_entities = list(itertools.chain.from_iterable([fake_entities_gpt_df[et] for et in entity_types]))
    yago_qec[k]["gpt_fake_entities"] = random.sample(eligible_fake_entities, min(len(yago_qec[k]["entities"]), len(eligible_fake_entities)))

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

KeyError: 'http://schema.org/Taxon'

In [None]:
len(yago_qec[k]["entities"])