In [25]:
import csv
import json
import re
import pandas as pd
import numpy as np
import time
import os
import random
import http.client
from datetime import datetime
from collections import Counter
from SPARQLWrapper import SPARQLWrapper, JSON
from SPARQLWrapper.SPARQLExceptions import QueryBadFormed, EndPointInternalError, EndPointNotFound, Unauthorized, URITooLong, SPARQLWrapperException
from urllib.error import HTTPError

In [10]:
sparql_query_cache = dict()

In [26]:
def convert_date_string(date_string):
    pattern = r"^(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z$"
    match = re.match(pattern, date_string)
    if match:
        year, month, day, hour, minute, second = match.groups()
        date = datetime(int(year), int(month), int(day))
        month_name = date.strftime("%B")
        new_date_string = f"{day} {month_name} {year}"
        return new_date_string
    else:
        return None

def get_splits(triples, splits = [0.4, 0.3, 0.3]):
    triples = np.array(triples)
    indices = np.random.permutation(triples.shape[0])
    train_count = int(triples.shape[0] * splits[0])
    val_count = int(triples.shape[0] * splits[1])
    test_count = triples.shape[0] - train_count - val_count
    train_triples = triples[indices[:train_count]]
    val_triples = triples[indices[train_count:train_count+val_count]]
    test_triples = triples[indices[train_count+val_count:]]
    return train_triples.tolist(), val_triples.tolist(), test_triples.tolist()

def save_triples(onto_id, train_all, val_all, test_all):
    # Define base paths for different directories
    base_paths = {
        'train': "../../data/wikidata_tekgen/train",
        'validation': "../../data/wikidata_tekgen/validation",
        'ground_truth': "../../data/wikidata_tekgen/ground_truth",
        'test': "../../data/wikidata_tekgen/test"
    }
    
    # Ensure all required directories exist
    for path in base_paths.values():
        ensure_directory_exists(path)
    
    # Save train data
    with open(f"{base_paths['train']}/{onto_id}_train.jsonl", "w") as out_file:
        for idx, tr in enumerate(train_all):
            data = {"id": f"{onto_id}_train_{idx+1}", "sub_label": tr[0], "rel_label": tr[1], "obj_label": tr[2], 
                   "sent": tr[6], "sub": tr[3], "rel": tr[4], "obj": tr[5]}
            out_file.write(f"{json.dumps(data)}\n")
    
    # Save validation data        
    with open(f"{base_paths['validation']}/{onto_id}_validation.jsonl", "w") as out_file:
        for idx, tr in enumerate(val_all):
            data = {"id": f"{onto_id}_val_{idx+1}", "sub_label": tr[0], "rel_label": tr[1], "obj_label": tr[2], 
                   "sent": tr[6], "sub": tr[3], "rel": tr[4], "obj": tr[5]}
            out_file.write(f"{json.dumps(data)}\n")

    ''' # Not extracted correctly, eval code from original repository requires different format, ground truth in original repository also different format
    # Save ground truth data        
    with open(f"{base_paths['ground_truth']}/{onto_id}_ground_truth.jsonl", "w") as out_file:
        for idx, tr in enumerate(test_all):
            data = {"id": f"{onto_id}_test_{idx+1}", "sub_label": tr[0], "rel_label": tr[1], "obj_label": tr[2], 
                   "sent": tr[6], "sub": tr[3], "rel": tr[4], "obj": tr[5]}
            out_file.write(f"{json.dumps(data)}\n")
    '''

    # Save ground truth data        
    with open(f"{base_paths['ground_truth']}/{onto_id}_ground_truth.jsonl", "w") as out_file:
        for idx, tr in enumerate(test_all):
            triples = [{"sub": tr[0], "rel": tr[1], "obj": tr[2]}]
            data = {"id": f"{onto_id}_test_{idx+1}", "sent": tr[6], "triples": triples}
            out_file.write(f"{json.dumps(data)}\n")

    # Save test data        
    with open(f"{base_paths['test']}/{onto_id}_test.jsonl", "w") as out_file:
        for idx, tr in enumerate(test_all):
            data = {"id": f"{onto_id}_test_{idx+1}", "sent": tr[6]}
            out_file.write(f"{json.dumps(data)}\n")

def get_triples_with_sentences(relation_pid: str, relation_label: str, rel_domain: str, rel_range: str,
                               limit: int = 200, max_retries: int = 10):
    assert relation_pid, "relation id can't be empty"
    assert rel_domain, "domain can't be empty"

    current_limit = 10000  # Start with a high limit for SPARQL query
    retries = 0
    # Set the User-Agent according to Wikidata's policy
    user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11'

    while retries < max_retries:
        try:
            # Build the SPARQL query
            sparql = SPARQLWrapper("https://query.wikidata.org/sparql", agent=user_agent)
            query = "PREFIX wdt: <http://www.wikidata.org/prop/direct/> \n PREFIX wd: <http://www.wikidata.org/entity/> \n"
            query += "SELECT DISTINCT ?sub ?subEntity ?objEntity ?objLabel { \n ?subEntity wdt:P31/wdt:P279* wd:" + rel_domain + " . \n"
            query += '?subEntity rdfs:label ?sub . FILTER (lang(?sub) = "en") \n '
            query += '?subEntity wdt:' + relation_pid + ' ?objEntity . \n'
            if rel_range and rel_range != "":
                query += '?objEntity wdt:P31*/wdt:P279* wd:' + rel_range + ' . \n '
            query += 'OPTIONAL { ?objEntity rdfs:label ?objLabel . FILTER (lang(?objLabel) = "en") } \n } '
            # Set the dynamic LIMIT
            query += f"LIMIT {current_limit}"
            if show_query:
                print(query)

            if query in sparql_query_cache:
                # Use cached results if available
                triples = sparql_query_cache[query]
            else:
                # Execute the query and get a set of triples
                triples = list()
                subject_counter, object_counter = Counter(), Counter()
                secondary_triples = list()
                sparql.setQuery(query)
                sparql.setReturnFormat(JSON)
                sparql.setTimeout(300)  # Set timeout to 5 minutes
                sparql.setMethod('POST')

                # Set the User-Agent
                sparql.agent = user_agent

                # Attempt to execute the query
                response = sparql.query()
                results = response.convert()

                print(f'  {len(results["results"]["bindings"])} SPARQL results.')
                for result in results["results"]["bindings"]:
                    t_subject = result['sub']['value']
                    if 'objLabel' in result:
                        t_object = result['objLabel']['value']
                        t_object_id = result['objEntity']['value'].replace("http://www.wikidata.org/entity/", "")
                    else:
                        t_object = result['objEntity']['value']
                        date_string = convert_date_string(t_object)
                        if date_string:
                            t_object = date_string
                        t_object_id = None
                    t_subject_id = result['subEntity']['value'].replace("http://www.wikidata.org/entity/", "")
                    triple = [t_subject, relation_label, t_object, t_subject_id, relation_pid, t_object_id]
                    # To get a diverse dataset, ignore subject/object if they occur more than 10% of the limit
                    subject_counter[t_subject] += 1
                    object_counter[t_object] += 1
                    if subject_counter[t_subject] > (limit / 10) or object_counter[t_object] > (limit / 10):
                        secondary_triples.append(triple)
                        continue
                    triples.append(triple)

                # Append secondary triples
                triples += secondary_triples
                sparql_query_cache[query] = triples

            print(f"  collected {len(triples)} triples")
            if show_sample:
                print(f"  sample:")
                for tr in triples[:5]:
                    print(f"    {tr[:3]}")

            triples_with_sentences = list()
            for tr in triples:
                search_key = create_key(tr[0], tr[1], tr[2])
                if search_key in sent_index:
                    sentence = sent_index[search_key]
                else:
                    continue
                tr.append(sentence)
                triples_with_sentences.append(tr)

                # Stop at the desired limit
                if len(triples_with_sentences) >= limit:
                    break

            # If successful, return the collected triples with sentences
            return triples_with_sentences

        except SPARQLWrapperException as e:
            retries += 1
            code = None
            reason = str(e)

            # Try to parse the HTTP status code from the exception message
            match = re.search(r'status code (\d+)', reason, re.IGNORECASE)
            if match:
                code = int(match.group(1))

            if code == 429:
                # HTTP 429: Too Many Requests
                retry_after = '60'  # Default wait time
                wait_time = int(retry_after)
                print(f"HTTP 429 error encountered. Waiting for {wait_time} seconds before retrying ({retries}/{max_retries})...")
                time.sleep(wait_time)
            elif code == 500:
                print(f"HTTP 500 error encountered on attempt {retries}/{max_retries}. Reducing LIMIT and retrying...")
                current_limit = max(10, current_limit // 2)
                #time.sleep(5)
            else:
                print(f"HTTP error {code} encountered: {reason}. Retrying attempt {retries}/{max_retries} after short wait...")
                time.sleep(5)
        except (http.client.IncompleteRead, json.JSONDecodeError) as e:
            retries += 1
            print(f"An error occurred: {e}. Retrying attempt {retries}/{max_retries} after short wait...")
            # Reduce the LIMIT and retry
            current_limit = max(10, current_limit // 2)
            print(f"Reducing LIMIT to {current_limit} and retrying...")
            #time.sleep(5)
        except Exception as e:
            retries += 1
            print(f"An error occurred: {e}. Retrying attempt {retries}/{max_retries} after short wait...")
            #time.sleep(5)

    print("Max retries reached. Skipping this relation.")
    return []

def create_key(sub_label, rel_label, obj_label):
    # remove spaces and make lower case
    sub_label = re.sub(r"\s+", '', sub_label).lower()
    rel_label = re.sub(r"\s+", '', rel_label).lower()
    obj_label = re.sub(r"\s+", '', obj_label).lower()
    # concatanate them 
    tr_key = f"{sub_label}{rel_label}{obj_label}"
    return tr_key

def ensure_directory_exists(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [4]:
# Load the TekGen corpus
sent_index = dict()
start_time = time.time()
print("TekGen corpus processing started!")
with open('../../tekgen.csv') as csv_in_file:
    sent_reader = csv.reader(csv_in_file)
    next(sent_reader)
    for row in sent_reader:
        tr_key = create_key(row[0], row[1], row[2])
        sent = row[4]
        sent_index[tr_key] = sent
        elapsed_time = (time.time()-start_time)/60
    print(f"\ttriple-to-sent index with {len(sent_index)} triples loaded in {elapsed_time:.2f} mins!")

TekGen corpus processing started!
	triple-to-sent index with 11358950 triples loaded in 2.20 mins!


In [27]:
# Update base path to use existing ontology directory
base_path = '../../data/wikidata_tekgen/ontologies'

# Check if directory exists, if not create it
if not os.path.exists(base_path):
    os.makedirs(base_path)
    error_message = """
ERROR: Ontology files not found!
Please copy the original ontology files from the Text2KGBench repository
""".format(base_path)
    raise FileNotFoundError(error_message)

# Check if directory is empty or missing ontology files
ontology_files = [f for f in os.listdir(base_path) if f.endswith('_ontology.json')]
if not ontology_files:
    error_message = """
ERROR: Ontology files not found!
Please copy the original ontology files from the Text2KGBench repository
""".format(base_path)
    raise FileNotFoundError(error_message)

# Load existing ontologies
ontologies = []
for filename in os.listdir(base_path):
    if filename.endswith('_ontology.json'):
        with open(os.path.join(base_path, filename)) as in_file:
            ontologies.append(json.load(in_file))
show_sample = True
show_query = False

for onto in ontologies:
    print(f"Ontology: {onto['title']} ({onto['id']})")
    onto_id = onto['id']
    train_all, val_all, test_all = [], [],[]
    for rel in onto['relations']:
        print(f"\nprocessing \"{rel['label']}\" ({rel['pid']}) relation:")
        start_time = time.time()
        triples_with_sentences = get_triples_with_sentences(rel['pid'], rel['label'], rel['domain'], rel['range'], 200)
        elapsed_time = (time.time()-start_time)
        print(f"    {len(triples_with_sentences)} triples with sentences in {elapsed_time:.2f} seconds!")
        train, val, test = get_splits(triples_with_sentences)
        train_all += train
        val_all += val
        test_all += test
    save_triples(onto_id, train_all, val_all, test_all)

Ontology: Culture Ontology (ont_10_culture)

processing "ethnic group" (P172) relation:
  collected 10000 triples
  sample:
    ['Mohd Noh Rajab', 'ethnic group', 'Malays']
    ['Yusril Ihza Mahendra', 'ethnic group', 'Malays']
    ['Lucas Fernandez de Piedrahita', 'ethnic group', 'Quechua people']
    ['Eva Copa', 'ethnic group', 'Aymara']
    ['Aru Apaza', 'ethnic group', 'Aymara']
    70 triples with sentences in 0.08 seconds!

processing "religious order" (P611) relation:
  collected 0 triples
  sample:
    0 triples with sentences in 0.00 seconds!

processing "languages spoken, written or signed" (P1412) relation:
  collected 10000 triples
  sample:
    ['Chang Chih-Chia', 'languages spoken, written or signed', 'Standard Taiwanese Mandarin']
    ['Ang Ui-jin', 'languages spoken, written or signed', 'Standard Taiwanese Mandarin']
    ['Tsai Pei-huo', 'languages spoken, written or signed', 'Standard Taiwanese Mandarin']
    ['Wang Chien-shien', 'languages spoken, written or signed',