<a href="https://colab.research.google.com/github/mille-s/Build_KGs_entities/blob/main/DBpedia_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Clones, Installs and functions
from IPython.display import clear_output, HTML, display
import os
! pip install SPARQLWrapper
! pip install colored
! pip install xmltodict

# Clone Build_KGs_entities repo
! git clone https://github.com/mille-s/Build_KGs_entities.git
# Delete locally to avoid confusion
! rm '/content/Build_KGs_entities/DBpedia_dataset.ipynb'

# clone wikipedia page generator repo
! git clone https://github.com/mille-s/WikipediaPage_Generator.git
# Delete locally to avoid confusion
! rm 'WikipediaPage_Generator/Wikipedia_generator.ipynb'

# clone dcu_tcd_webnlg repo
! git clone https://github.com/mille-s/DCU_TCD_FORGe_WebNLG23.git
# Delete locally to avoid confusion
! rm 'DCU_TCD_FORGe_WebNLG23/DCU_TCD_FORGe_WebNLG23.ipynb'

props_list_path = os.path.join('/content', 'DCU_TCD_FORGe_WebNLG23', 'code', 'sorted_properties.txt')

triple2predArg = os.path.join('/content', 'XML')
os.makedirs(triple2predArg)

def aggregate_info_and_get_propLabel(ontology_properties):
  dico_properties = {}
  for prop in ontology_properties:
    # Get raw labels without the url part
    _ , property_no_prefix = os.path.split(prop['property'])
    domain_no_prefix = 'Unknown'
    if not prop['domain'] == 'Unknown':
      _ , domain_no_prefix = os.path.split(prop['domain'])
    range_no_prefix = 'Unknown'
    if not prop['range'] == 'Unknown':
      _ , range_no_prefix = os.path.split(prop['range'])
    # The first time a property is found, create a dico entry with domain and range info
    if property_no_prefix not in dico_properties.keys():
      dico_properties[property_no_prefix] = {'domain': [domain_no_prefix], 'range': [range_no_prefix]}
    # The second time, only append the domain and range if they haven't been seen to this point
    else:
      if domain_no_prefix not in dico_properties[property_no_prefix]['domain']:
        dico_properties[property_no_prefix]['domain'].append(domain_no_prefix)
      if range_no_prefix not in dico_properties[property_no_prefix]['range']:
        dico_properties[property_no_prefix]['range'].append(range_no_prefix)
  return dico_properties

clear_output()
print('Working folder ready!\n--------------')

# Preliminary work: get info to build datasets

In [None]:
#@title SPARQL query to get all properties in DBpedia (from ChatGPT)

from SPARQLWrapper import SPARQLWrapper, JSON

def getDBpediaProperties():
  # Set up the SPARQL endpoint
  sparql = SPARQLWrapper("http://dbpedia.org/sparql")

  # Define the SPARQL query to retrieve properties with the prefix "http://dbpedia.org/ontology"
  query = """
  SELECT DISTINCT ?property
  WHERE {
    ?property a rdf:Property .
    FILTER(STRSTARTS(STR(?property), "http://dbpedia.org/ontology/"))
  }
  ORDER BY ?property
  """

  # Set the query
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)

  try:
      # Execute the query
      results = sparql.query().convert()

      # Process and display the results
      properties = [result["property"]["value"] for result in results["results"]["bindings"]]
      print(f"Number of properties used in DBpedia: {str(len(properties))}")
      # for prop in properties:
      #     print(prop)

  except Exception as e:
      print(f"An error occurred: {e}")

  return properties

list_properties = getDBpediaProperties()

In [None]:
#@title SPARQL query to get the number of instances of each property
from SPARQLWrapper import SPARQLWrapper, JSON
import os
import json

def getNumInstancesProperty(property_label):
  # lowercase first character
  head, tail = os.path.split(property_label)
  lowCase_tail = tail[0].lower() + tail[1:]
  lowCase_property_label = os.path.join(head, lowCase_tail)

  # Set up the SPARQL endpoint
  sparql = SPARQLWrapper("http://dbpedia.org/sparql")

  # Define the SPARQL query to count the number of instances of the property 'dbo:birthDate'
  query = f"""
  SELECT (COUNT(*) AS ?count)
  WHERE {{
    ?subject <{lowCase_property_label}> ?object .
  }}
  """

  # Set the query
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)

  try:
      # Execute the query
      results = sparql.query().convert()

      # Extract and print the count
      count = results["results"]["bindings"][0]["count"]["value"]
      # print(f"Number of instances of {lowCase_property_label}: {count}")

  except Exception as e:
      print(f"An error occurred: {e}")

  return(lowCase_property_label, count)

def createDicoCountOccurrenceProperties(list_properties):
  dico_count_occurrences = {}
  for i, property_label in enumerate(list_properties):
    lowCase_property_label, count = getNumInstancesProperty(property_label)
    dico_count_occurrences[lowCase_property_label] = int(count)
    print(f'{str(i)}/{str(len(list_properties))}: {property_label} = {count}')

  sorted_dico_count_occurrences = {k: v for k, v in sorted(dico_count_occurrences.items(), key=lambda item: item[1], reverse=True)}
  with open("dico_count_occurrences_dbp_props.json", "w") as outfile:
      json.dump(sorted_dico_count_occurrences, outfile)

createDicoCountOccurrenceProperties(list_properties)

In [None]:
#@title SPARQL query for all properties getting domain/range class
from SPARQLWrapper import SPARQLWrapper, JSON

# Set up the DBpedia SPARQL endpoint
sparql = SPARQLWrapper("http://dbpedia.org/sparql")
sparql.setReturnFormat(JSON)

# SPARQL query to select ontology properties with domain and range
# This is supposed to return all properties, but a lot seem to be missing, not sure why
query = """
PREFIX dbo: <http://dbpedia.org/ontology/>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>

SELECT ?property ?domain ?range WHERE {
    ?property a rdf:Property .
    FILTER(STRSTARTS(STR(?property), "http://dbpedia.org/ontology/")) .
    OPTIONAL { ?property rdfs:domain ?domain . }
    OPTIONAL { ?property rdfs:range ?range . }
}
LIMIT 50000  # Increase this limit if needed
"""

# Run the query
sparql.setQuery(query)
results = sparql.query().convert()

# Extract and display the results
ontology_properties = []
for result in results["results"]["bindings"]:
    property_uri = result["property"]["value"]
    domain = result.get("domain", {}).get("value", "Unknown")
    range_class = result.get("range", {}).get("value", "Unknown")
    ontology_properties.append({
        "property": property_uri,
        "domain": domain,
        "range": range_class
    })

# Print or process the results
for prop in ontology_properties[:100]:
    print(f"Property: {prop['property']}")
    print(f"  Domain: {prop['domain']}")
    print(f"  Range: {prop['range']}")
    print()




In [None]:
#@title SPARQL query for selected properties getting domain/range class
from SPARQLWrapper import SPARQLWrapper, JSON

WebNLG_properties_list = ['http://dbpedia.org/ontology/'+line.strip() for line in codecs.open('/content/all_properties.txt', 'r', 'utf-8').readlines()]

def get_domain_range(properties):
    domain_range_info = []

    for property_uri in properties:
        print(f'Cheking property {property_uri}')
        # SPARQL query to get domain and range for the specific property
        query = f"""
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>

        SELECT ?domain ?range WHERE {{
            <{property_uri}> a rdf:Property .
            OPTIONAL {{ <{property_uri}> rdfs:domain ?domain . }}
            OPTIONAL {{ <{property_uri}> rdfs:range ?range . }}
        }}
        """

        # Run the query
        sparql.setQuery(query)
        results = sparql.query().convert()

        # Extract domain and range information from the results
        for result in results["results"]["bindings"]:
            domain = result.get("domain", {}).get("value", "Unknown")
            range_class = result.get("range", {}).get("value", "Unknown")
            domain_range_info.append({
                "property": property_uri,
                "domain": domain,
                "range": range_class
            })

    return domain_range_info

# Retrieve domain and range for each property in the list
ontology_properties = get_domain_range(WebNLG_properties_list)

In [None]:
#@title Aggregate possible domain/ranges for each property
import os
import codecs

dico_properties = aggregate_info_and_get_propLabel(ontology_properties)

# for prop in properties_info:
#     print(f"Property: {prop['property']}", f"  Domain: {prop['domain']}", f"  Range: {prop['range']}")

# print(len(dico_properties.keys()))
# print(len(properties_info))

list_properties = [line.strip() for line in codecs.open('/content/all_properties.txt', 'r', 'utf-8').readlines()]
missing_props = []
for WebNLG_property in list_properties:
  if WebNLG_property not in dico_properties.keys():
    missing_props.append(WebNLG_property)

print('Missing properties: '+str(len(missing_props))+'/'+str(len(list_properties)), missing_props)
# We need to check the Original property labels, not the modified ones, that way we should get them all


In [None]:
#@title Get examples for properties
# Get a list of sample subject-object values given an input list of properties.

from SPARQLWrapper import SPARQLWrapper, JSON
import pandas as pd
import csv
import json

num_examples_desired = 30#@param

def get_dbpedia_property_examples(property_uri, dataframe_examples, count_props, num_occ_uri, limit=10):
  sparql = SPARQLWrapper("https://dbpedia.org/sparql")
  query = f"""
  SELECT DISTINCT ?subject ?object WHERE {{
    ?subject <{property_uri}> ?object .
  }} LIMIT {limit}
  """
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)
  results = sparql.query().convert()

  count_examples = 0
  for result in results["results"]["bindings"]:
    subject_uri = result["subject"]["value"]
    object_value = result["object"]["value"]

    # Extract a readable name from the subject and object URI
    subject_label = ""
    if subject_uri.startswith("http://dbpedia.org/resource/"):
      subject_label = subject_uri.split("/")[-1].replace("_", " ")
    else:
      subject_label = subject_uri
    object_label = ""
    if object_value.startswith("http://dbpedia.org/resource/"):
      object_label = object_value.split("/")[-1].replace("_", " ")
    else:
      object_label = object_value

    prop_label = property_uri.split("/")[-1]  # e.g., 'birthDate'
    dataframe_examples.loc[count_props*limit+count_examples] = [count_props, num_occ_uri, subject_label, prop_label, object_label]
    count_examples += 1

# properties_uri = ["http://dbpedia.org/ontology/birthDate", "http://dbpedia.org/ontology/birthPlace"]
list_props_dico = json.load(open('/content/dico_count_occurrences_dbp_props.json', 'r'))
# Get list of URIs for which there is at least 10 occurrences of the property (to filter out possibly bad properties)
properties_uri = []
# Also get list of the number of occurrences of each URI to store in the final CSV
num_occ_uris = []
for prop in list_props_dico.keys():
  if list_props_dico[prop] >= 10:
    properties_uri.append(prop)
    num_occ_uris.append(list_props_dico[prop])
assert len(properties_uri) == len(num_occ_uris)
print(f'Number of properties with at least 10 occurrences: {len(properties_uri)}')

# Create dataframe
dataframe_examples = pd.DataFrame(columns=["id", "num-occurrences", "Subject", "Property", "Object"])
# Example usage:
for counter_uris, property_uri in enumerate(properties_uri):
  print(f'{str(counter_uris)}/{str(len(properties_uri))}: {property_uri}...')
  num_occ_uri = num_occ_uris[counter_uris]
  get_dbpedia_property_examples(property_uri, dataframe_examples, counter_uris, num_occ_uri, int(num_examples_desired))

# Save dataframe as CSV
dataframe_examples.to_csv('dbp_props_examples.csv', index=False)

print(dataframe_examples)

In [None]:
#@title Process CSV created from annotated examples for properties and get lists of properties to ignore and that can happen only once.
# The produced files can be found in the GitHub WikipediaPage_Generator repo.
# I needed to extract which properties can happen multiple times for a given subject and which can't, as annotated manually in a spreadsheet.
# Guidelines for annotating a property are "is it confusing if ever a text contains two values for that property". E.g. "X was born in Finland and Britain" is weird, but "X is the sister of Y and Z" is not.
# It doesn't necessarily have to do with factual truth, e.g. it is almost certain that a ID for an entity in a database is unique, but it's not shocking to say "X has IDs A and B in database W".
# On the other hand, some cases with multiple object values are annotated as having a single value to avoid poorly entered DBpedia values: e.g. locations (frequently, people put as 3 values a town, the state, the country), subclasses/types (one value is more natural in a text), etc.

import pandas as pd
import csv
import json

# Load the CSV file into a DataFrame
# CSV file in C:\Users\sfmil\Desktop\ADAPT-2025-2026\MyPapers\2025-06_INLG-longText-D2T
df = pd.read_csv('/content/dbp_props_examples_annotated.csv')

# Make a new dataframe with the first 1000 rows of df
# df_1000 = df.head(1000)

list_props_that_can_happen_once_only = []
list_props_that_can_happen_more_than_once = []
list_props_to_filter = []
# Iterate over the rows of the dataframe
current_property = None
for index, df_row in df.iterrows():
  if df_row['Property'] != current_property:
    current_property = df_row['Property']
    current_num_possible_values = df_row['num-possible\nvalues']
    current_id = df_row['id']
    # print(f"Index: {index} - Subject: {df_row['Subject']} - Property: {df_row['Property']} - Object: {df_row['Object']}")
    if current_num_possible_values == '1':
      list_props_that_can_happen_once_only.append(current_property)
      # print(f'ID: {str(current_id).zfill(4)} - Status: {current_num_possible_values} {current_property}')
    elif current_num_possible_values == 'Multiple':
      list_props_that_can_happen_more_than_once.append(current_property)
      # print(f'ID: {str(current_id).zfill(4)} - Status: M {current_property}')
    # else is "?" or N/A
    elif current_num_possible_values == '?' or current_num_possible_values == 'IgnoreProp':
      list_props_to_filter.append(current_property)
      # print(f'ID: {str(current_id).zfill(4)} - Status: - {current_property}')
    else:
      print('ERROR! Unexpected value')
      break

print(sorted(list_props_that_can_happen_once_only))
print(len(list_props_that_can_happen_once_only))
print()
print(sorted(list_props_that_can_happen_more_than_once))
print(len(list_props_that_can_happen_more_than_once))
print()
print(sorted(list_props_to_filter))
print(len(list_props_to_filter))

print(f'There are {len(list_props_that_can_happen_once_only)+len(list_props_that_can_happen_more_than_once)+len(list_props_to_filter)} properties collected (expected: 1208).')

# Save lists as json
with open("list_props_that_can_happen_once_only.json", "w") as outfile:
    json.dump(sorted(list_props_that_can_happen_once_only), outfile)
# with open("list_props_that_can_happen_more_than_once.json", "w") as outfile:
#     json.dump(list_props_that_can_happen_more_than_once, outfile)
with open("list_props_to_filter.json", "w") as outfile:
    json.dump(sorted(list_props_to_filter), outfile)

In [None]:
#@title SPARQL query for finding all the possible values for gold:hypernym on dbpedia
# Could use dbo:type or rdf:type, but both look a bit messy at first sight
from SPARQLWrapper import SPARQLWrapper, JSON

def get_types():
  # Define the SPARQL endpoint
  sparql = SPARQLWrapper("https://dbpedia.org/sparql")

  # Define the SPARQL query
  query = """
  SELECT DISTINCT ?type
  WHERE {
    ?s gold:hypernym ?type .
  }
  """

  # Set the query and the return format
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)

  # Execute the query and retrieve results
  # Returns 10k results, and running it several times in a row always returns the same 10k results
  results = sparql.query().convert()

  return results["results"]["bindings"]

def count_entities_of_type(hypernym_type):
  # Define the DBpedia SPARQL endpoint
  sparql = SPARQLWrapper("https://dbpedia.org/sparql")

  # Define the SPARQL query
  query = f"""
  SELECT (COUNT(?s) AS ?count)
  WHERE {{
      ?s gold:hypernym <{hypernym_type}> .
  }}
  """

  # Set the query and the return format
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)

  # Execute the query and retrieve results
  results = sparql.query().convert()

  # Extract and return the count
  count = results["results"]["bindings"][0]["count"]["value"]
  return int(count)

dico_hypernym_types = {}
# Extract and print the types
results_types = get_types()
for i, result in enumerate(results_types):
  hypernym_url = result['type']['value']
  if hypernym_url not in dico_hypernym_types.keys():
    count_occurrences = count_entities_of_type(hypernym_url)
    dico_hypernym_types[hypernym_url] = count_occurrences
    print(f'{str(i)}/{str(len(results_types))}: {hypernym_url} = {count_occurrences}')


In [None]:
#@title Save dico_hypernym_types to a json and download
import json
from google.colab import files

with open("dico_hypernym_types.json", "w") as outfile:
  sorted_dico_hypernym_types = {k: v for k, v in sorted(dico_hypernym_types.items(), key=lambda item: item[1], reverse=True)}

  json.dump(sorted_dico_hypernym_types, outfile)

# Download
files.download('dico_hypernym_types.json')

In [None]:
#@title Make a list of random entities to query for each hypernym (SPARQL query for getting n entities that have a specific hypernym)
import random
import json
from SPARQLWrapper import SPARQLWrapper, JSON

def get_random_entities(hypernym: str, limit):
  sparql = SPARQLWrapper("http://dbpedia.org/sparql")

  query = f"""
  SELECT DISTINCT ?entity WHERE {{
      ?entity gold:hypernym <{hypernym}> .
  }} LIMIT 50000
  """

  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)
  results = sparql.query().convert()

  entities = [result['entity']['value'] for result in results['results']['bindings']]
  return random.sample(entities, min(len(entities), limit))

# Load json that contains hypernyms as keys and count of instances of that hypernym on DBpedia as value
dico_hypernym_types = None
with open('/content/dico_hypernym_types_incomplete.json', 'r') as file:
    dico_hypernym_types = json.load(file)

# Get up to 1000 random entities for the classes that have at least 100 members
dico_hypernym_sample_entities = {}
for i, hypernym in enumerate(dico_hypernym_types.keys()):
  if i < 5:
    if dico_hypernym_types[hypernym] >= 100:
      print(hypernym, dico_hypernym_types[hypernym])
      dico_hypernym_sample_entities[hypernym] = get_random_entities(hypernym, 10)

In [None]:
#@title Save dico_hypernym_sample_entities as JSON and download
from google.colab import files
with open("dico_hypernym_sample_entities.json", "w") as outfile:
    json.dump(dico_hypernym_sample_entities, outfile)
# Dowload json
files.download('dico_hypernym_sample_entities.json')

# Build dataset

## Get properties for list of entities
Creates dico_input_contents_DBp.pickle file.
Skip if you want to use an already generated pickle file.

In [None]:
# @title Get DBpedia properties online for an entity list (about 1h for GREC entities)
import os
import codecs
import json
import re
import ipywidgets as widgets
from ipywidgets import Layout
from WikipediaPage_Generator.code.queryDBpediaProps import get_dbpedia_properties
from WikipediaPage_Generator.code.utils import removeReservedCharsFileName

# Input json should be a dico_1 with category names (urls or name) as keys, and a list of entities (urls or names) as value.
input_json_path = '/content/Build_KGs_entities/resources/GREC_NE.json'#@param{type:"string"}
# triple-source should be Ontology for this experiment
triple_source = 'Ontology' #@param['Infobox', 'Ontology', 'Wikidata']
# Store here "dirty" properties; string expected, it is later on split by ','
# ignore_properties = 'width,title' # Used for semantic accuracy experiments with Rudali
ignore_properties = ','.join(json.loads(open('/content/WikipediaPage_Generator/resources/list_props_to_filter.json', 'r').read()))
get_triples_where_entity_is_subj = True #@param {type:"boolean"}
get_triples_where_entity_is_obj = True #@param {type:"boolean"}

# Load json dico with sample entities for each hypernym
dico_hypernym_sample_entities_loaded = None
with open(input_json_path, 'r') as file:
    dico_hypernym_sample_entities_loaded = json.load(file)

# dico_input_contents will contain category keys, which contain entity keys, which contain a list of triple objects
dico_input_contents = {}
for hypernym in sorted(dico_hypernym_sample_entities_loaded.keys()):
  input_category = None
  if re.search('/', hypernym):
    input_category = hypernym.rsplit('/', 1)[1]
  else:
    input_category = hypernym
  print(input_category)
  dico_input_contents[input_category] = {}
  # Format properties for passing as argument to python module
  # list_triple_object contains object with 3 attributes: DBsubj, DBprop, DBobj
  # list_propObj is used for UI (for triples selection by the user)
  # list_obj is used for getting class and gender info later on
  for sampled_entity in sorted(dico_hypernym_sample_entities_loaded[hypernym]):
    entity_name = None
    if re.search('/', sampled_entity):
      entity_name = sampled_entity.rsplit('/', 1)[1]
    else:
      entity_name = '_'.join(sampled_entity.split(' '))

    # Get all triples in which the entity is the subject
    list_triple_objects, list_propObj, list_obj = get_dbpedia_properties(props_list_path, entity_name, triple_source, ignore_properties, get_triples_where_entity_is_subj, get_triples_where_entity_is_obj)

    if len(list_triple_objects) > 0:
      print(f'  {entity_name}: found {len(list_triple_objects)} properties.')
      dico_input_contents[input_category][entity_name] = list_triple_objects

In [None]:
#@title Serialise dico_input_contents using pickle and download
import pickle
from google.colab import files

with open("dico_input_contents_DBp.pickle", "wb") as handle:
    pickle.dump(dico_input_contents, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Download file
files.download('dico_input_contents_DBp.pickle')

## Build knowledge graphs #1: WebNLG mirror input configuration

In [None]:
#@title Check which entities have property (sub)sets that match WebNLG inputs. Creates dico_entities_for_triple_configuration.json file.
import json
import pickle

dico_input_contents = 'GitHub_GREC_NEs'#@param['GitHub_GREC_NEs', 'Made_with_this_notebook']
print_output = False #@param{type:'boolean'}

# dico_input_contents_loaded should be a dico_1 with category names as keys, and a dico_2 as value.
# dico_2 should have entity_name as key, and a list of Triple Objects as value.
# Triple Objects have 3 attributes: DBsubj, DBprop, DBobj
dico_input_contents_loaded = None
if dico_input_contents == 'Made_with_this_notebook':
  with open("dico_input_contents_DBp.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)
elif dico_input_contents == 'GitHub_GREC_NEs':
  with open("/content/Build_KGs_entities/resources/dico_input_contents_DBp_GREC_NEs.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)

# dico_triple_configs_WebNLG should be a dico_1 with category names as keys, and a dico_2 as value.
# dico_2 should have strings of properties separated by "##" as keys, and integers (occurrence counts in WebNLG train data) as values.
dico_triple_configs_WebNLG = None
with open("/content/Build_KGs_entities/resources/dico_category_tripleConfigs_WebNLG.json", "r") as handle:
    dico_triple_configs_WebNLG = json.load(handle)

# dico_mapping_categories = {'City':'Cities'}
dico_mapping_categories = {'Person':'People', 'City':'Cities'}

# Let's extract which entities have the properties that match a WebNLG configuration. The dico will have: { category: { triple_config: [entity1, entity2, etc.] } }
dico_entities_for_triple_configuration = {}
print('Finding which entities have the properties that match a WebNLG configuration...')
for category_label_WebNLG in dico_mapping_categories.keys():
  dico_entities_for_triple_configuration[category_label_WebNLG] = {}
  for triple_config_WebNLG in dico_triple_configs_WebNLG[category_label_WebNLG].keys():
    # Get input configurations (i.e. property sets) extracted from WebNLG
    list_properties_WebNLG = triple_config_WebNLG.split('##')
    # print(list_properties_WebNLG)
    # Get category label used in GREC
    category_label_GREC = dico_mapping_categories[category_label_WebNLG]
    # For each GREC entity, extract the set of properties found on DBpedia
    for entity_name in dico_input_contents_loaded[category_label_GREC].keys():
      # Need a list of strings so we can then convert in sets and compare with other set of property labels
      list_properties_entity = []
      list_triple_objects_entity = dico_input_contents_loaded[category_label_GREC][entity_name]
      for triple_object in list_triple_objects_entity:
        list_properties_entity.append(triple_object.DBprop)
      # Check if any of the WebNLG triple configurations can be built using the properties of each entity
      if set(list_properties_WebNLG).issubset(set(list_properties_entity)):
        if triple_config_WebNLG not in dico_entities_for_triple_configuration[category_label_WebNLG].keys():
          dico_entities_for_triple_configuration[category_label_WebNLG][triple_config_WebNLG] = []
        dico_entities_for_triple_configuration[category_label_WebNLG][triple_config_WebNLG].append(entity_name)

# Save dico_entities_for_triple_configuration as json
with open("dico_entities_for_triple_configuration.json", "w") as outfile:
    json.dump(dico_entities_for_triple_configuration, outfile)

if print_output == True:
  for category_label in dico_entities_for_triple_configuration.keys():
    print('============')
    print(category_label)
    print('============')
    for triple_config_overlap in dico_entities_for_triple_configuration[category_label].keys():
      print('')
      print(triple_config_overlap)
      print('----------------------------')
      for entity_name in dico_entities_for_triple_configuration[category_label][triple_config_overlap]:
        print('-', entity_name)


In [None]:
#@title Save triple sets in XML format: WebNLG size and input config mirroring
# Here we're trying to build a new dataset that has the same properties as in WebNLG, and the same property configurations in the outputs.
import pickle
import json
from WikipediaPage_Generator.code.utils import create_xml, clear_folder

dico_input_contents = 'GitHub_GREC_NEs'#@param['GitHub_GREC_NEs', 'Made_with_this_notebook']
dico_input_entities = 'GitHub_GREC_NEs'#@param['GitHub_GREC_NEs', 'Made_with_this_notebook']

dico_mapping_categories = {'Person':'People', 'City':'Cities'}

# The dico has the following form: { category: { triple_config: [entity1, entity2, etc.] } }
dico_entities_for_triple_configuration_l = None
if dico_input_entities == 'Made_with_this_notebook':
  dico_entities_for_triple_configuration_l = json.load(open('/content/dico_entities_for_triple_configuration.json', 'r'))
elif dico_input_entities == 'GitHub_GREC_NEs':
  dico_entities_for_triple_configuration_l = json.load(open('/content/Build_KGs_entities/resources/dico_entities_for_triple_configuration_GREC_NEs.json', 'r'))

# dico_input_contents_loaded should be a dico_1 with category names as keys, and a dico_2 as value.
# dico_2 should have entity_name as key, and a list of Triple Objects as value.
# Triple Objects have 3 attributes: DBsubj, DBprop, DBobj
dico_input_contents_loaded = None
if dico_input_contents == 'Made_with_this_notebook':
  with open("dico_input_contents_DBp.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)
elif dico_input_contents == 'GitHub_GREC_NEs':
  with open("/content/Build_KGs_entities/resources/dico_input_contents_DBp_GREC_NEs.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)

counter_datapoints = 0
# Keep track of how many times an entity is used for an XML, so we can number the inputs corresponding to the same entity (an XML is named after the entity name)
entity_counter = {}
for category_l in dico_entities_for_triple_configuration_l:
  print(category_l)
  # Prepare output folder
  clear_folder(os.path.join(triple2predArg, category_l))
  # if not os.path.exists(os.path.join(triple2predArg, input_category)):
  os.makedirs(os.path.join(triple2predArg, category_l))
  catregory_grec = dico_mapping_categories[category_l]
  for triple_config in dico_entities_for_triple_configuration_l[category_l]:
    print('  ', triple_config, len(dico_entities_for_triple_configuration_l[category_l][triple_config]))
    property_list_l = triple_config.split('##')
    # print(f'{category_l}: {triple_config}: {len(dico_entities_for_triple_configuration_l[category_l][triple_config])}')
    for entity_name in dico_entities_for_triple_configuration_l[category_l][triple_config]:
      # Make filename by using entity name + number of times that entity is being used
      if entity_name not in entity_counter.keys():
        entity_counter[entity_name] = 0
      else:
        entity_counter[entity_name] += 1
      filename = entity_name+'_'+str(entity_counter[entity_name])
      list_triple_objects = dico_input_contents_loaded[catregory_grec][entity_name]
      list_selected_triple_objects = []
      for triple_object in list_triple_objects:
        found_prop = False
        if triple_object.DBprop in property_list_l:
          # print(f'      {entity_name}: found {triple_object.DBprop}.')
          if found_prop == False:
            list_selected_triple_objects.append(triple_object)
            found_prop = True
      # The function that builds an XML expects a list of list IDs that correspond to selected triples. In this context, we want all triples.
      properties_selected = [i for i in range(len(property_list_l))]
      # create xml file passing the entity name to use as filename
      counter_datapoints += 1
      list_triples_text = create_xml(list_selected_triple_objects, properties_selected, category_l, os.path.join(triple2predArg, category_l), entity_name=filename, eid = counter_datapoints)

print(f'----------\n{counter_datapoints} new datapoints were created in total.')

## Build knowledge graphs #2: WebNLG mirror input size distribution only

In [None]:
#@title Pseudo-code

# Initialize empty dictionary dico_length_ratio_entity

# For each entity in the dataset:
#     Initialize two lists:
#         subject_triples = []  // Triples where the entity is the subject
#         object_triples = []   // Triples where the entity is the object

#     For each triple related to the entity:
#         If entity is the subject:
#             Add triple to subject_triples
#         Else if entity is the object:
#             Add triple to object_triples

#     // Create property-based dictionaries for subjects and objects
#     subject_property_dict = Group subject_triples by property
#     object_property_dict = Group object_triples by property

#     For each desired input length (e.g., 1 to N triples, in our case N=7):
#         For each possible subject/object ratio (e.g., 1:2, 2:1, etc.):
#             // Ensure at least one triple has the entity as subject (Constraint 2)
#             subject_count = number of triples to select as subject
#             object_count = input_length - subject_count

#             possible_subject_triples = []
#             possible_object_triples = []

#             For subject_properties in subject_property_dict:
#                 Randomly select up to 2 triples per property (Constraint 3)
#                 Add to possible_subject_triples

#             For object_properties in object_property_dict:
#                 Randomly select up to 2 triples per property (Constraint 3)
#                 Add to possible_object_triples

#             selected_triples = []
#             Randomly select subject_count triples from possible_subject_triples
#             Add to selected_triples
#             Randomly select object_count triples from possible_object_triples
#             Add to selected_triples

#             Shuffle selected_triples

#             Store selected_triples in dico_length_ratio_entity

# Sample randomly from dico_length_ratio_entity using WebNLG's input length distribution

In [None]:
#@title Save triple sets in XML format: Implementation of the algorithm for the search+triples sampling
import pickle
import random
import json
from WikipediaPage_Generator.code.utils import create_xml, clear_folder

dico_input_contents = 'GitHub_GREC_NEs'#@param['GitHub_GREC_NEs', 'Made_with_this_notebook']
suffle_selected_triples = True #@param{type:'boolean'}
final_dataset_size = 20 #@param{type:'integer'}
seed = 785 #@param{type:'integer'}
MAX_TRIPLES_SET_LENGTH = 7
MAX_TRIPLES_PER_PROPERTY = 2
DEBUG = False#@param{type:'boolean'}

random.seed(seed)

def group_triples_by_property(triples):
  property_dict = {}
  for triple in triples:
    if triple.DBprop not in property_dict:
      property_dict[triple.DBprop] = []
    property_dict[triple.DBprop].append(triple)
  return property_dict

dico_input_contents_loaded = None
if dico_input_contents == 'Made_with_this_notebook':
  with open("dico_input_contents_DBp.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)
elif dico_input_contents == 'GitHub_GREC_NEs':
  with open("/content/Build_KGs_entities/resources/dico_input_contents_DBp_GREC_NEs.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)

print(f'Triples for {len(dico_input_contents_loaded["Cities"].keys())+len(dico_input_contents_loaded["People"].keys())} entities were found.')

all_triples_sets = []
for category_name, entities in dico_input_contents_loaded.items():
  for entity_name, triples in entities.items():
    if DEBUG:
      print('\n')
      print(f'Entity name: {entity_name}')
      print('', f'# Triples available: {len(triples)}')
    # Extract all the triples that have the current entity as subject
    subject_triples = [tri for tri in triples if tri.DBsubj == entity_name]
    # Extract all the triples that have the current entity as object
    object_triples = [tri for tri in triples if tri.DBobj == entity_name]
    if DEBUG:
      print('  ', f'# Triples with entity as subject: {len(subject_triples)}')
      print('    ', [f"{o.DBsubj} {o.DBprop} {o.DBobj}" for o in subject_triples[:10]])
      print('  ', f'# Triples with entity as object: {len(object_triples)}')
      print('    ', [f"{o.DBsubj} {o.DBprop} {o.DBobj}" for o in object_triples[:10]])

    # Group subject_triples by property to be able to select randomly a maximum of N triples with the same property (set in MAX_TRIPLES_PER_PROPERTY)
    subject_property_dict = group_triples_by_property(subject_triples)
    # Group object_triples by property
    object_property_dict = group_triples_by_property(object_triples)
    if DEBUG:
      print('  ', f'# Unique properties entity as subj: {len(subject_property_dict)}')
      print('  ', f'# Unique properties entity as obj: {len(object_property_dict)}')

    # create one triple set for each triple set length and ratio
    for triples_len in range(1, MAX_TRIPLES_SET_LENGTH+1):
      if DEBUG:
        print("Triple set length:", triples_len)

      # the subject/object ratio is extracted considering subject count between 1 (possibly change to 0?)
      # (at least one triples with the current entity as subject must be in the triples set)
      # and the triples set length, while object count is triples set length - subject count
      # together they give the subject/object ratio
      for subj_count in range(1, triples_len+1):
        obj_count = triples_len - subj_count
        if DEBUG:
          print(f"  Subject count: {subj_count}, Object count: {obj_count}")

        # extracted the possible triples with the entity as subject
        # for each property, select up to 2 triples, this gives us the complete
        # set of triples with entity as subject from which randomly select the
        # triples for the final triples set
        possible_subject_triples = []
        for prop, prop_triples in subject_property_dict.items():
          possible_subject_triples.extend(random.sample(prop_triples, min(MAX_TRIPLES_PER_PROPERTY, len(prop_triples))))
        # if DEBUG:
        #   print('   ', [f"{o.DBsubj} {o.DBprop} {o.DBobj}" for o in possible_subject_triples])

        # extracted the possible triples with the entity as object
        # for each property, select up to 2 triples, this gives us the complete
        # set of triples with entity as object from which randomly select the
        # triples for the final triples set
        possible_object_triples = []
        for prop, prop_triples in object_property_dict.items():
          possible_object_triples.extend(random.sample(prop_triples, min(MAX_TRIPLES_PER_PROPERTY, len(prop_triples))))

        # check that we have enough triples to select
        if len(possible_subject_triples) >= subj_count and len(possible_object_triples) >= obj_count:
          selected_triples = []
          # select subj_count triples where the entity is subject
          selected_triples.extend(random.sample(possible_subject_triples, subj_count))
          # select obj_count triples where the entity is object
          selected_triples.extend(random.sample(possible_object_triples, obj_count))
          # shuffle the selected triples
          if suffle_selected_triples:
            random.shuffle(selected_triples)
          all_triples_sets.append({
              'triples': selected_triples,
              'subj_count': subj_count,
              'obj_count': obj_count,
              'triples_len': triples_len,
              'entity_name': entity_name,
              'category_name': category_name
          })
          if DEBUG:
            print('   ', [f"{o.DBsubj} {o.DBprop} {o.DBobj}" for o in selected_triples])
        else:
          if DEBUG:
            print('   ', 'Cannot make an input with the subj/obj ratio.')

# sample all_triples_sets to reflect the same triples_len as WebNLG and specific
# TODO replace with automatic extraction of real WebNLG distribution
# IMPORTANT: Express distribution in %
distr = {1: 20.8, 2: 19.6, 3: 19.6, 4: 17.2, 5: 12, 6: 6.4, 7: 4.4}
total_prob = int(sum(distr.values()))
# print(total_prob)
assert total_prob == 100, "Total probability should be 100%"

# If we are missing one sample in the total, find which number to increase by one. E.g. we want 20 samples and get 19 because several numbers are rounded down.
# We look for the number that is the closest to the number above to round it up. e.g. we have 1.2, 3.4 and 3.1 samples, which round to 1, 3 and 3, we want to round 3.4 up to 4.
# Keep here all float numbers to be sampled
num_to_sample_list = [final_dataset_size*value/100 for key, value in distr.items()]
# Keep here all rounded numbers to be sampled
rounded_num_to_sample_list = [round(x) for x in num_to_sample_list]
# If the total of rounded numbers to sample does not equal the actual number to sample, correct the float closest to the above number
position_of_num_to_increase = 0
highest_difference = 0
if sum(rounded_num_to_sample_list) + 1 == final_dataset_size:
  count = 0
  while count < len(num_to_sample_list):
    rounded_number = rounded_num_to_sample_list[count]
    float_number = num_to_sample_list[count]
    difference = float_number - rounded_number
    if difference > highest_difference:
      highest_difference = difference
      position_of_num_to_increase = count
    count += 1
  # Now update the list with rounded numbers to sample
  rounded_num_to_sample_list[position_of_num_to_increase] += 1
  if DEBUG:
    print(num_to_sample_list)
    print(rounded_num_to_sample_list)
    print(f'Position of number to increase: {position_of_num_to_increase} (diff = {highest_difference})')
assert sum(rounded_num_to_sample_list) == final_dataset_size, 'Total number of samples does not match final_dataset_size.'

sampled_triple_sets = []
for count, (triples_len, prob) in enumerate(distr.items()):
  # Select any triple set that has the expected size (allows for multiple triple sets per entity)
  triples_of_len = [tri for tri in all_triples_sets if tri['triples_len'] == triples_len]
  # num_to_sample = round(final_dataset_size * prob / 100)
  num_to_sample = rounded_num_to_sample_list[count]
  if num_to_sample <= len(triples_of_len):
    sampled_triple_sets.extend(random.sample(triples_of_len, num_to_sample))
  else:
    print(f'!!! Could not select triples sets of size {triples_len} (not enough triple sets).')
  print(f'Length: {triples_len}')
  print(f'  # Total triple sets of current size: {len(triples_of_len)}')
  print(f'  # Selected triple sets of current size: {num_to_sample}')
  print(f'  # Total Selected triples at this point: {len(sampled_triple_sets)}')

print(f'# Selected triple sets: {len(sampled_triple_sets)}')
print(sampled_triple_sets[:10])

## Number of triples for each category
# print(f'# Triple sets People: {len([tri for tri in sampled_triples if tri["category_name"]=="People"])}')
# print(f'# Triple sets Cities:{len([tri for tri in sampled_triples if tri["category_name"]=="Cities"])}')
## Total number of triple sets
# print(f'# Triple sets before sampling: {len(all_triples_sets)}')
# print(all_triples_sets[:10])

# Save datapoints in individual XML files
counter_datapoints = 0
entity_counter = {}
folder_name = ''
# print(folder_name)
clear_folder(triple2predArg)
os.makedirs(os.path.join(triple2predArg, 'People'))
os.makedirs(os.path.join(triple2predArg, 'Cities'))
for sampled_triple_set in sampled_triple_sets:
  counter_datapoints += 1
  # print(counter_datapoints)
  list_triple_objects = sampled_triple_set['triples']
  properties_selected = [i for i in range(len(sampled_triple_set['triples']))]
  input_category = sampled_triple_set['category_name']
  folder_name = input_category
  entity_name = sampled_triple_set['entity_name']
  if entity_name not in entity_counter.keys():
    entity_counter[entity_name] = 0
  else:
    entity_counter[entity_name] += 1
  filename = entity_name+'_'+str(entity_counter[entity_name])
  # create xml file passing the entity name to use as filename
  list_triples_text = create_xml(list_triple_objects, properties_selected, input_category, os.path.join(triple2predArg, folder_name), entity_name=filename, eid = counter_datapoints)

# print(sampled_triple_sets[:10])
# TODO save created dataset preprocessing the triples (we cannot save the object as it is)
# for i, tri in enumerate(sampled_triple_sets):
#   tri['triples'] = [f"{o.DBsubj} | {o.DBprop} | {o.DBobj}" for o in tri['triples']]
# with open('/content/sampled_triples.json', 'w') as f:
#   json.dump(sampled_triple_sets, f)

In [None]:
#@title Check triples extracted from DBpedia
import pickle
import matplotlib.pyplot as plt

# dico_input_contents_loaded should be a dico_1 with category names as keys, and a dico_2 as value.
# dico_2 should have entity_name as key, and a list of Triple Objects as value.
# Triple Objects have 3 attributes: DBsubj, DBprop, DBobj
dico_input_contents_loaded = None
if dico_input_contents == 'Made_with_this_notebook':
  with open("dico_input_contents_DBp.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)
elif dico_input_contents == 'GitHub_GREC_NEs':
  with open("/content/Build_KGs_entities/resources/dico_input_contents_DBp_GREC_NEs.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)

dico_properties = {}
dico_different_properties = {}
dico_entity_as_subj = {}
dico_entity_as_obj = {}
for category in dico_input_contents_loaded.keys():
  # print(category)
  for entity_name in dico_input_contents_loaded[category].keys():
    all_properties = []
    different_properties = []
    subj_of_properties = []
    obj_of_properties = []
    # print('  ', entity_name)
    for triple_object in dico_input_contents_loaded[category][entity_name]:
      all_properties.append(triple_object.DBprop)
      if triple_object.DBprop not in different_properties:
        different_properties.append(triple_object.DBprop)
      if triple_object.DBsubj == entity_name:
        subj_of_properties.append(triple_object.DBprop)
      if triple_object.DBobj == entity_name:
        obj_of_properties.append(triple_object.DBprop)
    dico_properties[entity_name] = len(all_properties)
    dico_different_properties[entity_name] = len(different_properties)
    dico_entity_as_subj[entity_name] = len(subj_of_properties)
    dico_entity_as_obj[entity_name] = len(obj_of_properties)
    # print(f'    {len(different_properties)} different properties')
    # print(f'    {len(subj_of_properties)} properties with {entity_name} as subject')
    # print(f'    {len(obj_of_properties)} properties with {entity_name} as object')

dico_properties_sorted = {k: v for k, v in sorted(dico_properties.items(), key=lambda item: item[1], reverse=True)}
dico_different_properties_sorted = {k: v for k, v in sorted(dico_different_properties.items(), key=lambda item: item[1], reverse=True)}
dico_entity_as_subj_sorted = {k: v for k, v in sorted(dico_entity_as_subj.items(), key=lambda item: item[1], reverse=True)}
dico_entity_as_obj_sorted = {k: v for k, v in sorted(dico_entity_as_obj.items(), key=lambda item: item[1], reverse=True)}

# Plot dico_properties dictionary using matplotlib
plt.figure(figsize=(10, 6))
plt.bar(dico_properties_sorted.keys(), dico_properties_sorted.values())
plt.xlabel('Entity Name')
plt.ylabel('Number of Properties')
plt.show()

# Plot dico_different_properties dictionary using matplotlib
plt.figure(figsize=(10, 6))
plt.bar(dico_different_properties_sorted.keys(), dico_different_properties_sorted.values())
plt.xlabel('Entity Name')
plt.ylabel('Number of Different Properties')
plt.show()

# Plot dico_entity_as_subj_sorted dictionary using matplotlib
plt.figure(figsize=(10, 6))
plt.bar(dico_entity_as_subj_sorted.keys(), dico_entity_as_subj_sorted.values())
plt.xlabel('Entity Name')
plt.ylabel('Number of Properties with Entity as Subject')
plt.show()

# Plot dico_entity_as_obj_sorted dictionary using matplotlib
plt.figure(figsize=(10, 6))
plt.bar(dico_entity_as_obj_sorted.keys(), dico_entity_as_obj_sorted.values())
plt.xlabel('Entity Name')
plt.ylabel('Number of Properties with Entity as Object')
plt.show()

In [None]:
#@title Check dataset
from colored import Fore, Back, Style

sizes = {}
property_count_per_datapoint = {}
subjObj_ratios = {}
category_count = {}
for i, striple_set in enumerate(sampled_triple_sets):
  print(f'Datapoint {i+1}:')
  # Check that there is no error with the triple set size
  assert striple_set['subj_count'] + striple_set['obj_count'] == striple_set['triples_len'], 'subj_count + obj_count should be equal to triples_len.'
  print(f'{Fore.green}{Back.white}  subj_count and obj_count match triples_len{Style.reset}')
  assert striple_set['triples_len'] == len(striple_set['triples']), 'triples_len should be equal to the number of triples in the triple set.'
  print(f'{Fore.green}{Back.white}  triples_len and number of triples match{Style.reset}')

  # Collect and check property count for each datapoint
  if str(i) not in property_count_per_datapoint.keys():
    property_count_per_datapoint[str(i)] = {}
  for triple in striple_set['triples']:
    if triple.DBprop not in property_count_per_datapoint[str(i)]:
      property_count_per_datapoint[str(i)][triple.DBprop] = 0
    property_count_per_datapoint[str(i)][triple.DBprop] += 1
  assert property_count_per_datapoint[str(i)][triple.DBprop] <= 2, 'No more than 2 instances of a property per triple set.'
  print(f'{Fore.green}{Back.white}  No more than 2 instances of the same property in the triple set{Style.reset}')

  # Collect all triple sizes
  if striple_set['triples_len'] not in sizes.keys():
    sizes[striple_set['triples_len']] = 0
  sizes[striple_set['triples_len']] += 1

  # Collect category count
  if striple_set['category_name'] not in category_count.keys():
    category_count[striple_set['category_name']] = {}
  if striple_set['entity_name'] not in category_count[striple_set['category_name']].keys():
    category_count[striple_set['category_name']][striple_set['entity_name']] = 0
  category_count[striple_set['category_name']][striple_set['entity_name']] += 1
  # Sort entity names by frequency
  category_count[striple_set['category_name']] = {k: v for k, v in sorted(category_count[striple_set['category_name']].items(), key=lambda item: item[1], reverse=True)}

  # Collect configuration count of triples (i.e. for each size, what is the subj/obj ratio count)
  if striple_set['triples_len'] not in subjObj_ratios.keys():
    subjObj_ratios[striple_set['triples_len']] = {}
  subj_obj_ratio = str(striple_set['subj_count'])+'/'+str(striple_set['obj_count'])
  if subj_obj_ratio not in subjObj_ratios[striple_set['triples_len']].keys():
    subjObj_ratios[striple_set['triples_len']][subj_obj_ratio] = 0
  subjObj_ratios[striple_set['triples_len']][subj_obj_ratio] += 1

print('')
# Check number of datapoints for each size and in total
assert sum(sizes.values()) == final_dataset_size
print(f'{Fore.green}{Back.white}Total number of datapoints: {sum(sizes.values())}{Style.reset}')
print(f'{Fore.green}{Back.white}Triple size distribution: {sizes}{Style.reset}')
print('')
print(f'Subject/Object ratios:    {subjObj_ratios}')
# print sum of people
print(f"Category count:           'People': {sum(category_count['People'].values())} ({len(category_count['People'])} distinct), 'Cities': {sum(category_count['Cities'].values())} ({len(category_count['Cities'])} distinct)")
print(f'People instances:         {category_count["People"]}')
print(f'Cities instances:         {category_count["Cities"]}')

## Build knowledge graphs #3: free distribution

In [None]:
#@title Test function
# def get_first_n_instances_of_props(list_triple_objects, max_num_of_instances_of_prop_desired, properties_that_can_happen_once_only):
#   """
#   Function that selects the first n occurrences of a triple that contains the same property.
#   Input list_triple_objects: a list of triple objects, each triple should have 3 attributes: DBsubj, DBprop, DBobj.
#   Input max_num_of_instances_of_prop_desired: an integer that specifies the maximum number of occurrences of each property in the triple set.
#   Input properties_that_can_happen_once_only; a list of property labels that cannot have 2 or more values for the same subject, even though they do have more than one on the queried resource.
#   Output: a list of ist indices, e.g [0, 1, 2, 3, 4, 5, 6, 8, 9].
#   """
#   # Dico to keep track of properties already added for each subject
#   dico_dbSubj_properties = {}
#   # Dico to keep track of how many times each property is found in the triple set
#   dico_num_instances_of_prop_found = {}

#   selected_properties = []
#   for i, triple_object in enumerate(list_triple_objects):
#     # Add property to the list of properties in the triple set and initialise count
#     if triple_object.DBprop not in dico_num_instances_of_prop_found.keys():
#       dico_num_instances_of_prop_found[triple_object.DBprop] = 1
#       selected_properties.append(i)
#     # For the second and more instance of a property, increase counter and if the resulting number is below that of the maximum number of instances of each property specified in the input, then add the id to the selected list
#     else:
#       dico_num_instances_of_prop_found[triple_object.DBprop] += 1
#       if dico_num_instances_of_prop_found[triple_object.DBprop] <= max_num_of_instances_of_prop_desired:
#         # Only add a 2nd or 3rd property if (i) that property is not in the list of props that can happen only once, or (ii) if it is in that list but the subject doesn't already have that property in the triple set
#         if (triple_object.DBprop not in properties_that_can_happen_once_only) or (triple_object.DBsubj not in dico_dbSubj_properties.keys()) or (triple_object.DBprop not in dico_dbSubj_properties[triple_object.DBsubj]):
#           selected_properties.append(i)

#     # Now that we processed a triple, fill up dico_dbSubj_properties
#     # For the first instance of a property, create dico entry and add id of triple_object to the list of selected properties
#     if triple_object.DBsubj not in dico_dbSubj_properties.keys():
#       dico_dbSubj_properties[triple_object.DBsubj] = []
#     # Add property to the list of properties used for a subject
#     if triple_object.DBprop not in dico_dbSubj_properties[triple_object.DBsubj]:
#       dico_dbSubj_properties[triple_object.DBsubj].append(triple_object.DBprop)
#   # print(dico_dbSubj_properties)
#   return selected_properties

In [None]:
#@title Save triple set in XML format: Large DBpedia dataset
from WikipediaPage_Generator.code.utils import get_first_n_instances_of_props, create_xml, clear_folder
# from WikipediaPage_Generator.code.utils import create_xml, clear_folder
import json
import pickle

# Used for long-input D2T experiments: C:\Users\sfmil\Desktop\ADAPT-2025-2026\MyPapers\2025-06_INLG-longText-D2T\files\triple_sets_full\dico_input_contents_DBp.pickle
dico_input_contents = 'Made_with_this_notebook'#@param['GitHub_GREC_NEs', 'Made_with_this_notebook']
# Map some categories that are found in WebNLG to the name used in WebNLG
dico_mapping_categories = {'People':'Person', 'Cities':'City'}

# Specifies the maximum number of occurrences of each property in the triple set
max_num_of_instances_of_prop_desired = "3" #@param[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Specifies the minimun and maximum size of input desired
min_input_size = "8" #@param[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 25, 30, 35, 40, 45, 50]
max_input_size = "100" #@param[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 500, 1000]
# List here properties that cannot have 2 values (to filter bad stuff from i.e. infobox)
# properties_that_can_happen_once_only = ['budget', 'gross', 'imdbId', 'length', 'runtime']
properties_that_can_happen_once_only = json.loads(open('/content/WikipediaPage_Generator/resources/list_props_that_can_happen_once_only.json', 'r').read())
entities_for_dev = ['Abu_Dhabi', 'Accra', 'Islamabad', 'Lagos', 'Aneurin_Bevan', 'Anthony_Giddens', 'Antoine_Lavoisier', 'Antonio_Negri']

# Input data structure should be a dico_1 with category names as keys, and a dico_2 as value.
# dico_2 should have entity_name as key, and a list of Triple Objects as value.
# Triple Objects have 3 attributes: DBsubj, DBprop, DBobj
counter_datapoints = 0
dico_input_contents_loaded = None
if dico_input_contents == 'Made_with_this_notebook':
  with open("dico_input_contents_DBp.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)
elif dico_input_contents == 'GitHub_GREC_NEs':
  with open("/content/Build_KGs_entities/resources/dico_input_contents_DBp_GREC_NEs.pickle", "rb") as handle:
    dico_input_contents_loaded = pickle.load(handle)

total_num_selected_props = 0
longest_input_size = 0
for input_category in dico_input_contents_loaded.keys():
  webnlg_input_category = dico_mapping_categories[input_category]
  print(webnlg_input_category)
  folder_name = webnlg_input_category+'_'+str(max_num_of_instances_of_prop_desired)
  clear_folder(os.path.join(triple2predArg, folder_name))
  # if not os.path.exists(os.path.join(triple2predArg, input_category)):
  os.makedirs(os.path.join(triple2predArg, folder_name))
  for entity_name in dico_input_contents_loaded[input_category].keys():
    # Comment this "if" to get v1 data; use "if entity_name in entities_for_dev" to get DEV data.
    if entity_name not in entities_for_dev:
      list_triple_objects = dico_input_contents_loaded[input_category][entity_name]
      # Generate list of indices of all properties that can be part of an input (index in the list of Triple objects that contains all retrieved triples)
      candidate_properties = get_first_n_instances_of_props(list_triple_objects, int(max_num_of_instances_of_prop_desired), properties_that_can_happen_once_only)
      print(f'  {entity_name}: pre-selected {len(candidate_properties)}/{len(dico_input_contents_loaded[input_category][entity_name])} properties.')

      if len(candidate_properties) >= int(min_input_size):
        if len(candidate_properties) > longest_input_size:
          longest_input_size = len(candidate_properties)
        # create xml file passing the entity name to use as filename
        print(f'    {len(candidate_properties[:int(max_input_size)])} included in dataset')
        total_num_selected_props += len(candidate_properties[:int(max_input_size)])
        counter_datapoints += 1
        list_triples_text = create_xml(list_triple_objects, candidate_properties[:int(max_input_size)], webnlg_input_category, os.path.join(triple2predArg, folder_name), entity_name=entity_name, eid = counter_datapoints)

print(f'----------\n{counter_datapoints} new datapoints were created in total (between {int(min_input_size)} and {longest_input_size} triples in input; up to {max_num_of_instances_of_prop_desired} instances of the same property in input).')
print(f'Average number of triples per input: {total_num_selected_props/counter_datapoints}')

## Process output XMLs (group, etc.)
Stratified sampling info 100 inputs:
1:20.8 - 2:19.6 - 3:19.6 - 4:17.2 - 5:12 - 6:6.4 - 7:4.4

In [None]:
#@title Put all XMLs in the same file for sampling
import glob
import codecs
import json

XML_folder = 'XML_Split'#@param['XML', 'XML_Split']

paths_folders_categories = None
out_filename = None
out_entities_list = None

############# In case code is applied to unsplit files generated with the cells above
if XML_folder == 'XML':
  paths_folders_categories = glob.glob('/content/XML/*')
  out_filename = f'/content/XML/D2T-1-FA_same{max_num_of_instances_of_prop_desired}_min{min_input_size}_max{max_input_size}.xml'
  out_entities_list = f'/content/XML/list_entities_same{max_num_of_instances_of_prop_desired}_min{min_input_size}_max{max_input_size}.json'
############# In case code is applied to split files
elif XML_folder == 'XML_Split':
  paths_folders_categories = glob.glob('/content/XML_Split/*')
  out_filename = f'/content/XML_Split/D2T-1-FA_same3_min8_max{max_num_triples}_SPLIT.xml'
  out_entities_list = f'/content/XML_Split/list_entities_same3_min8_max100_SPLIT-{max_num_triples}.json'

with codecs.open(out_filename, 'w', 'utf-8') as f:
  list_entities = []
  f.write('<?xml version="1.0" encoding="UTF-8"?>\n')
  f.write('<benchmark>\n')
  f.write('  <entries>\n')
  for path_folder_category in sorted(paths_folders_categories):
    list_XMLS_for_category = glob.glob(os.path.join(path_folder_category, '*.xml'))
    for XML in sorted(list_XMLS_for_category):
      filename = os.path.basename(XML).rsplit('.', 1)[0]
      list_entities.append(filename)
      with codecs.open(XML, 'r', 'utf-8') as file:
        XML_lines = file.readlines()
        for line in XML_lines:
          if not line.startswith('<') and not line.startswith('  <'):
            f.write(line)
  f.write('  </entries>\n')
  f.write('</benchmark>\n')
  # Save entity list as json
  # print(list_entities)
  with codecs.open(out_entities_list, 'w', 'utf-8') as f:
    json.dump(list_entities, f)
print('Created XML file!')

In [None]:
#@title Zip and download XML folder
from google.colab import files
import os

XML_folder = 'XML_Split'#@param['XML', 'XML_Split']

zip_name_xml = None
folder_to_zip = None
if XML_folder == 'XML':
  zip_name_xml = f'/content/XMLs_{dico_input_contents}_same{max_num_of_instances_of_prop_desired}_min{min_input_size}_max{max_input_size}.zip'
  folder_to_zip = '/content/XML'
elif XML_folder == 'XML_Split':
  zip_name_xml = f'v4_XMLs_GREC_NEs_Updated_same3_min8_max100_SPLIT-{max_num_triples}.zip'
  folder_to_zip = '/content/XML_Split'

if os.path.exists(zip_name_xml):
  os.remove(zip_name_xml)

! zip -r {zip_name_xml} {folder_to_zip}
files.download(zip_name_xml)

In [None]:
#@title Split each triple set in an XML file into several smaller triple sets
from WikipediaPage_Generator.code.utils import TripleSet, Triple_withID, create_xml, clear_folder, balanced_split_with_max, extract_info_from_WebNLG_XML, sort_WebNLG_XMLs, split_XMLs
import os

##### INPUT VARIABLES
path_xml = "/content/v4_long-inputs_GREC_same3_min8_max100.xml"#@param{type:"string"}
path_DBprops_count = "/content/WikipediaPage_Generator/resources/dico_count_occurrences_dbp_props.json"#@param{type:"string"}
# Take into account a few triples can be added to a group in case boundaries between groups are changed (to keep same properties in the same group)
max_num_triples = 22#@param{type:"integer"}
path_save_XMLs = "/content/XML_Split"#@param{type:"string"}
debug = False#@param{type:"boolean"}

if os.path.exists(path_xml):
  split_XMLs(path_xml, path_DBprops_count, max_num_triples, path_save_XMLs, DEBUG = debug)

In [None]:
#@title Compare Semantic accuracy pickle file and Long-input D2T pickle file
import pickle
import json

dico_input_contents_SemEx_loaded = None
with open("/content/Build_KGs_entities/resources/dico_input_contents_DBp_GREC_NEs.pickle", "rb") as handle:
  dico_input_contents_SemEx_loaded = pickle.load(handle)

dico_input_contents_LongIn_loaded = None
with open("/content/dico_input_contents_DBp.pickle", "rb") as handle:
  dico_input_contents_LongIn_loaded = pickle.load(handle)

def showContentsDico(dico_input_contents_loaded):
  dico_contents = {}
  for category in dico_input_contents_loaded.keys():
    print(category)
    if category not in dico_contents.keys():
      dico_contents[category] = {}
    for entity_name in dico_input_contents_loaded[category].keys():
      print(f'  {entity_name} ({len(dico_input_contents_loaded[category][entity_name])} properties)')
      if entity_name not in dico_contents[category].keys():
        dico_contents[category][entity_name] = []
      for triple_object in dico_input_contents_loaded[category][entity_name]:
        # print('    ', triple_object.DBsubj, triple_object.DBprop, triple_object.DBobj)
        triple = f'{triple_object.DBsubj} {triple_object.DBprop} {triple_object.DBobj}'
        dico_contents[category][entity_name].append(triple)
  return dico_contents

dico_SemEx = showContentsDico(dico_input_contents_SemEx_loaded)
dico_LongIn = showContentsDico(dico_input_contents_LongIn_loaded)

# Save dicos as jsons
with open('dico_input_contents_DBp_GREC_NEs_SemEx.json', 'w') as f:
  json.dump(dico_SemEx, f)
with open('dico_input_contents_DBp_GREC_NEs_LongIn.json', 'w') as f:
  json.dump(dico_LongIn, f)

# More stuff

In [None]:
#@title Test Functions

# To split large XML inputs that can trigger memory issues in FORGe
# 1 Check what the most frequent entity is (as subject or object)
# 2 order all triples in which this entity is subject by frequency of occurrence of the property in DBpedia
# 3 order all triples in which this entity is object by frequency of occurrence of the property in DBpedia
  # If the entity is the object of a property that ends with a preposition, put it at the bottom in the list.
# 4 Repeat 2 and 3 for the next most frequent entities as subject.
# 5 Create an XML with at most n entities. Do not cut between two same properties.

##### INSTALLS
# ! pip install xmltodict
##### IMPORTS
from WikipediaPage_Generator.code.utils import create_xml, clear_folder
import xmltodict
import json
import codecs
import os
from colored import Fore, Back, Style
##### INPUT VARIABLES
# path_xml = "/content/v3_long-inputs_GREC_same3_min8_max100.xml"
path_xml = "/content/v3_long-inputs_GREC_same3_min8_just1.xml"
path_DBprops_count = "/content/WikipediaPage_Generator/resources/dico_count_occurrences_dbp_props.json"
# Take into account a few triples can be added to a group in case boundaries between groups are changed (to keep same properties in the same group)
max_num_triples = 22
path_save_XMLs = "/content/XML_Split"
debug = False

##### FUNCTIONS
class TripleSet:
  def __init__(self, triples_list, category, eid, shape, shape_type):
    self.triples = triples_list
    self.category = category
    self.eid = eid
    self.size = len(triples_list)
    self.shape = shape
    self.shape_type = shape_type
    self.entities_by_frequency = []
    # The main entity is the one that has the most occurrences as subject or object (if several have the same num of occurrences, "max" returns the first one)
    entity_counter_dico = {}
    for triple in self.triples:
      if triple.DBsubj not in entity_counter_dico.keys():
        entity_counter_dico[triple.DBsubj] = 1
      else:
        entity_counter_dico[triple.DBsubj] += 1
      if triple.DBobj not in entity_counter_dico.keys():
        entity_counter_dico[triple.DBobj] = 1
      else:
        entity_counter_dico[triple.DBobj] += 1
    self.entities_by_frequency = sorted(entity_counter_dico, key=entity_counter_dico.get, reverse=True)
    # print(entity_counter_dico)
    # print(self.entities_by_frequency)

class Triple_withID:
  def __init__(self, prop, subj_value, obj_value, triple_id):
    self.DBprop = prop
    self.DBsubj = subj_value
    self.DBobj = obj_value
    self.id = triple_id

def balanced_split_with_max(N1, N2):
  """
  Function that splits as evenly as possible a number N1 into smaller numbers each as close as possible to another number N2 without being larger than N2.
  For example, if N1 == 50 and N2 == 20, the output is 17, 17 16.
  """
  assert N1 >= N2, "N1 must be greater than or equal to N2"
  assert N2 >= 1, "N1 must be at least 1"
  # Start with the minimal number of parts needed to respect the max constraint
  for k in range(N1 // N2, N1 + 1):
    # print(f'k: {k}')
    base = N1 // k
    remainder = N1 % k
    # print(f'N1//N2, N1+1: {N1//N2}, {N1+1}')
    # print(f'base: {base}')
    # print(f'remainder: {remainder}')
    # print()
    # The largest part will be base + 1 (if remainder > 0)
    if base + (1 if remainder > 0 else 0) <= N2:
      result = [base + 1] * remainder + [base] * (k - remainder)
      return result

def extract_info_from_WebNLG_XML (path_input_XML):
  """
  path_input_XML: Path to an XML file that contains triple sets, e.g. as provided in the WebNLG shared tasks.
  returns a list of TripleSet objects. Each object contains as attributes: triples (a list of Triple objects), category, eid, size, shape, main_entity
  """
  with codecs.open(path_input_XML, 'r', 'utf-8') as file:
    XML_file = file.read()
    XML_dict = xmltodict.parse(XML_file)
    print(f'    Reading file {path_input_XML}..')
    # triple_sets_list will be a list of objects of class TripleSet
    total_number_of_triples = 0
    triple_sets_list = []
    if isinstance(XML_dict['benchmark']['entries']['entry'], list):
      print(f"      There are {len(XML_dict['benchmark']['entries']['entry'])} inputs in the original XML file.")
      for entry in XML_dict['benchmark']['entries']['entry']:
        category = entry['@category']
        eid = entry['@eid']
        size = entry['@size']
        shape = entry['@shape']
        shape_type = entry['@shape-type']
        # mtriples_list will be a list of objects of class Triple
        mtriples_list = []
        # Get modified triples
        if isinstance(entry['modifiedtripleset']['mtriple'], list):
          for triple_id, mtriple in enumerate(entry['modifiedtripleset']['mtriple']):
            triple_object = Triple_withID(mtriple.split(' | ')[1], mtriple.split(' | ')[0], mtriple.split(' | ')[2], triple_id)
            mtriples_list.append(triple_object)
        else:
          mtriples_list.append(entry['modifiedtripleset']['mtriple'])
        assert int(size) == len(mtriples_list), f'Error: found size {size} but found {len(mtriples_list)} triples.'
        total_number_of_triples += len(mtriples_list)
        # Create object of class TripleSet
        tripleSet_object = TripleSet(mtriples_list, category, eid, shape, shape_type)
        triple_sets_list.append(tripleSet_object)
    else:
      print(f"      There is 1 input in the original XML file.")
      category = XML_dict['benchmark']['entries']['entry']['@category']
      eid = XML_dict['benchmark']['entries']['entry']['@eid']
      size = XML_dict['benchmark']['entries']['entry']['@size']
      shape = XML_dict['benchmark']['entries']['entry']['@shape']
      shape_type = XML_dict['benchmark']['entries']['entry']['@shape-type']
      # Block repeated from above
      # mtriples_list will be a list of objects of class Triple
      mtriples_list = []
      # Get modified triples
      if isinstance(XML_dict['benchmark']['entries']['entry']['modifiedtripleset']['mtriple'], list):
        for triple_id, mtriple in enumerate(XML_dict['benchmark']['entries']['entry']['modifiedtripleset']['mtriple']):
          triple_object = Triple_withID(mtriple.split(' | ')[1], mtriple.split(' | ')[0], mtriple.split(' | ')[2], triple_id)
          mtriples_list.append(triple_object)
      else:
        mtriples_list.append(XML_dict['benchmark']['entries']['entry']['modifiedtripleset']['mtriple'])
      assert int(size) == len(mtriples_list), f'Error: found size {size} but found {len(mtriples_list)} triples.'
      total_number_of_triples += len(mtriples_list)
      # Create object of class TripleSet
      tripleSet_object = TripleSet(mtriples_list, category, eid, shape, shape_type)
      triple_sets_list.append(tripleSet_object)

    print(f"      There are {total_number_of_triples} input triples in the original XML file.")

  return triple_sets_list

def sort_WebNLG_XMLs (path_input_XML, path_DBprops_count):
  """
  path_input_XML: Path to an XML file that contains triple sets, e.g. as provided in the WebNLG shared tasks. The code expects that all triples mention the same entity, as subject or object.
  path_DBprops_count: Path to a json file that contains DBpedia properties as keys (e.g. "http://dbpedia.org/ontology/birthPlace") and number of occurrences on DBpedia as values (e.g 1486579).
  This function returns a list of TripleSets objects. TripleSet.triples contains Triple objects; in each triple set, Triple objects are sorted by "importance" (i.e. sorted by frequency of entity in the triple set, and by frequency of property on DBpedia)
  """
  print('  Sorting triples sets by frequency of entity in the triple set, and by frequency of respective properties on DBpedia...')
  dico_count_occurrences_dbp_props = json.loads(codecs.open(path_DBprops_count, 'r', 'utf-8').read())
  triple_sets_list = extract_info_from_WebNLG_XML (path_input_XML)
  new_triple_set_Objects_list = []
  total_number_of_triples = 0
  for triple_set in triple_sets_list:
    # print(triple_set.eid, triple_set.category, triple_set.size, triple_set.entities_by_frequency[0])
    # Make a list where we will store the order of the triples using their index in the triple_set list
    # E.g. list_triple_indices = [0, 4, 5, 2, 3, 1]
    list_triple_indices = []
    # Process entities by their respective importance in the triple set, so the most frequently found entity will go first, the second most frequently found will go second, and so on.
    for entity_name in triple_set.entities_by_frequency:
      # print(f'  {entity_name}')
      # Make a list of property labels with the http://dbpedia.org/ontology/ prefix, one with the properties where the entity is subject, and one with the properties where the entity is object
      # The properties in the ..._Subj list will go first, the properties in the ..._Obj list will go after.
      list_dico_count_occurrences_dbp_props_keys_Subj = [[f'http://dbpedia.org/ontology/{triple.DBprop}', triple.id] for triple in triple_set.triples if triple.DBsubj == entity_name]
      list_dico_count_occurrences_dbp_props_keys_Obj = [[f'http://dbpedia.org/ontology/{triple.DBprop}', triple.id] for triple in triple_set.triples if triple.DBobj == entity_name]
      # Order that list according to the count in path_DBprops_count
      sorted_list_dico_count_occurrences_dbp_props_keys_Subj = sorted(list_dico_count_occurrences_dbp_props_keys_Subj, key=lambda x: dico_count_occurrences_dbp_props[x[0]], reverse=True)
      sorted_list_dico_count_occurrences_dbp_props_keys_Obj = sorted(list_dico_count_occurrences_dbp_props_keys_Obj, key=lambda x: dico_count_occurrences_dbp_props[x[0]], reverse=True)
      # print(f'    {sorted_list_dico_count_occurrences_dbp_props_keys_Subj}')
      # print(f'    {sorted_list_dico_count_occurrences_dbp_props_keys_Obj}')
      # Now put all the triple indices for the current entity in list_triple_indices, starting with the triples in which the entity is subject
      for list_triple_indices_Subj in sorted_list_dico_count_occurrences_dbp_props_keys_Subj:
        # To avoid duplicated triples:
        if list_triple_indices_Subj[1] not in list_triple_indices:
          list_triple_indices.append(list_triple_indices_Subj[1])
      for list_triple_indices_Obj in sorted_list_dico_count_occurrences_dbp_props_keys_Obj:
        if list_triple_indices_Obj[1] not in list_triple_indices:
          list_triple_indices.append(list_triple_indices_Obj[1])

    #Now add the triples in a list, ordering the triples as defined in list_triple_indices (the create_xml function expects the triples ordered already)
    new_triples_list = [triple_set.triples[i] for i in list_triple_indices]
    assert len(new_triples_list) == len(triple_set.triples), f'Expected {len(triple_set.triples)} triples, found {len(new_triples_list)}'
    total_number_of_triples += len(new_triples_list)
    # print(len(new_triples_list), [new_triples_list[x].id for x in range(len(new_triples_list))])
    new_triple_set_Objects_list.append(TripleSet(new_triples_list, triple_set.category, triple_set.eid, triple_set.shape, triple_set.shape_type))
  assert len(new_triple_set_Objects_list) == len(triple_sets_list), f'Expected {len(triple_sets_list)} triple sets, found {len(new_triple_set_Objects_list)}'
  print(f'    There are {len(new_triple_set_Objects_list)} sorted triple sets...')
  print(f'    There are {total_number_of_triples} input triples in the sorted XML file.')

  return new_triple_set_Objects_list

def split_XMLs (path_input_XML, path_DBprops_count, max_num_triples, path_save_XMLs, DEBUG = False):
  """
  path_input_XML: Path to an XML file that contains triple sets, e.g. as provided in the WebNLG shared tasks. The code expects that all triples mention the same entity, as subject or object.
  path_DBprops_count: Path to a json file that contains DBpedia properties as keys (e.g. "http://dbpedia.org/ontology/birthPlace") and number of occurrences on DBpedia as values (e.g 1486579).
  max_num_triples: the maximum number of triples desired in an XML
  path_save_XMLs: the path where the output XMLs should be created
  This function creates individual XML files for each split triple set.
  """
  print('Splitting XML file...')
  clear_folder(path_save_XMLs)
  os.makedirs(path_save_XMLs)
  # Get the list of TripleSet objects with the triples re-ordered. The object contains the following:
  # self.triples, self.category, self.eid, self.size, self.shape, self.shape_type, self.entities_by_frequency
  new_triple_set_Objects_list = sort_WebNLG_XMLs(path_input_XML, path_DBprops_count)
  total_number_of_XMLs = 0
  total_number_of_triples = 0
  for new_triple_set in new_triple_set_Objects_list:
    if DEBUG:
      print(new_triple_set.size, new_triple_set.entities_by_frequency[0])
    # Get "ideal" triple set split (see balanced_split_with_max function)
    even_slices = None
    if new_triple_set.size > max_num_triples:
      # balanced_split_with_max returns a sequence of numbers that stand for a number of properties.
      groups = balanced_split_with_max(new_triple_set.size, max_num_triples)
      # Let's convert that to a sequence of numbers that correspond to list slices: [10, 10, 5] becomes [10, 20, 25]
      even_slices = [sum(groups[:i]) for i in range(len(groups)+1)]
    else:
      even_slices = [0] + [new_triple_set.size]
    if DEBUG:
      print(f'  Before: {even_slices}')

    # Initialise new list
    new_slices = [0]
    # Now we need to check if the split happened between two occurrences of the same property, which we'd like to avoid
    # even_slices has at least 2 numbers, 0 and the end of the first or only slice.
    if len(even_slices) > 2:
      # Check for intermediate group boundaries (i.e. exclude the first boundary, which is 0, and the last one, because there is no property after it)
      for boundary in even_slices[1:-1]:
        previous_same_property = 0
        # Since in the way even_slices is built, the last slices are the smallest ones, it's better to move boundaries to the left.
        while new_triple_set.triples[boundary+previous_same_property].DBprop == new_triple_set.triples[boundary+previous_same_property-1].DBprop:
          previous_same_property -= 1
        if previous_same_property < 0:
          new_slices.append(boundary+previous_same_property)
          if DEBUG:
            print(f'  {Fore.red}{Back.yellow}!!! Changed split {boundary}, {previous_same_property}!{Style.reset}')
        else:
          new_slices.append(boundary)
      # Add last boundary
      new_slices.append(even_slices[-1])
    else:
      # Add second and last boundary
      new_slices.append(even_slices[1])
    if DEBUG:
      print(f'  After: {new_slices}')

    # Create XMLs
    # Set parameters for calling function that creates XMLs
    input_category = new_triple_set.category
    folder_name = input_category+'_max'+str(max_num_triples)
    entity_name = new_triple_set.entities_by_frequency[0]
    eid = new_triple_set.eid
    # Clear/Create output folder
    if not os.path.exists(os.path.join(path_save_XMLs, folder_name)):
      os.makedirs(os.path.join(path_save_XMLs, folder_name))

    # For each slice of the triple set, create an XML file
    count_files_created = 0
    for count_files, i in enumerate(range(len(new_slices)-1)):
      list_triple_objects = new_triple_set.triples[new_slices[i]:new_slices[i+1]]
      properties_selected_by_user = [i for i in range(len(list_triple_objects))] # Use all properties
      unique_entity_name = entity_name+'_'+str(count_files)
      list_triples_text = create_xml(list_triple_objects, properties_selected_by_user, input_category, os.path.join(path_save_XMLs, folder_name), entity_name=unique_entity_name, eid = eid)
      count_files_created += 1
      total_number_of_triples += len(list_triple_objects)
    total_number_of_XMLs += count_files_created

  print(f'  Created {total_number_of_XMLs} split XML files of approximate size {max_num_triples}.')
  print(f'  There are {total_number_of_triples} input triples in the split XML files.')

split_XMLs(path_xml, path_DBprops_count, max_num_triples, path_save_XMLs, DEBUG = debug)

In [None]:
#@title Regroup FORGe outputs for split long-input inputs.
# Above we needed to split some XMLs because FORGe can explode on inputs too large, so now we need to build on output file aligned with the input file
import codecs
import os
import glob
import json

# 1 - Create a folder FORGe and upload outputs there, e.g. all_EN_dev_out_aligned.txt or v4_long-inputs_GREC_same3_min8_max100_SPLIT-22_en_000-299__SMorphText.conll_out.txt
# 2 - Upload file list_entities_same3_min8_max100_SPLIT-22.json created at the same time as the split XML (and found in the same folder).

forge_out_files = glob.glob('/content/FORGe/*.txt')
list_entities_split = json.load(codecs.open('/content/list_entities_same3_min8_max100_SPLIT-22.json', 'r', 'utf-8'))

# Put all FORGe texts in a list
all_forge_lines = []
for forge_out_file in sorted(forge_out_files):
  with codecs.open(forge_out_file, 'r', 'utf-8') as f:
    forge_out_lines = f.readlines()
    all_forge_lines += forge_out_lines

current_entity = None
count_different_entities = 0
with codecs.open('FORGe-all.txt', 'w', 'utf-8') as f:
  for line_count, entity_with_count in enumerate(list_entities_split):
    entity_name = entity_with_count.rsplit('_', 1)[0]
    print(f'{line_count} - {entity_name}')
    line = ''
    if entity_name != current_entity:
      current_entity = entity_name
      if line_count == 0:
        line = all_forge_lines[line_count].strip()
      else:
        line = line + '\n'
        line = line + all_forge_lines[line_count].strip()
      count_different_entities += 1
    else:
      line = line + ' '+all_forge_lines[line_count].strip()
    f.write(line)

print(f'Found {count_different_entities} different entities.')

In [None]:
#@title Make a shorter version of the reference texts
import os
import glob
import codecs

lang = 'EN'#@param['EN', 'GA']
proportion_kept = 0.9#@param{type:"slider", min:0, max:1, step:0.01}
folder_path = f'8_{lang}'
zip_path = f"/content/{folder_path}.zip"
# 9 is made with proportion_kept = 0.9
# 10 is made with proportion_kept = 0.7
# 11 is made with proportion kept = 0.5
folder_out_num = '9'
folder_out = f'{folder_out_num}_{lang}'

if not os.path.exists(os.path.join('/content', folder_out)):
  os.makedirs(os.path.join('/content', folder_out))

#Unzip file
# ! rm -r {folder_path}
if not os.path.exists(folder_path):
  ! unzip {zip_path} -d /content/{folder_path}


for text_file_path in glob.glob(f'/content/{folder_path}/8/*.txt'):
  filename = text_file_path.split('/')[-1]
  new_filename = filename.replace('[8_', '['+folder_out_num+'_')
  print(new_filename)
  sentences_list = codecs.open(text_file_path, 'r', 'utf-8').read().strip().split('. ')
  print(f'  {len(sentences_list)} sentences found.')
  num_sentences_kept = int(round(len(sentences_list)*proportion_kept, 0))
  print(f'  {len(sentences_list[:num_sentences_kept])} sentences kept')
  with codecs.open(os.path.join('/content', folder_out, new_filename), 'w', 'utf-8') as f:
    for sentence in sentences_list[:num_sentences_kept]:
      f.write(sentence+'. ')

# zip and download folder_out
! zip -r /content/{folder_out}.zip /content/{folder_out}
from google.colab import files
files.download(f'/content/{folder_out}.zip')