In [1]:
!pwd

/home/kevin/code/rycolab/measureLM/data/YagoECQ


In [2]:
from typing import List, Dict, Tuple
from SPARQLWrapper import SPARQLWrapper2
from SPARQLWrapper.SPARQLExceptions import EndPointInternalError
from collections import namedtuple
import re
from time import sleep
from urllib.error import HTTPError
import os
import json
from tqdm import tqdm
import time
# from preprocessing.utils import extract_name_from_yago_uri

In [3]:
Predicate = namedtuple("Predicate", ["uri", "kb_name", "relation"])
DATA_ROOT = "."
PRED_URI_TO_S_OBJ_CLASSES_PATH = os.path.join(DATA_ROOT, "yago_pred_uri_to_s_obj_classes.json")
PRED_URI_TO_SO_PAIRS_PATH = os.path.join(DATA_ROOT, "yago_pred_uri_to_so_pairs_randomized_1k.json")
YAGO_QEC_PATH = os.path.join(DATA_ROOT, "yago_qec.json")
TRY_QUERYING_MISSING_PREDS = True

In [4]:
def extract_name_from_yago_uri(uri: str):
    is_reversed = False
    if uri.startswith("reverse-"):
        uri = uri.split("reverse-")[1]
        is_reversed = True
    pattern = r"http://(?:www\.)?([^\/]+)\.org\/(.+)$"
    matches = re.match(pattern, uri)

    if matches:
        kb_domain = matches.group(1)
        relation = matches.group(2)
    else:
        raise ValueError(f"Could not find match containing kb_domain and relation for uri {uri}.")

    domain_to_name = {"schema": "schema", "yago-knowledge": "yago", "w3": "w3"}
    kb_name = domain_to_name[kb_domain]
    kb_name = ("reverse-" if is_reversed else "") + kb_name

    return kb_name, relation


In [5]:
##################################
# 1. Get all relevant predicates #
##################################
sparql = SPARQLWrapper2("https://yago-knowledge.org/sparql/query")
query_p = """
    SELECT DISTINCT ?p WHERE {
     ?s ?p ?obj . 
    }  ORDER BY ?p
"""

# Sparql query
sparql.setQuery(query_p)

# Adding values
relevant_preds: List[Predicate] = []
ineligible_relations = ["schema#fromClass", "schema#fromProperty", "logo", "image"]

for result in sparql.query().bindings:
    uri = result["p"].value
    kb_name, relation = extract_name_from_yago_uri(uri)
    if kb_name != "w3" and relation not in ineligible_relations:
        relevant_preds.append(Predicate(uri=uri, kb_name=kb_name, relation=relation))

In [6]:
###########################################################
# 1b. Get all subject and object types for each predicate #
###########################################################
from SPARQLWrapper.SPARQLExceptions import QueryBadFormed
query_single = """
    SELECT DISTINCT ?s ?obj WHERE {{
     ?s <{}> ?obj . 
    }}  LIMIT 10
"""

query_superclasses = """
    SELECT DISTINCT ?superclasses WHERE {{
     <{}> rdf:type/rdfs:subClassOf* ?superclasses .
    }}  
"""

types = [
    "http://schema.org/CreativeWork",
    "http://schema.org/Event",
    "http://schema.org/Intangible",
    "http://schema.org/Organization",
    "http://schema.org/Person",
    "http://schema.org/Place",
    "http://schema.org/Product",
    "http://schema.org/Taxon",
    "http://schema.org/FictionalEntity",
]
if os.path.exists(PRED_URI_TO_S_OBJ_CLASSES_PATH):
    print(f"Loading pred_to_s_and_obj_types from file {PRED_URI_TO_S_OBJ_CLASSES_PATH}.")
    with open(PRED_URI_TO_S_OBJ_CLASSES_PATH) as f:
        pred_to_s_and_obj_types_with_reverse = json.load(f)
else:
    pred_to_s_and_obj_types = dict()
    for pred in relevant_preds:
        concrete_query_single = query_single.format(pred.uri)
        # print(query)
        sparql.setQuery(concrete_query_single)
        superclasses_s = []
        superclasses_obj = []
        for result in sparql.query().bindings:
            s = result["s"].value
            obj = result["obj"].value
            try:
                concrete_query_superclasses_s = query_superclasses.format(s)
                sparql.setQuery(concrete_query_superclasses_s)
                superclasses_s += [x["superclasses"].value for x in sparql.query().bindings]
            except QueryBadFormed as e:
                print("Raising error for s:", s)
                # raise ValueError
            try:
                concrete_query_superclasses_obj = query_superclasses.format(obj)
                sparql.setQuery(concrete_query_superclasses_obj)
                superclasses_obj += [x["superclasses"].value for x in sparql.query().bindings]
            except QueryBadFormed:
                print("Raising error for obj:", obj)
                # raise ValueError
            
        pred_to_s_and_obj_types[pred] = (set(superclasses_s), set(superclasses_obj))
            
    pred_to_s_and_obj_types = {k.uri: (list(s_set.intersection(types)), list(obj_set.intersection(types))) for k, (s_set, obj_set) in pred_to_s_and_obj_types.items()}
    pred_to_s_and_obj_types_with_reverse = {**pred_to_s_and_obj_types, **{f"reverse-{k}": (obj, s) for k, (s, obj) in pred_to_s_and_obj_types.items()}}

Raising error for obj: 20, place des Terreaux
Raising error for obj: Neuschwansteinstraße 20
Raising error for obj: Viale del Fante 11, I-90146 Palermo
Raising error for obj: The University of Hong Kong, Pokfulam, Hong Kong
Raising error for obj: Piazza Giuseppe Mazzini, 21
Raising error for obj: Statene 2, 5970 Ærøskøbing
Raising error for obj: Strada Comunale Santa Margherita, 79
Raising error for obj: Strada Santa Margherita 79, 10131 Torino
Raising error for obj: Ареспублика Мали
Raising error for obj: Ареспублика Сиерра-Леоне
Raising error for obj: Ареспублика Судан
Raising error for obj: Китаитәи Жәлар рреспублика
Raising error for obj: اللوغة الإنڨليزية
Raising error for obj: Andrei Dmitrijewitsj Sacharof
Raising error for obj: Point(-1.0833333333333 53.958333333333)
Raising error for obj: Point(-1.6808333333333 48.114166666667)
Raising error for obj: Point(-106 34)
Raising error for obj: Point(-106 34)
Raising error for obj: Point(-113.5 53.533333333333)
Raising error for obj: 

In [6]:
#######################################################################
# 2. Extract all subject-object pairs for each pred in relevant preds #
#######################################################################
def get_so_pairs_for_pred(pred: Predicate) -> List[Tuple[str, str, str, str]]:
    start_time = time.time()
    print(f"Querying {pred.uri}.")
    so_pairs = []

    query = """
SELECT DISTINCT ?output_s ?output_obj ?s ?obj WHERE {{
  ?s <{}> ?obj . # which predicate to use
  OPTIONAL {{ 
    ?s rdfs:label ?s_label .
    FILTER (LANG(?s_label) = 'en')
  }} # Get the label (name) for the subject if it exists
  BIND(COALESCE(?s_label, ?s) AS ?output_s) # if the label does not exist, stick with the URI
  
  OPTIONAL {{
    ?obj rdfs:label ?obj_label . 
    FILTER (LANG(?obj_label) = 'en')
  }} # Get the label (name) for the object if it exists
  OPTIONAL {{
    ?obj rdf:type ?obj_type . 
    ?obj_type rdfs:label ?obj_type_name .
    FILTER (LANG(?obj_type_name) = 'en')
  }} # get the name of the type for the object if the type exists
  BIND(COALESCE(IF(STR(?obj_label) != "Generic instance", ?obj_label, ?obj_type_name), ?obj) AS ?output_obj) 
  # if the label exists, go with the label, but if it's "Generic instance", then go with the type; if that does not exist, then stick with the OG object.
  BIND(MD5(CONCAT(STR(?s), STR(?obj))) AS ?sortkey) .
}}
ORDER BY ?sortkey
LIMIT 1000
""".format(
        pred.uri
    )

    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            so_pairs.append((result["output_s"].value, result["output_obj"].value, result["s"].value, result["obj"].value))
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {pred.uri}. Trying chained query.")
        so_pairs = get_so_pairs_for_pred_chained(pred)
    finally:
        time_elapsed = time.time() - start_time
        print(f"Time elapsed: {time_elapsed}")
        return so_pairs
    
def get_so_pairs_for_pred_chained(pred: Predicate) -> List[Tuple[str, str, str, str]]:
    start_time = time.time()
    print(f"Querying {pred.uri} (chained).")
    so_pairs = []

    query = """
SELECT DISTINCT ?s ?o WHERE {{
    ?s <{}> ?o
    BIND(MD5(CONCAT(STR(?s), STR(?o))) AS ?sortkey) .
}}
ORDER BY ?sortkey
LIMIT 1000
""".format(
        pred.uri
    )
    
    query_get_name_of_obj = """
SELECT DISTINCT ?output_obj WHERE {{
    OPTIONAL {{
      <{0}> rdfs:label ?obj_label . 
      FILTER (LANG(?obj_label) = 'en')
    }} # Get the label (name) for the object if it exists
    OPTIONAL {{
      <{0}> rdf:type ?obj_type . 
      ?obj_type rdfs:label ?obj_type_name .
      FILTER (LANG(?obj_type_name) = 'en')
    }} # get the name of the type for the object if the type exists
    BIND(COALESCE(IF(STR(?obj_label) != "Generic instance", ?obj_label, ?obj_type_name), <{0}>) AS ?output_obj) 
}}  
"""

    try:
        sparql.setQuery(query)
        for result in sparql.query().bindings:
            # print(result)
            s_uri = result["s"].value
            o_uri = result["o"].value
            query_s = query_get_name_of_obj.format(s_uri)            
            sparql.setQuery(query_s)
            res_s = sparql.query().bindings
            if len(res_s) != 1:
                raise ValueError(f">1 label returned for subject {s_uri}: {res_s}")
            s_name = res_s[0]["output_obj"].value

            query_obj = query_get_name_of_obj.format(o_uri)
            sparql.setQuery(query_obj)
            res_obj = sparql.query().bindings
            if len(res_obj) != 1:
                raise ValueError(f">1 label returned for subject {o_uri}: {res_obj}")
            obj_name = res_obj[0]["output_obj"].value

            so_pairs.append((s_name, obj_name, s_uri, o_uri))
        
    # sleep(1)
    except (HTTPError, EndPointInternalError) as e:
        print(f"HTTPerror for uri {pred.uri} when chaining. Skipping.")
        so_pairs = None
    except Error as e:
        print(e)
    finally:
        time_elapsed = time.time() - start_time
        print(f"Time elapsed: {time_elapsed}")
        return so_pairs

In [8]:
if os.path.exists(PRED_URI_TO_SO_PAIRS_PATH):
    print(f"Loading pred_uri_to_so_pairs from file {PRED_URI_TO_SO_PAIRS_PATH}.")
    with open(PRED_URI_TO_SO_PAIRS_PATH) as f:
        pred_uri_to_so_pairs = json.load(f)
else:
    pred_to_so_pairs: Dict[Predicate, List[Tuple[str, str, str, str]]] = {
        pred: get_so_pairs_for_pred(pred) for pred in tqdm(relevant_preds)
    }
    pred_to_so_pairs = {k: v for k, v in pred_to_so_pairs.items() if v is not None}
    pred_uri_to_so_pairs = {
        k.uri: v for k, v in pred_to_so_pairs.items() if v is not None
    }

if TRY_QUERYING_MISSING_PREDS:
    missing_preds = [p for p in relevant_preds if p.uri not in pred_uri_to_so_pairs]
    print("Missing preds:", missing_preds)
    missing_pred_to_so_pairs: Dict[Predicate, List[Tuple[str, str, str, str]]] = {
        pred: get_so_pairs_for_pred(pred) for pred in tqdm(missing_preds)
    }
    missing_pred_to_so_pairs = {k: v for k, v in missing_pred_to_so_pairs.items() if v is not None}
    missing_pred_uri_to_so_pairs = {
        k.uri: v for k, v in missing_pred_to_so_pairs.items() if v is not None
    }

    pred_uri_to_so_pairs = {**missing_pred_uri_to_so_pairs, **pred_uri_to_so_pairs}

with open(PRED_URI_TO_SO_PAIRS_PATH, "w", encoding='utf-8') as fp:
    json.dump(pred_uri_to_so_pairs, fp, ensure_ascii=False, indent=4)

def augment_pred_uri_to_so_pairs_with_reverse(pred_uri_to_so_pairs):
    return {
        **pred_uri_to_so_pairs,
        **{
            f"reverse-{k}": [(a_label, e_label, a_uri, e_uri) for (e_label, a_label, e_uri, a_uri) in v]
            for k, v in pred_uri_to_so_pairs.items()
        },
    }

pred_uri_to_so_pairs_with_reverse = augment_pred_uri_to_so_pairs_with_reverse(
    pred_uri_to_so_pairs
)

  0%|          | 0/1 [00:00<?, ?it/s]

Querying http://schema.org/about.


100%|██████████| 1/1 [00:13<00:00, 13.02s/it]


Time elapsed: 13.01859974861145
Missing preds: []


0it [00:00, ?it/s]


In [10]:
########################################################################################################
# 4. Construct queries containing entities, corresponding answers, query forms, and context templates. #
########################################################################################################
from yago_questions import yago_topic_to_qfs

keys = set(yago_topic_to_qfs).intersection(set(pred_uri_to_so_pairs_with_reverse))
print(keys)
yago_qec = {
    k: {
        "query_forms": yago_topic_to_qfs[k],
        "entities": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[0],
        "answers": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[1],
        "entity_uris": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[2],
        "answer_uris": list(zip(*pred_uri_to_so_pairs_with_reverse[k]))[3],
        "context_templates": [yago_topic_to_qfs[k]["open"][-1] + " {answer}.\n"],
        "entity_types": pred_to_s_and_obj_types_with_reverse[k][0], 
        "answer_types": pred_to_s_and_obj_types_with_reverse[k][1], 
    }
    for k in keys
}
yago_qec
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

{'http://schema.org/about'}


In [None]:
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec_reloaded = json.load(fp)

In [None]:
for qid, v in yago_qec.items():
    print(qid, len(v["entities"]), len(set(v["entities"])))

In [None]:
pd.DataFrame(
    yago_qec["reverse-http://schema.org/homeLocation"]["entities"]
).value_counts()

In [None]:
# yago_qec_reloaded

In [None]:
s = [x for x in yago_qec_reloaded['http://schema.org/icaoCode']["entities"] if "Egilssta" in x][0]
print(s, type(s))

In [None]:
#################################
# Get the degree of each entity #
#################################
entity_name_to_degree = dict()


In [5]:
##############################
# 5. Construct fake entities #
##############################
# This section requires GPUs #
##############################

# Can be run separate from previous section #
from transformers import ReformerModelWithLMHead
import json
import re
import os
import itertools
from transformers import ReformerModelWithLMHead

import torch
import random
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoding
def encode(list_of_strings, pad_token_id=0, device="cpu"):
    max_length = max([len(string) for string in list_of_strings])

    # create emtpy tensors
    attention_masks = torch.zeros((len(list_of_strings), max_length), dtype=torch.long)
    input_ids = torch.full((len(list_of_strings), max_length), pad_token_id, dtype=torch.long)

    for idx, string in enumerate(list_of_strings):
        # make sure string is in byte format
        if not isinstance(string, bytes):
            string = str.encode(string)

        input_ids[idx, :len(string)] = torch.tensor([x + 2 for x in string])
        attention_masks[idx, :len(string)] = 1

    return input_ids.to(device), attention_masks.to(device)
    
# Decoding
def decode(outputs_ids):
    decoded_outputs = []
    for output_ids in outputs_ids.tolist():
        # transform id back to char IDs < 2 are simply transformed to ""
        decoded_outputs.append("".join([chr(x - 2) if x > 1 else "" for x in output_ids]))
    return decoded_outputs

def set_seed(seed=0):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
def extract_entity(text):
    regex_pattern = r'\[\[([^\[|\]]+?)(?:\([^)]*\))?(?:\|[^|\]]+)*\]\]'
    match = re.search(regex_pattern, text)

    if match:
        result = match.group(1).strip()
        return result
    
    return None

In [6]:
from collections import defaultdict
from typing import Dict, Set

with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)
entity_types_to_real_entities: Dict[str, Set[str]] = defaultdict(set)
for query_id, qec in yago_qec.items():
    for et in qec["entity_types"]:
        entity_types_to_real_entities[et] = entity_types_to_real_entities[et].union(qec["entities"])
print(json.dumps({k: len(v) for k, v in entity_types_to_real_entities.items()}, indent=4))

real_entities = set(itertools.chain.from_iterable([v["entities"] for _, v in yago_qec.items()]))
print(len(real_entities))

{
    "http://schema.org/Person": 17352,
    "http://schema.org/Product": 10715,
    "http://schema.org/CreativeWork": 13410,
    "http://schema.org/Place": 12723,
    "http://schema.org/Organization": 14291,
    "http://schema.org/Event": 1990
}
45054


In [7]:
entity_types_to_prompts = {
    "http://schema.org/CreativeWork": "The creative work called [[",
    "http://schema.org/Event": "The event called [[",
    "http://schema.org/Intangible": "The concept called [[",
    "http://schema.org/Organization": "The organization is called [[",
    "http://schema.org/Person": "The person named [[",
    "http://schema.org/Place": "The place is named [[",
    "http://schema.org/Product": "The product is called [[",
    "http://schema.org/Taxon": "The taxon named [[",
    "http://schema.org/FictionalEntity": "The fictional entity named [[",
}
entity_types_to_fake_entities = {}
model = ReformerModelWithLMHead.from_pretrained("google/reformer-enwik8").to(device)

In [8]:
for et, p in entity_types_to_prompts.items():
    set_seed(0)
    print(f"***{et}***")
    encoded, attention_masks = encode([p], device=device)
    res = decode(model.generate(encoded, do_sample=True, num_return_sequences=1200, max_length=100))

    extracted_res = {extract_entity(s) for s in res if extract_entity(s) is not None}
    print(f"# unique fake ents: {len(extracted_res)}")
    extracted_res_without_reals = extracted_res.difference(real_entities)
    print(f"# unique fake ents removing reals: {len(extracted_res_without_reals)}")
    extracted_res = random.sample(list(extracted_res), 1000)
    entity_types_to_fake_entities[et] = extracted_res

***http://schema.org/CreativeWork***
# unique fake ents: 1148
# unique fake ents removing reals: 1126
***http://schema.org/Event***
# unique fake ents: 1159
# unique fake ents removing reals: 1141
***http://schema.org/Intangible***
# unique fake ents: 1071
# unique fake ents removing reals: 1068
***http://schema.org/Organization***
# unique fake ents: 1039
# unique fake ents removing reals: 1017
***http://schema.org/Person***
# unique fake ents: 1177
# unique fake ents removing reals: 1152
***http://schema.org/Place***
# unique fake ents: 1124
# unique fake ents removing reals: 1082
***http://schema.org/Product***
# unique fake ents: 1109
# unique fake ents removing reals: 1097
***http://schema.org/Taxon***
# unique fake ents: 1147
# unique fake ents removing reals: 1126
***http://schema.org/FictionalEntity***
# unique fake ents: 1149
# unique fake ents removing reals: 1118


In [9]:
YAGO_FAKE_ENTITIES_PATH = os.path.join(DATA_ROOT, "fake_entities.json") 

# Load yago_qec
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)

# Save fake entities per each entity type.
with open(YAGO_FAKE_ENTITIES_PATH, "w", encoding='utf-8') as fp:
    json.dump(entity_types_to_fake_entities, fp, ensure_ascii=False, indent=4)

# Randomly sample fake entities that are eligible according to entity type for each relation and save to yago_qec
for k, v in yago_qec.items():
    entity_types = yago_qec[k]["entity_types"]
    eligible_fake_entities = list(itertools.chain.from_iterable([entity_types_to_fake_entities[et] for et in entity_types]))
    yago_qec[k]["fake_entities"] = random.sample(eligible_fake_entities, len(yago_qec[k]["entities"]))

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)

In [7]:
###########################################
# 6. Construct fake entities with chatgpt #
###########################################
import pandas as pd
import random
YAGO_GPT_FAKE_ENTITIES_PATH = os.path.join(DATA_ROOT, "chatgpt_fake_entities.csv") 

fake_entities_gpt = pd.read_csv(YAGO_GPT_FAKE_ENTITIES_PATH)
fake_entities_gpt = set(fake_entities_gpt["FirstName"] + " " + fake_entities_gpt["LastName"])

# Load yago_qec
with open(YAGO_QEC_PATH, "r") as fp:
    yago_qec = json.load(fp)

# Randomly sample fake entities that are eligible according to entity type for each relation and save to yago_qec
for k, v in yago_qec.items():
    # entity_types = yago_qec[k]["entity_types"]
    # eligible_fake_entities = list(itertools.chain.from_iterable([entity_types_to_fake_entities[et] for et in entity_types]))
    eligible_fake_entities = list(fake_entities_gpt)
    yago_qec[k]["gpt_fake_entities"] = random.sample(eligible_fake_entities, min(len(yago_qec[k]["entities"]), len(eligible_fake_entities)))

# Save yago_qec including fake entities
with open(YAGO_QEC_PATH, "w", encoding='utf-8') as fp:
    json.dump(yago_qec, fp, ensure_ascii=False, indent=4)