In [1]:
import csv
import json
import re
import pandas as pd
import numpy as np
import time
import os
import random
from datetime import datetime
from collections import Counter
from SPARQLWrapper import SPARQLWrapper, JSON

In [2]:
sparql_query_cache = dict()

In [3]:
def convert_date_string(date_string):
    pattern = r"^(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z$"
    match = re.match(pattern, date_string)
    if match:
        year, month, day, hour, minute, second = match.groups()
        date = datetime(int(year), int(month), int(day))
        month_name = date.strftime("%B")
        new_date_string = f"{day} {month_name} {year}"
        return new_date_string
    else:
        return None

def get_splits(triples, splits = [0.4, 0.3, 0.3]):
    triples = np.array(triples)
    indices = np.random.permutation(triples.shape[0])
    train_count = int(triples.shape[0] * splits[0])
    val_count = int(triples.shape[0] * splits[1])
    test_count = triples.shape[0] - train_count - val_count
    train_triples = triples[indices[:train_count]]
    val_triples = triples[indices[train_count:train_count+val_count]]
    test_triples = triples[indices[train_count+val_count:]]
    return train_triples.tolist(), val_triples.tolist(), test_triples.tolist()

def save_triples(onto_id, train_all, val_all, test_all):
    # Define base paths for different directories
    base_paths = {
        'train': "../../data/wikidata_tekgen/train",
        'validation': "../../data/wikidata_tekgen/validation",
        'ground_truth': "../../data/wikidata_tekgen/ground_truth",
        'test': "../../data/wikidata_tekgen/test"
    }
    
    # Ensure all required directories exist
    for path in base_paths.values():
        ensure_directory_exists(path)
    
    # Save train data
    with open(f"{base_paths['train']}/{onto_id}_train.jsonl", "w") as out_file:
        for idx, tr in enumerate(train_all):
            data = {"id": f"{onto_id}_train_{idx+1}", "sub_label": tr[0], "rel_label": tr[1], "obj_label": tr[2], 
                   "sent": tr[6], "sub": tr[3], "rel": tr[4], "obj": tr[5]}
            out_file.write(f"{json.dumps(data)}\n")
    
    # Save validation data        
    with open(f"{base_paths['validation']}/{onto_id}_validation.jsonl", "w") as out_file:
        for idx, tr in enumerate(val_all):
            data = {"id": f"{onto_id}_val_{idx+1}", "sub_label": tr[0], "rel_label": tr[1], "obj_label": tr[2], 
                   "sent": tr[6], "sub": tr[3], "rel": tr[4], "obj": tr[5]}
            out_file.write(f"{json.dumps(data)}\n")
    
    # Save ground truth data        
    with open(f"{base_paths['ground_truth']}/{onto_id}_ground_truth.jsonl", "w") as out_file:
        for idx, tr in enumerate(test_all):
            data = {"id": f"{onto_id}_test_{idx+1}", "sub_label": tr[0], "rel_label": tr[1], "obj_label": tr[2], 
                   "sent": tr[6], "sub": tr[3], "rel": tr[4], "obj": tr[5]}
            out_file.write(f"{json.dumps(data)}\n")
    
    # Save test data        
    with open(f"{base_paths['test']}/{onto_id}_test.jsonl", "w") as out_file:
        for idx, tr in enumerate(test_all):
            data = {"id": f"{onto_id}_test_{idx+1}", "sent": tr[6]}
            out_file.write(f"{json.dumps(data)}\n")

def execute_query_with_retries(sparql, max_retries=3, wait_time=60):
    retries = 0
    while retries < max_retries:
        try:
            results = sparql.query().convert()
            return results
        except Exception as e:
            print(f"Attempt {retries+1} failed with error: {e}")
            retries += 1
            if retries < max_retries:
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print("Max retries reached. Moving on.")
    return None

def get_triples_with_sentences(relation_pid: str, relation_label: str, rel_domain: str, rel_range: str, limit: int = 200):
    assert relation_pid, "relation id can't be empty"
    assert rel_domain, "domain can't be empty"

    # Build the SPARQL query
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql", agent='MyTool/1.0 (https://mytool.example.com)')
    query = "PREFIX wdt: <http://www.wikidata.org/prop/direct/> \n PREFIX wd: <http://www.wikidata.org/entity/> \n"
    query += "SELECT DISTINCT ?sub ?subEntity ?objEntity ?objLabel { \n ?subEntity wdt:P31/wdt:P279* wd:" + rel_domain + " . \n"
    query += '?subEntity rdfs:label ?sub . FILTER (lang(?sub) = "en") \n '
    query += '?subEntity wdt:' + relation_pid + ' ?objEntity . \n'
    if rel_range and rel_range != "":
        query += '?objEntity wdt:P31*/wdt:P279* wd:' + rel_range + ' . \n '
    query += 'OPTIONAL { ?objEntity rdfs:label ?objLabel . FILTER (lang(?objLabel) = "en") } \n } '
    # We get more results to filter later
    query += f"LIMIT 10000"
    if show_query:
        print(query)
        
    if query in sparql_query_cache:
        triples = sparql_query_cache[query] 
    else:
        # Execute the query and get a set of triples
        triples = list()
        subject_counter, object_counter = Counter(), Counter()
        secondary_triples = list()
        sparql.setQuery(query)
        sparql.setReturnFormat(JSON)
        sparql.setTimeout(300)  # Set timeout to 5 minutes
        sparql.setMethod('POST')  # Use POST request
        sparql.addCustomHttpHeader('User-Agent', 'MyTool/1.0 (https://mytool.example.com)')  # Custom User-Agent

        results = execute_query_with_retries(sparql)
        if results is None:
            return []  # Return empty list if query failed after retries

        print(f'  {len(results["results"]["bindings"])} SPARQL results.')
        for result in results["results"]["bindings"]:
            t_subject = result['sub']['value']
            if 'objLabel' in result:
                t_object = result['objLabel']['value']
                t_object_id = result['objEntity']['value'].replace("http://www.wikidata.org/entity/","")
            else:
                t_object = result['objEntity']['value']
                date_string = convert_date_string(t_object)
                if date_string:
                    t_object = date_string
                t_object_id = None
            t_subject_id = result['subEntity']['value'].replace("http://www.wikidata.org/entity/","")
            triple = [t_subject, relation_label, t_object, t_subject_id, relation_pid, t_object_id]    
            # To get a diverse dataset, ignore subject/object if they occur more than 10% of the limit
            subject_counter[t_subject] += 1
            object_counter[t_object] += 1
            if subject_counter[t_subject] > (limit / 10) or object_counter[t_object] > (limit / 10):
                secondary_triples.append(triple)
                continue
            triples.append(triple)

        # Append secondary triples
        triples += secondary_triples
        sparql_query_cache[query] = triples
        
    print(f"  collected {len(triples)} triples")
    if show_sample:
        print(f"  sample:")
        for tr in triples[:5]:
            print(f"    {tr[:3]}")
        
    triples_with_sentences = list()
    for tr in triples:
        search_key = create_key(tr[0],tr[1], tr[2])
        if search_key in sent_index:
            sentence = sent_index[search_key] 
        else:
            continue
        tr.append(sentence)
        triples_with_sentences.append(tr)
        
        # Once we actually check for sentences, we will stop at the limit
        if len(triples_with_sentences) >= limit:
            break
            
    return triples_with_sentences

    # columns = ["subject", "relation", "object", "subject_entity", "object_entity", "sentence"]
    # df = pd.DataFrame(triples_with_sentences, columns=columns)
    # return df

def create_key(sub_label, rel_label, obj_label):
    # remove spaces and make lower case
    sub_label = re.sub(r"\s+", '', sub_label).lower()
    rel_label = re.sub(r"\s+", '', rel_label).lower()
    obj_label = re.sub(r"\s+", '', obj_label).lower()
    # concatanate them 
    tr_key = f"{sub_label}{rel_label}{obj_label}"
    return tr_key

def ensure_directory_exists(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [4]:
# Load the TekGen corpus
sent_index = dict()
start_time = time.time()
print("TekGen corpus processing started!")
with open('../../tekgen.csv') as csv_in_file:
    sent_reader = csv.reader(csv_in_file)
    next(sent_reader)
    for row in sent_reader:
        tr_key = create_key(row[0], row[1], row[2])
        sent = row[4]
        sent_index[tr_key] = sent
        elapsed_time = (time.time()-start_time)/60
    print(f"\ttriple-to-sent index with {len(sent_index)} triples loaded in {elapsed_time:.2f} mins!")

TekGen corpus processing started!
	triple-to-sent index with 11358950 triples loaded in 1.28 mins!


In [5]:
# Update base path to use existing ontology directory
base_path = '../../data/wikidata_tekgen/ontologies'

# Load existing ontologies
ontologies = []
for filename in os.listdir(base_path):
    if filename.endswith('_ontology.json'):
        with open(os.path.join(base_path, filename)) as in_file:
            ontologies.append(json.load(in_file))
show_sample = True
show_query = False

for onto in ontologies:
    print(f"Ontology: {onto['title']} ({onto['id']})")
    onto_id = onto['id']
    train_all, val_all, test_all = [], [],[]
    for rel in onto['relations']:
        print(f"\nprocessing \"{rel['label']}\" ({rel['pid']}) relation:")
        start_time = time.time()
        triples_with_sentences = get_triples_with_sentences(rel['pid'], rel['label'], rel['domain'], rel['range'], 200)
        elapsed_time = (time.time()-start_time)
        print(f"    {len(triples_with_sentences)} triples with sentences in {elapsed_time:.2f} seconds!")
        train, val, test = get_splits(triples_with_sentences)
        train_all += train
        val_all += val
        test_all += test
    save_triples(onto_id, train_all, val_all, test_all)

Ontology: Culture Ontology (ont_10_culture)

processing "ethnic group" (P172) relation:
  10000 SPARQL results.
  collected 10000 triples
  sample:
    ['Mohd Noh Rajab', 'ethnic group', 'Malays']
    ['Yusril Ihza Mahendra', 'ethnic group', 'Malays']
    ['Nazario Turpo', 'ethnic group', 'Quechua']
    ['Melania Canales Poma', 'ethnic group', 'Quechua']
    ['Elwin Huaman', 'ethnic group', 'Quechua']
    44 triples with sentences in 7.70 seconds!

processing "religious order" (P611) relation:
  0 SPARQL results.
  collected 0 triples
  sample:
    0 triples with sentences in 2.04 seconds!

processing "languages spoken, written or signed" (P1412) relation:
  10000 SPARQL results.
  collected 10000 triples
  sample:
    ['Hope Su', 'languages spoken, written or signed', 'Standard Taiwanese Mandarin']
    ['Chang Chih-Chia', 'languages spoken, written or signed', 'Standard Taiwanese Mandarin']
    ['Ang Ui-jin', 'languages spoken, written or signed', 'Standard Taiwanese Mandarin']
    ['