# Generate the YAGO-Query-Entity-Contexts (YagoQEC) Dataset

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
!pwd

/cluster/work/cotterell/kdu/measureLM/preprocessing/YagoECQ


In [3]:
from collections import namedtuple
import os
import itertools
import json
import random
import re
import time
from tqdm import tqdm
from typing import List, Dict, Tuple
from urllib.error import HTTPError

import numpy as np
import pandas as pd
from SPARQLWrapper import SPARQLWrapper2
from SPARQLWrapper.SPARQLExceptions import EndPointInternalError
from utils import extract_name_from_yago_uri, negate_template, lowercase_first_letter

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [4]:
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)


set_seed(0)

In [5]:
Predicate = namedtuple("Predicate", ["uri", "kb_name", "relation"])
DATA_ROOT = "../../data/YagoECQ/"
PRED_URI_TO_S_OBJ_CLASSES_PATH = os.path.join(
    DATA_ROOT, "yago_pred_uri_to_s_obj_classes.json"
)
PRED_URI_TO_SO_PAIRS_PATH = os.path.join(
    DATA_ROOT, "yago_pred_uri_to_so_pairs_randomized_1k.json"
)
YAGO_QEC_PATH = os.path.join(DATA_ROOT, "yago_qec.json")
TRY_QUERYING_MISSING_PREDS = False

### 1. Get all relevant predicates from Yago

In [6]:
##################################
# 1. Get all relevant predicates #
##################################
sparql = SPARQLWrapper2("https://yago-knowledge.org/sparql/query")
query_p = """
    SELECT DISTINCT ?p WHERE {
     ?s ?p ?obj . 
    }  ORDER BY ?p
"""

# Sparql query
sparql.setQuery(query_p)

# Adding values
relevant_preds: List[Predicate] = []
ineligible_relations = ["schema#fromClass", "schema#fromProperty", "logo", "image"]

for result in sparql.query().bindings:
    uri = result["p"].value
    kb_name, relation = extract_name_from_yago_uri(uri)
    if kb_name != "w3" and relation not in ineligible_relations:
        relevant_preds.append(Predicate(uri=uri, kb_name=kb_name, relation=relation))

In [7]:
relevant_preds[:3]

[Predicate(uri='http://schema.org/about', kb_name='schema', relation='about'),
 Predicate(uri='http://schema.org/actor', kb_name='schema', relation='actor'),
 Predicate(uri='http://schema.org/address', kb_name='schema', relation='address')]

In [8]:
###########################################################
# 1b. Get all subject and object types for each predicate #
###########################################################
from SPARQLWrapper.SPARQLExceptions import QueryBadFormed
from collections import Counter

query_single = """
    SELECT DISTINCT ?s ?obj WHERE {{
     ?s <{}> ?obj . 
    }}  LIMIT 100
"""
query_superclasses = """
    SELECT DISTINCT ?superclasses WHERE {{
     <{}> rdf:type/rdfs:subClassOf* ?superclasses .
    }}  
"""

types = [
    "http://schema.org/CreativeWork",
    "http://schema.org/Event",
    "http://schema.org/Intangible",
    "http://schema.org/Organization",
    "http://schema.org/Person",
    "http://schema.org/Place",
    "http://schema.org/Product",
    "http://schema.org/Taxon",
    "http://schema.org/FictionalEntity",
]  # Extracted from https://yago-knowledge.org/schema

if os.path.exists(PRED_URI_TO_S_OBJ_CLASSES_PATH):
    print(
        f"Loading pred_to_s_and_obj_types from file {PRED_URI_TO_S_OBJ_CLASSES_PATH}."
    )
    with open(PRED_URI_TO_S_OBJ_CLASSES_PATH, "r") as f:
        pred_to_s_and_obj_types_with_reverse = json.load(f)
else:
    pred_to_s_and_obj_types = dict()
    for pred in tqdm(relevant_preds):
        concrete_query_single = query_single.format(pred.uri)
        # print(query)
        sparql.setQuery(concrete_query_single)
        superclasses_s = []
        superclasses_obj = []
        for result in sparql.query().bindings:
            s = result["s"].value
            obj = result["obj"].value
            try:
                concrete_query_superclasses_s = query_superclasses.format(s)
                sparql.setQuery(concrete_query_superclasses_s)
                superclasses_s += [
                    x["superclasses"].value for x in sparql.query().bindings
                ]
            except QueryBadFormed as e:
                print("Raising error for s:", s)
                # raise ValueError
            try:
                concrete_query_superclasses_obj = query_superclasses.format(obj)
                sparql.setQuery(concrete_query_superclasses_obj)
                superclasses_obj += [
                    x["superclasses"].value for x in sparql.query().bindings
                ]
            except QueryBadFormed:
                print("Raising error for obj:", obj)
                # raise ValueError

        s_counter, obj_counter = Counter(superclasses_s), Counter(superclasses_obj)

        s_counter = Counter({k: v for k, v in s_counter.items() if k in types})
        s_counter = {k: v / s_counter.total() for k, v in s_counter.items()}
        obj_counter = Counter({k: v for k, v in obj_counter.items() if k in types})
        obj_counter = {k: v / obj_counter.total() for k, v in obj_counter.items()}

        pred_to_s_and_obj_types[pred] = (s_counter, obj_counter)

    pred_to_s_and_obj_types = {
        k.uri: (s_counter, obj_counter)
        for k, (s_counter, obj_counter) in pred_to_s_and_obj_types.items()
    }
    pred_to_s_and_obj_types_with_reverse = {
        **pred_to_s_and_obj_types,
        **{
            f"reverse-{k}": (obj_counter, s_counter)
            for k, (s_counter, obj_counter) in pred_to_s_and_obj_types.items()
        },
    }

with open(PRED_URI_TO_S_OBJ_CLASSES_PATH, "w", encoding="utf-8") as fp:
    json.dump(pred_to_s_and_obj_types_with_reverse, fp, ensure_ascii=False, indent=4)

Loading pred_to_s_and_obj_types from file ../../data/YagoECQ/yago_pred_uri_to_s_obj_classes.json.


### 2. Extract all subject-object pairs for each predicate

In [9]:
#######################################################################
# 2. Extract all subject-object pairs for each pred in relevant preds #
#######################################################################
def get_so_pairs_for_pred(pred: Predicate) -> List[Tuple[str, str, str, str]]:
    start_time = time.time()
    print(f"Querying {pred.uri}.")
    so_pairs = []

    query = """
SELECT DISTINCT ?output_s ?output_obj ?s ?obj WHERE {{
  ?s <{}> ?obj . # which predicate to use
  OPTIONAL {{ 
    ?s rdfs:label ?s_label .
    FILTER (LANG(?s_label) = 'en')
  }} # Get the label (name) for the subject if it exists
  BIND(COALESCE(?s_label, ?s) AS ?output_s) # if the label does not exist, stick with the URI
  
  OPTIONAL {{
    ?obj rdfs:label ?obj_label . 
    FILTER (LANG(?obj_label) = 'en')
  }} # Get the label (name) for the object if it exists
  OPTIONAL {{
    ?obj rdf:type ?obj_type . 
    ?obj_type rdfs:label ?obj_type_name .
    FILTER (LANG(?obj_type_name) = 'en')
  }} # get the name of the type for the object if the type exists
  BIND(COALESCE(IF(STR(?obj_label) != "Generic instance", ?obj_label, ?obj_type_name), ?obj) AS ?output_obj) 
  # if the label exists, go with the label, but if it's "Generic instance", then go with the type; if that does not exist, then stick with the OG object.
  BIND(MD5(CONCAT(STR(?s), STR(?obj))) AS ?sortkey) .
}}
ORDER BY ?sortkey
LIMIT 1000
""".format(
        pred.uri
    )

    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            so_pairs.append(
                (
                    result["output_s"].value,
                    result["output_obj"].value,
                    result["s"].value,
                    result["obj"].value,
                )
            )
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {pred.uri}. Trying chained query.")
        so_pairs = get_so_pairs_for_pred_chained(pred)
    finally:
        time_elapsed = time.time() - start_time
        print(f"Time elapsed: {time_elapsed}")
        return so_pairs


def get_so_pairs_for_pred_chained(pred: Predicate) -> List[Tuple[str, str, str, str]]:
    start_time = time.time()
    print(f"Querying {pred.uri} (chained).")
    so_pairs = []

    query = """
SELECT DISTINCT ?s ?o WHERE {{
    ?s <{}> ?o
    BIND(MD5(CONCAT(STR(?s), STR(?o))) AS ?sortkey) .
}}
ORDER BY ?sortkey
LIMIT 1000
""".format(
        pred.uri
    )

    query_get_name_of_obj = """
SELECT DISTINCT ?output_obj WHERE {{
    OPTIONAL {{
      <{0}> rdfs:label ?obj_label . 
      FILTER (LANG(?obj_label) = 'en')
    }} # Get the label (name) for the object if it exists
    OPTIONAL {{
      <{0}> rdf:type ?obj_type . 
      ?obj_type rdfs:label ?obj_type_name .
      FILTER (LANG(?obj_type_name) = 'en')
    }} # get the name of the type for the object if the type exists
    BIND(COALESCE(IF(STR(?obj_label) != "Generic instance", ?obj_label, ?obj_type_name), <{0}>) AS ?output_obj) 
}}  
"""

    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            # print(result)
            s_uri = result["s"].value
            o_uri = result["o"].value
            query_s = query_get_name_of_obj.format(s_uri)
            sparql.setQuery(query_s)
            res_s = sparql.query().bindings
            if len(res_s) != 1:
                raise ValueError(f">1 label returned for subject {s_uri}: {res_s}")
            s_name = res_s[0]["output_obj"].value

            query_obj = query_get_name_of_obj.format(o_uri)
            sparql.setQuery(query_obj)
            res_obj = sparql.query().bindings
            if len(res_obj) != 1:
                raise ValueError(f">1 label returned for subject {o_uri}: {res_obj}")
            obj_name = res_obj[0]["output_obj"].value

            so_pairs.append((s_name, obj_name, s_uri, o_uri))

    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {pred.uri} when chaining. Skipping.")
        so_pairs = None
    except Error as e:
        print(e)
    finally:
        time_elapsed = time.time() - start_time
        print(f"Time elapsed: {time_elapsed}")
        return so_pairs

In [10]:
if os.path.exists(PRED_URI_TO_SO_PAIRS_PATH):
    print(f"Loading pred_uri_to_so_pairs from file {PRED_URI_TO_SO_PAIRS_PATH}.")
    with open(PRED_URI_TO_SO_PAIRS_PATH) as f:
        pred_uri_to_so_pairs = json.load(f)
else:
    pred_to_so_pairs: Dict[Predicate, List[Tuple[str, str, str, str]]] = {
        pred: get_so_pairs_for_pred(pred) for pred in tqdm(relevant_preds)
    }
    pred_to_so_pairs = {k: v for k, v in pred_to_so_pairs.items() if v is not None}
    pred_uri_to_so_pairs = {
        k.uri: v for k, v in pred_to_so_pairs.items() if v is not None
    }

if TRY_QUERYING_MISSING_PREDS:
    missing_preds = [p for p in relevant_preds if p.uri not in pred_uri_to_so_pairs]
    print("Missing preds:", missing_preds)
    missing_pred_to_so_pairs: Dict[Predicate, List[Tuple[str, str, str, str]]] = {
        pred: get_so_pairs_for_pred(pred) for pred in tqdm(missing_preds)
    }
    missing_pred_to_so_pairs = {
        k: v for k, v in missing_pred_to_so_pairs.items() if v is not None
    }
    missing_pred_uri_to_so_pairs = {
        k.uri: v for k, v in missing_pred_to_so_pairs.items() if v is not None
    }

    pred_uri_to_so_pairs = {**missing_pred_uri_to_so_pairs, **pred_uri_to_so_pairs}

with open(PRED_URI_TO_SO_PAIRS_PATH, "w", encoding="utf-8") as fp:
    json.dump(pred_uri_to_so_pairs, fp, ensure_ascii=False, indent=4)


def augment_pred_uri_to_so_pairs_with_reverse(pred_uri_to_so_pairs):
    return {
        **pred_uri_to_so_pairs,
        **{
            f"reverse-{k}": [
                (a_label, e_label, a_uri, e_uri)
                for (e_label, a_label, e_uri, a_uri) in v
            ]
            for k, v in pred_uri_to_so_pairs.items()
        },
    }


pred_uri_to_so_pairs_with_reverse = augment_pred_uri_to_so_pairs_with_reverse(
    pred_uri_to_so_pairs
)

Loading pred_uri_to_so_pairs from file ../../data/YagoECQ/yago_pred_uri_to_so_pairs_randomized_1k.json.


In [11]:
# Check the distribution of entity types seems reasonable
pred_to_s_and_obj_types_with_reverse["http://schema.org/highestPoint"][0]

{'http://schema.org/Place': 0.7518796992481203,
 'http://schema.org/Organization': 0.24812030075187969}

### 3. Construct queries containing entities, corresponding answers, query forms, and context templates

In [12]:
########################################################################################################
# 3. Construct queries containing entities, corresponding answers, query forms, and context templates. #
########################################################################################################
from yago_questions import yago_topic_to_qfs

keys = set(yago_topic_to_qfs).intersection(set(pred_uri_to_so_pairs_with_reverse))
print(keys)
yago_qec = {
    k: {
        "query_forms": yago_topic_to_qfs[k],
        "entities": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[0],
        "answers": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[1],
        "entity_uris": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[2],
        "answer_uris": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[3],
        "context_templates": {
            "base": yago_topic_to_qfs[k]["open"][-1] + " {answer}.\n",
            "assertive": "Definitely, "
            + lowercase_first_letter(yago_topic_to_qfs[k]["open"][-1])
            + " {answer}.\n",
            "super_assertive": "Most certainly and definitely, "
            + lowercase_first_letter(yago_topic_to_qfs[k]["open"][-1])
            + " {answer}.\n",
            "ignore_prior": "Forget everything you think you know about {entity}. Most certainly and definitely, "
            + lowercase_first_letter(yago_topic_to_qfs[k]["open"][-1])
            + " {answer}.\n",
            "believe_me": "Believe me, "
            + lowercase_first_letter(yago_topic_to_qfs[k]["open"][-1])
            + " {answer}.\n",
            "negation": negate_template(
                yago_topic_to_qfs[k]["open"][-1] + " {answer}.\n"
            ),
        },
        "entity_types": pred_to_s_and_obj_types_with_reverse[k][0],
        "answer_types": pred_to_s_and_obj_types_with_reverse[k][1],
    }
    for k in keys
}
# yago_qec['http://schema.org/highestPoint']["entity_types"]
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

{'http://schema.org/author', 'http://schema.org/deathDate', 'reverse-http://yago-knowledge.org/resource/studentOf', 'http://schema.org/director', 'reverse-http://schema.org/author', 'http://yago-knowledge.org/resource/radialVelocity', 'http://yago-knowledge.org/resource/sportNumber', 'http://schema.org/dissolutionDate', 'reverse-http://yago-knowledge.org/resource/participant', 'reverse-http://yago-knowledge.org/resource/playsIn', 'reverse-http://schema.org/homeLocation', 'http://schema.org/icaoCode', 'reverse-http://schema.org/performer', 'http://schema.org/numberOfPages', 'http://schema.org/iswcCode', 'http://schema.org/editor', 'http://yago-knowledge.org/resource/length', 'http://yago-knowledge.org/resource/terminus', 'reverse-http://schema.org/leader', 'http://schema.org/duns', 'http://schema.org/url', 'http://schema.org/recordLabel', 'http://schema.org/parentTaxon', 'http://schema.org/nationality', 'http://schema.org/birthDate', 'http://yago-knowledge.org/resource/distanceFromEarth



In [13]:
with open(YAGO_QEC_PATH) as f:
    yago_qec = json.load(f)

In [14]:
for qid, v in yago_qec.items():
    print(qid, len(v["entities"]), len(set(v["entities"])))

http://schema.org/about 1000 989
http://schema.org/actor 1000 991
http://schema.org/address 1000 997
http://schema.org/administrates 1000 800
http://schema.org/affiliation 1000 968
http://schema.org/alumniOf 1000 999
http://schema.org/author 1000 997
http://schema.org/award 1000 995
http://schema.org/birthDate 1000 999
http://schema.org/birthPlace 1000 1000
http://schema.org/children 1000 994
http://schema.org/contentLocation 1000 992
http://schema.org/dateCreated 1000 999
http://schema.org/deathDate 1000 999
http://schema.org/deathPlace 1000 1000
http://schema.org/demonym 1000 884
http://schema.org/director 1000 998
http://schema.org/dissolutionDate 1000 999
http://schema.org/duns 119 119
http://schema.org/duration 1000 999
http://schema.org/editor 378 263
http://schema.org/elevation 1000 998
http://schema.org/endDate 1000 999
http://schema.org/founder 1000 987
http://schema.org/gtin 21 18
http://schema.org/highestPoint 1000 991
http://schema.org/homeLocation 1000 995
http://schema.or

### 4. Get the degree and related features of each entity

In [15]:
####################################
# 4. Get the degree of each entity #
####################################
import itertools
from collections import defaultdict

ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH = os.path.join(
    DATA_ROOT, "entity_name_to_possible_entity_uris.json"
)
ENTITY_URI_TO_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_uri_to_degree.json")
ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH = os.path.join(
    DATA_ROOT, "entity_uri_to_degree_including_ambiguous_entities.json"
)
ENTITY_URI_TO_PREDICATE_DEGREE_PATH = os.path.join(
    DATA_ROOT, "entity_uri_to_predicate_degree_path.json"
)
ENTITY_NAMESAKE_TO_DEGREE_PATH = os.path.join(
    DATA_ROOT, "entity_namesake_to_degree.json"
)
ENTITY_NAMESAKE_TO_NUM_URIS_PATH = os.path.join(
    DATA_ROOT, "entity_namesake_to_num_uris.json"
)

# Load cached jsons
# {entity_name: List[entity_uri]}
entity_name_to_possible_entity_uris: Dict[str, List[str]] = dict()
if os.path.exists(ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH):
    print(
        f"Loading entity_name_to_possible_entity_uris from file {ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH}."
    )
    with open(ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH) as f:
        entity_name_to_possible_entity_uris = json.load(f)

# {entity_uri: degree}
entity_uri_to_degree: Dict[str, int] = dict()
if os.path.exists(ENTITY_URI_TO_DEGREE_PATH):
    print(f"Loading entity_uri_to_degree from file {ENTITY_URI_TO_DEGREE_PATH}.")
    with open(ENTITY_URI_TO_DEGREE_PATH) as f:
        entity_uri_to_degree = json.load(f)

# {entity_uri: degree}
entity_uri_to_degree_including_ambiguous_entities: Dict[str, int] = dict()
if os.path.exists(ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH):
    print(
        f"Loading entity_uri_to_degree_including_ambiguous_entities from file {ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH}."
    )
    with open(ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH) as f:
        entity_uri_to_degree_including_ambiguous_entities = json.load(f)

# {predicate: {entity: degree}}
entity_uri_to_predicate_degree: Dict[str, Dict[str, int]] = defaultdict(dict)
if os.path.exists(ENTITY_URI_TO_PREDICATE_DEGREE_PATH):
    print(
        f"Loading entity_uri_to_predicate_degree from file {ENTITY_URI_TO_PREDICATE_DEGREE_PATH}."
    )
    with open(ENTITY_URI_TO_PREDICATE_DEGREE_PATH) as f:
        entity_uri_to_predicate_degree = defaultdict(dict, json.load(f))

Loading entity_name_to_possible_entity_uris from file ../../data/YagoECQ/entity_name_to_possible_entity_uris.json.
Loading entity_uri_to_degree from file ../../data/YagoECQ/entity_uri_to_degree.json.
Loading entity_uri_to_degree_including_ambiguous_entities from file ../../data/YagoECQ/entity_uri_to_degree_including_ambiguous_entities.json.
Loading entity_uri_to_predicate_degree from file ../../data/YagoECQ/entity_uri_to_predicate_degree_path.json.


In [16]:
def get_degree_for_entity_uri(entity_uri: str) -> int:
    start_time = time.time()
    # print(f"Querying degree of {entity_uri}.")

    query = """
SELECT (COUNT(?edge) as ?degree)
WHERE {{
  {{
    <{0}> ?edge ?object.
  }}
  UNION
  {{
    ?subject ?edge <{0}>.
  }}
}}
""".format(
        entity_uri
    )
    try:
        sparql.setQuery(query)
        res_degree = sparql.query().bindings
        if len(res_degree) != 1:
            raise ValueError(
                f">1 degree returned for entity {entity_uri}: {res_degree}"
            )
        degree = res_degree[0]["degree"].value
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {entity_uri} when chaining. Skipping.")
        degree = None
    except Exception as e:
        print(e)
        degree = None
    finally:
        time_elapsed = time.time() - start_time
        # print(f"Time elapsed: {time_elapsed}")
        return degree


def get_possible_entity_uris_per_entity(entity: str) -> int:
    start_time = time.time()
    # print(f"Querying number of possible uris for entity {entity}.")

    query = """
SELECT ?entity_uri
WHERE {{
  ?entity_uri rdfs:label "{}"@en.
}}
""".format(
        entity
    )
    # print(query)
    entity_uris = []
    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            # print(result)
            entity_uris.append(result["entity_uri"].value)
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {entity}. Skipping.")
        entity_uris = None
    except Exception as e:
        print(e)
        entity_uris = None
    finally:
        time_elapsed = time.time() - start_time
        # print(f"Time elapsed: {time_elapsed}")
        return entity_uris


def get_predicate_degree_for_entity_uri(entity_uri: str, predicate: str) -> int:
    start_time = time.time()
    # print(f"Querying degree of {entity_uri}.")

    if predicate.startswith("reverse-"):
        predicate = predicate.split("reverse-")[1]
        query = """
                SELECT (COUNT(?subject) as ?degree) WHERE {{
                    {{
                        ?subject <{predicate}> <{entity_uri}>.
                    }}
                }}
                """
    else:
        query = """
                SELECT (COUNT(?object) as ?degree) WHERE {{
                    {{
                        <{entity_uri}> <{predicate}> ?object.
                    }}
                }}
                """

    query = query.format(entity_uri=entity_uri, predicate=predicate)
    # print(query)
    try:
        sparql.setQuery(query)
        res_degree = sparql.query().bindings
        if len(res_degree) != 1:
            raise ValueError(
                f">1 degree returned for entity {entity_uri}: {res_degree}"
            )
        degree = res_degree[0]["degree"].value
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {entity_uri} when chaining. Skipping.")
        degree = None
    except Exception as e:
        print(e)
        degree = None
    finally:
        time_elapsed = time.time() - start_time
        # print(f"Time elapsed: {time_elapsed}")
        return degree

In [17]:
# Construct the relevant list of entity uris and entity names.
entity_uris: List[str] = list(
    set(itertools.chain.from_iterable([v["entity_uris"] for _, v in yago_qec.items()]))
)
entities: List[str] = list(
    set(itertools.chain.from_iterable([v["entities"] for _, v in yago_qec.items()]))
)
predicate_to_entity_uris: Dict[str, List[str]] = {
    predicate: qec["entity_uris"] for predicate, qec in yago_qec.items()
}
answers: List[str] = list(
    set(itertools.chain.from_iterable([v["answers"] for _, v in yago_qec.items()]))
)

# # Test run examples
# entity_uris = ["http://yago-knowledge.org/resource/Paul_McCartney", "http://yago-knowledge.org/resource/Paul_Allen__u0028_editor_u0029_"]
# entities = ["Paul McCartney", "Paul Allen"]
# predicate_to_entity_uris = defaultdict(
#     dict,
#     {
#         "reverse-http://schema.org/author": ["http://yago-knowledge.org/resource/Anton_Chekhov", "http://yago-knowledge.org/resource/Alexander_Hamilton"],
#         "http://schema.org/author": ["http://yago-knowledge.org/resource/A_Marriage_Proposal"],
#     }
# )
len(entity_uris), len(entities), sum(
    len(v) for v in predicate_to_entity_uris.values()
), len(answers)

(96271, 95030, 112536, 78946)

In [18]:
ENTITIES_PATH = os.path.join(DATA_ROOT, "entities.json")
ANSWERS_PATH = os.path.join(DATA_ROOT, "answers.json")

with open(ENTITIES_PATH, "w", encoding="utf-8") as fp:
    json.dump(entities, fp, ensure_ascii=False, indent=4)

with open(ANSWERS_PATH, "w", encoding="utf-8") as fp:
    json.dump(answers, fp, ensure_ascii=False, indent=4)

In [19]:
missing_entities = set(entities).difference(set(entity_name_to_possible_entity_uris))
len(missing_entities), len(set(entities)), len(set(entity_name_to_possible_entity_uris))

(0, 95030, 95031)

In [20]:
missing_entity_uris = set(entity_uris).difference(set(entity_uri_to_degree))
len(missing_entity_uris), len(set(entity_uris)), len(set(entity_uri_to_degree))

(0, 96271, 96272)

In [21]:
# Construct the jsons mapping (a) from entity names to uris, (b) from entity uris to degrees, and (c) from entity uris (including those of the namesake) to degrees
# This will additively build from the cached files.
missing_entities = set(entities).difference(set(entity_name_to_possible_entity_uris))
entity_name_to_possible_entity_uris = {
    **entity_name_to_possible_entity_uris,
    **{
        entity: get_possible_entity_uris_per_entity(entity)
        for entity in tqdm(missing_entities)
    },
}
with open(ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH, "w", encoding="utf-8") as fp:
    json.dump(entity_name_to_possible_entity_uris, fp, ensure_ascii=False, indent=4)

missing_entity_uris = set(entity_uris).difference(set(entity_uri_to_degree))
entity_uri_to_degree = {
    **entity_uri_to_degree,
    **{
        entity_uri: get_degree_for_entity_uri(entity_uri)
        for entity_uri in tqdm(missing_entity_uris)
    },
}
with open(ENTITY_URI_TO_DEGREE_PATH, "w", encoding="utf-8") as fp:
    json.dump(entity_uri_to_degree, fp, ensure_ascii=False, indent=4)

missing_entity_uris_ambiguous = {
    entity_uri
    for entity in entities
    for entity_uri in entity_name_to_possible_entity_uris[entity]
}.difference(set(entity_uri_to_degree_including_ambiguous_entities))
entity_uri_to_degree_including_ambiguous_entities = {
    **entity_uri_to_degree_including_ambiguous_entities,
    **{
        entity_uri: get_degree_for_entity_uri(entity_uri)
        if entity_uri not in entity_uri_to_degree
        else entity_uri_to_degree[entity_uri]
        for entity_uri in missing_entity_uris_ambiguous
    },
}
with open(
    ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH, "w", encoding="utf-8"
) as fp:
    json.dump(
        entity_uri_to_degree_including_ambiguous_entities,
        fp,
        ensure_ascii=False,
        indent=4,
    )


missing_entity_predicate_uris = defaultdict(
    dict,
    {
        predicate: {
            entity_uri
            for entity_uri in set(predicate_to_entity_uris[predicate]).difference(
                set(entity_uri_to_predicate_degree[predicate])
            )
        }
        for predicate in tqdm(yago_qec)
    },
)
entity_uri_to_predicate_degree = {
    predicate: {
        **entity_uri_to_predicate_degree[predicate],
        **{
            entity_uri: get_predicate_degree_for_entity_uri(entity_uri, predicate)
            for entity_uri in missing_entity_predicate_uris[predicate]
        },
    }
    for predicate in tqdm(yago_qec)
}
with open(ENTITY_URI_TO_PREDICATE_DEGREE_PATH, "w", encoding="utf-8") as fp:
    json.dump(entity_uri_to_predicate_degree, fp, ensure_ascii=False, indent=4)

0it [00:00, ?it/s]
0it [00:00, ?it/s]
100%|██████████| 127/127 [00:00<00:00, 8660.70it/s]
100%|██████████| 127/127 [00:00<00:00, 54399.16it/s]


In [22]:
len(entity_name_to_possible_entity_uris), len(entity_uri_to_degree), len(
    entity_uri_to_degree_including_ambiguous_entities
), len(entity_uri_to_predicate_degree)

(95031, 96272, 169229, 127)

In [23]:
# Construct and save entity_namesake_to_* jsons.
# {entity_namesake: degree}
entity_namesake_to_degree: Dict[str, int] = {
    k: sum(
        [int(entity_uri_to_degree_including_ambiguous_entities[uri]) for uri in uris]
    )
    for k, uris in entity_name_to_possible_entity_uris.items()
}
with open(ENTITY_NAMESAKE_TO_DEGREE_PATH, "w", encoding="utf-8") as fp:
    json.dump(entity_namesake_to_degree, fp, ensure_ascii=False, indent=4)

# {entity_namesake: number of uris with that name}
entity_namesake_to_num_uris: Dict[str, int] = {
    k: len(uris) for k, uris in entity_name_to_possible_entity_uris.items()
}
with open(ENTITY_NAMESAKE_TO_NUM_URIS_PATH, "w", encoding="utf-8") as fp:
    json.dump(entity_namesake_to_num_uris, fp, ensure_ascii=False, indent=4)

In [24]:
# Incorporate above entity_namesake_to_* and entity_uri_to_* stats into yago_qec

# Load yago entity namesake to degree stats (including all entity uris sharing that namesake)
ENTITY_NAMESAKE_TO_DEGREE_PATH = os.path.join(
    DATA_ROOT, "entity_namesake_to_degree.json"
)
with open(ENTITY_NAMESAKE_TO_DEGREE_PATH, "r") as fp:
    entity_namesake_to_degree = json.load(fp)

for k, v in yago_qec.items():
    yago_qec[k]["entity_namesake_to_degree"] = [
        int(entity_namesake_to_degree[entity]) for entity in yago_qec[k]["entities"]
    ]

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

k, sorted(
    list(zip(yago_qec[k]["entity_namesake_to_degree"], yago_qec[k]["entities"])),
    key=lambda x: x[0],
    reverse=True,
)[:10]

('reverse-http://yago-knowledge.org/resource/terminus',
 [(38923, 'London'),
  (29297, 'Paris'),
  (29297, 'Paris'),
  (29297, 'Paris'),
  (29297, 'Paris'),
  (29297, 'Paris'),
  (22183, 'Los Angeles'),
  (22183, 'Los Angeles'),
  (22183, 'Los Angeles'),
  (16524, 'Moscow')])

In [25]:
# Load yago entity uri to degree stats
ENTITY_URI_TO_DEGREE_PATH = os.path.join(DATA_ROOT, "entity_uri_to_degree.json")
with open(ENTITY_URI_TO_DEGREE_PATH, "r") as fp:
    entity_uri_to_degree = json.load(fp)

for k, v in yago_qec.items():
    yago_qec[k]["entity_uri_to_degree"] = [
        int(entity_uri_to_degree[uri]) for uri in yago_qec[k]["entity_uris"]
    ]

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

k, sorted(
    list(zip(yago_qec[k]["entity_uri_to_degree"], yago_qec[k]["entities"])),
    key=lambda x: x[0],
    reverse=True,
)[:10]

('reverse-http://yago-knowledge.org/resource/terminus',
 [(36564, 'London'),
  (25899, 'Paris'),
  (25899, 'Paris'),
  (25899, 'Paris'),
  (25899, 'Paris'),
  (25899, 'Paris'),
  (21919, 'Los Angeles'),
  (21919, 'Los Angeles'),
  (21919, 'Los Angeles'),
  (15445, 'Moscow')])

In [26]:
# Load yago entity namesake to number of different uris
ENTITY_NAMESAKE_TO_NUM_URIS_PATH = os.path.join(
    DATA_ROOT, "entity_namesake_to_num_uris.json"
)
with open(ENTITY_NAMESAKE_TO_NUM_URIS_PATH, "r") as fp:
    entity_namesake_to_num_uris = json.load(fp)

for k, v in yago_qec.items():
    yago_qec[k]["entity_namesake_to_num_uris"] = [
        int(entity_namesake_to_num_uris[entity]) for entity in yago_qec[k]["entities"]
    ]

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

k, sorted(
    list(zip(yago_qec[k]["entity_namesake_to_num_uris"], yago_qec[k]["entities"])),
    key=lambda x: x[0],
    reverse=True,
)[:10]

('reverse-http://yago-knowledge.org/resource/terminus',
 [(81, 'Victoria'),
  (70, 'Lincoln'),
  (64, 'Mount Pleasant'),
  (62, 'Richmond'),
  (58, 'Paris'),
  (58, 'Paris'),
  (58, 'Paris'),
  (58, 'Paris'),
  (58, 'Paris'),
  (56, 'Newton')])

In [27]:
# Load yago entity namesake to number of different uris
with open(ENTITY_URI_TO_PREDICATE_DEGREE_PATH, "r") as fp:
    entity_uri_to_predicate_degree = json.load(fp)

for k, v in yago_qec.items():
    yago_qec[k]["entity_uri_to_predicate_degree"] = [
        int(entity_uri_to_predicate_degree[k][uri])
        for uri in yago_qec[k]["entity_uris"]
    ]

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

k, sorted(
    list(zip(yago_qec[k]["entity_uri_to_predicate_degree"], yago_qec[k]["entities"])),
    key=lambda x: x[0],
    reverse=True,
)[:10], len(set(yago_qec[k]["entities"]))

('reverse-http://yago-knowledge.org/resource/terminus',
 [(21, 'Nebraska'),
  (21, 'Nebraska'),
  (21, 'Nebraska'),
  (21, 'Nebraska'),
  (21, 'Nebraska'),
  (20, 'Beijing'),
  (20, 'Beijing'),
  (20, 'Beijing'),
  (19, 'Paris'),
  (19, 'Paris')],
 899)

### 5. Augment each query in yago_qec with fake entities (gpt-4 generated)

Be sure to first run `gpt_generate_entities.ipynb` to generate the `chatgpt_fake_entities_all.csv` which is then combined into yago_qec.

In [28]:
##################################################
# 5. Combine chatgpt fake entities with yago_qec #
##################################################
set_seed(0)
YAGO_GPT_FAKE_ENTITIES_PATH = os.path.join(DATA_ROOT, "chatgpt_fake_entities_all.csv")

fake_entities_gpt_df = pd.read_csv(YAGO_GPT_FAKE_ENTITIES_PATH)
fake_entities_gpt_df = fake_entities_gpt_df.add_prefix("http://schema.org/")

# Load yago_qec
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)

# Randomly sample fake entities that are eligible according to entity type for each relation and save to yago_qec
for k, v in yago_qec.items():
    entity_types: Dict[str, float] = yago_qec[k]["entity_types"]
    if set(entity_types).issubset(set(fake_entities_gpt_df.columns)):
        total_num_entities = len(yago_qec[k]["entities"])
        # print(k, total_num_entities)
        ents_per_type: Dict[str, int] = {
            et: round(total_num_entities * frac) for et, frac in entity_types.items()
        }

        # Ensure the number of entities across entity types matches the total number of entities
        diff = total_num_entities - Counter(ents_per_type).total()
        if diff:
            biggest_et = max(ents_per_type, key=ents_per_type.get)
            print(f"Increasing count of {biggest_et} for query {k} by {diff}")
            ents_per_type[biggest_et] += diff
        assert Counter(ents_per_type).total() == total_num_entities

        yago_qec[k]["gpt_fake_entities"] = list(
            itertools.chain.from_iterable(
                [
                    fake_entities_gpt_df[et].sample(n=count)
                    for et, count in ents_per_type.items()
                ]
            )
        )  # sample desired number of entities from each category

yago_qec = {k: v for k, v in yago_qec.items() if "gpt_fake_entities" in v}

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

Increasing count of http://schema.org/CreativeWork for query http://schema.org/about by 1
Increasing count of http://schema.org/Organization for query http://schema.org/humanDevelopmentIndex by 1
Increasing count of http://schema.org/Organization for query http://schema.org/unemploymentRate by -1
Increasing count of http://schema.org/Person for query reverse-http://schema.org/author by -1
Increasing count of http://schema.org/Organization for query reverse-http://schema.org/ownedBy by -1


In [29]:
yago_qec["http://schema.org/affiliation"]["gpt_fake_entities"][:10]

['Andrej Pešić',
 'Ngozi Abubakar',
 'Philippe Giroux',
 'Emilija Petrović',
 'Lindiwe Mkhize',
 'Hakan Çelik',
 'Kashvi Singh',
 'Orlando Ferrara',
 'Assia Mokhtar',
 'Eytan Zamir']

In [30]:
# Get list of all gpt fake entities
YAGO_GPT_FAKE_ENTITIES_PATH = os.path.join(DATA_ROOT, "chatgpt_fake_entities_all.csv")
YAGO_GPT_FAKE_ENTITIES_LIST_PATH = os.path.join(
    DATA_ROOT, "chatgpt_fake_entities_list.json"
)
fake_entities_gpt_df = pd.read_csv(YAGO_GPT_FAKE_ENTITIES_PATH)
fake_entities = fake_entities_gpt_df.values.ravel().tolist()
# Save yago_qec including fake entities
with open(YAGO_GPT_FAKE_ENTITIES_LIST_PATH, "w", encoding="utf-8") as fp:
    json.dump(fake_entities, fp, ensure_ascii=False, indent=4)
len(set(fake_entities))

6999

In [31]:
len(yago_qec)

125

In [32]:
{k for k, v in yago_qec.items() if "http://schema.org/Intangible" in v["entity_types"]}

set()

### 6. Filter out obviously bad entities

In [33]:
# Filter out entities which have URIs as the entity name
import re

uri_regex = r"http(s*):"

# Load yago_qec
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)

for query_id, v in yago_qec.items():
    og_len = len(v["entities"])
    inds_to_delete = [
        i
        for i, (e, a) in enumerate(zip(v["entities"], v["answers"]))
        if re.search(uri_regex, e) or re.search(uri_regex, a)
    ]
    keys_to_delete_from = [
        k for k, vals in v.items() if len(vals) == len(v["entities"])
    ]
    for k in keys_to_delete_from:
        new_list = [val for ind, val in enumerate(v[k]) if ind not in inds_to_delete]
        v[k] = new_list
    print(query_id, og_len, len(inds_to_delete), len(v[k]))

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

http://schema.org/about 1000 46 954
http://schema.org/actor 1000 0 1000
http://schema.org/address 1000 0 1000
http://schema.org/administrates 1000 2 998
http://schema.org/affiliation 1000 6 994
http://schema.org/alumniOf 1000 0 1000
http://schema.org/author 1000 0 1000
http://schema.org/award 1000 334 666
http://schema.org/birthDate 1000 0 1000
http://schema.org/birthPlace 1000 0 1000
http://schema.org/children 1000 0 1000
http://schema.org/contentLocation 1000 0 1000
http://schema.org/dateCreated 1000 0 1000
http://schema.org/deathDate 1000 0 1000
http://schema.org/deathPlace 1000 0 1000
http://schema.org/demonym 1000 0 1000
http://schema.org/director 1000 0 1000
http://schema.org/dissolutionDate 1000 0 1000
http://schema.org/duns 119 0 119
http://schema.org/duration 1000 0 1000
http://schema.org/editor 378 0 378
http://schema.org/elevation 1000 0 1000
http://schema.org/endDate 1000 0 1000
http://schema.org/founder 1000 0 1000
http://schema.org/gtin 21 0 21
http://schema.org/highestPo

### Appendix: Load manual country-capital examples as for comparing/debugging

In [34]:
# Debug by adding country capital examples into yago_qec too
# Load yago_qec
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)

df = pd.read_csv("../../data/CountryCapital/real-fake-historical-country-capital.csv")
countries = df[df["type"] == "countryCapital"]["country"].tolist()
capitals = df[df["type"] == "countryCapital"]["capital"].tolist()
my_fake_countries = df[df["type"] == "fakeCountryCapital"]["country"].tolist()
my_fake_capitals = df[df["type"] == "fakeCountryCapital"]["capital"].tolist()

yago_qec["http://yago-knowledge.org/resource/capital"]["my_famous_entities"] = countries
# yago_qec["http://yago-knowledge.org/resource/capital"]["famous_answers"] = capitals
yago_qec["http://yago-knowledge.org/resource/capital"][
    "my_fake_entities"
] = my_fake_countries
# yago_qec["http://yago-knowledge.org/resource/capital"]["my_fake_answers"] = my_fake_capitals

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding="utf-8") as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4, sort_keys=True)

### Final: Log yago_qec.json artifact to wandb

In [35]:
import os
import wandb

# wandb stuff
PROJECT_NAME = "context-vs-bias"
GROUP_NAME = None
TAGS = ["yago", "generate_data"]
LOG_DATASETS = True
DATA_ROOT = "../../data/YagoECQ/"
YAGO_QEC_PATH = os.path.join(DATA_ROOT, "yago_qec.json")

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), "yago_generate_qec.ipynb")

params_to_log = {k: v for k, v in locals().items() if k.isupper()}

run = wandb.init(
    project=PROJECT_NAME,
    group=GROUP_NAME,
    config=params_to_log,
    tags=TAGS,
    mode="online",
)
print(dict(wandb.config))
artifact = wandb.Artifact(name=f"YagoECQ-yago_qec", type="dataset")
artifact.add_file(local_path=YAGO_QEC_PATH)
wandb.run.log_artifact(artifact)

[34m[1mwandb[0m: Currently logged in as: [33mkdu[0m ([33methz-rycolab[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'DATA_ROOT': '../../data/YagoECQ/', 'PRED_URI_TO_S_OBJ_CLASSES_PATH': '../../data/YagoECQ/yago_pred_uri_to_s_obj_classes.json', 'PRED_URI_TO_SO_PAIRS_PATH': '../../data/YagoECQ/yago_pred_uri_to_so_pairs_randomized_1k.json', 'YAGO_QEC_PATH': '../../data/YagoECQ/yago_qec.json', 'TRY_QUERYING_MISSING_PREDS': False, 'ENTITY_NAME_TO_POSSIBLE_ENTITY_URIS_PATH': '../../data/YagoECQ/entity_name_to_possible_entity_uris.json', 'ENTITY_URI_TO_DEGREE_PATH': '../../data/YagoECQ/entity_uri_to_degree.json', 'ENTITY_URI_TO_DEGREE_INCLUDING_AMBIGUOUS_ENTITIES_PATH': '../../data/YagoECQ/entity_uri_to_degree_including_ambiguous_entities.json', 'ENTITY_URI_TO_PREDICATE_DEGREE_PATH': '../../data/YagoECQ/entity_uri_to_predicate_degree_path.json', 'ENTITY_NAMESAKE_TO_DEGREE_PATH': '../../data/YagoECQ/entity_namesake_to_degree.json', 'ENTITY_NAMESAKE_TO_NUM_URIS_PATH': '../../data/YagoECQ/entity_namesake_to_num_uris.json', 'ENTITIES_PATH': '../../data/YagoECQ/entities.json', 'ANSWERS_PATH': '../../data/YagoE

<Artifact YagoECQ-yago_qec>