<a href="https://colab.research.google.com/github/mille-s/Text_Structuring/blob/main/WebNLG_TextStructuring.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Shared packages and functions
from IPython.display import clear_output

# Package for parsing xml files (WebNLG 23 and Enhanced WebNLG)
! pip install xmltodict

# Install SPARQLWrapper for making queries to DBpedia/Wikidata
! pip install SPARQLWrapper

# datasets is for loading datasets from HuggingFace (WebNLG 17, 18, 20)
! pip install datasets
from datasets import load_dataset

! pip install --upgrade gdown

! pip install dicttoxml

# Clone repos containing WebNLG processing modules and processed data
! git clone 'https://github.com/mille-s/Mod-D2T.git'
! git clone 'https://github.com/mille-s/M-FleNS_NLG-Pipeline.git'
! git clone 'https://github.com/mille-s/UD_Converter.git'
# Delete locally to avoid confusion
! rm '/content/UD_Converter/UD_Converter_release.ipynb'

clear_output()

def extractTripleElements(dataset, element):
  """ Returns a list of unique subjects, objects or properties extracted from triple sets"""
  n = ''
  if element == 'subject':
    n = 0
  elif element == 'property':
    n = 1
  elif element == 'object':
    n = 2
  else:
    print('Error, the second argument of extractTripleElements must be "subject", "property" or "object".')
  element_list = []
  for entry in dataset:
    for input_triple in entry[0]:
      element_name = input_triple.split(' | ')[n]
      if element_name not in element_list:
        element_list.append(element_name)
  return(element_list)

In [None]:
# @title Functions for wikidata and dbpedia queries
# Function list

import requests
import csv
import re
import progressbar
from SPARQLWrapper import SPARQLWrapper, JSON

bar = ''
def createProgressBar(bar, max):
  bar = progressbar.ProgressBar(max_value=max)
  return(bar)

def format_entity_dbp(entity):
  """
  Used for the 2024 experiments
  """
  # Add this line so all lines below have the same variable name on the right
  clean_entity = entity
  # Remove what is between parentheses; in the end better to keep and escape them
  # clean_entity = clean_entity.split('_(',1)[0]
  # clean_entity = clean_entity.split(' (',1)[0]
  # Replace underscores by spaces (for wikidata)
  # clean_entity = re.sub('_', ' ', clean_entity)
  # Replace ampersands by 'and' (for dbpedia, seems to affect results from wikidata though)
  # clean_entity = re.sub('&', 'and', clean_entity)
  # Escape other reserved characters
  clean_entity = re.sub('/', '\/', clean_entity)
  clean_entity = re.sub('\.', '\.', clean_entity)
  clean_entity = re.sub('\+', '\+', clean_entity)
  clean_entity = re.sub('\,', '\,', clean_entity)
  clean_entity = re.sub('\&', '\&', clean_entity)
  clean_entity = re.sub('\-', '\-', clean_entity)
  clean_entity = re.sub('\(', '\(', clean_entity)
  clean_entity = re.sub('\)', '\)', clean_entity)
  # Remove quotes, semi-colons and other things which are usually errors or hacks
  clean_entity = re.sub('"', '', clean_entity)
  clean_entity = re.sub(';', '', clean_entity)
  clean_entity = re.sub('~', '', clean_entity)
  clean_entity = re.sub('<', '', clean_entity)
  clean_entity = re.sub('>', '', clean_entity)
  # Other
  # I checked, it works like this...
  clean_entity = re.sub("'", "\\'", clean_entity)
  return clean_entity

def format_entity_wkd(entity):
  """
  Used for the GEM 2023-2024 data
  """
  # Remove what is after commas and between parentheses
  clean_entity = entity.split(',',1)[0].split('_(',1)[0]
  # Replace underscores by spaces (for wikidata)
  clean_entity = re.sub('_', ' ', clean_entity)
  # Remove quotes
  clean_entity = re.sub('"', '', clean_entity)
  return clean_entity

def assign_classRegEx(entity):
  classRegEx = ''
  if re.search('gramPerCubicCentimetres', entity):
    classRegEx = 'concentration_gPerCubCm'
  if re.search('kilogramPerCubicMetres', entity):
    classRegEx = 'concentration_kgPerCubM'
  elif re.search('inhabitants per square kilometre', entity):
    classRegEx = 'populationDensity'
  elif re.search('[0-9\.,]+.*square.*metre', entity):
    classRegEx = 'area_measurement'
  elif re.search('bombing', entity):
    classRegEx = 'event'
  elif re.search('[Uu]niversity', entity):
    classRegEx = 'university'
  elif re.search('Dodge', entity):
    classRegEx = 'car'
  elif re.search('^"*[0-9]{4}-[0-9]{2}-[0-9]{2}"*$', entity):
    classRegEx = 'date'
  elif re.search('^"*[0-9]+\s*-*(January|Jan|February|Feb|March|Mar|April|Apr|May|June|Jun|July|Jul|August|Aug|September|Sept|October|Oct|November|Nov|December|Dec)\s*-*[0-9]+"*$', entity):
    classRegEx = 'date'
  elif re.search('^"*(January|February|March|April|May|June|July|August|September|October|November|December)\s*-*[0-9]{4}"*$', entity):
    classRegEx = 'month'
  elif re.search('^"*(January|February|March|April|May|June|July|August|September|October|November|December)"*$', entity):
    classRegEx = 'month'
  elif re.search('^"*[0-9\.,]+.*(litres|cubic)', entity):
    classRegEx = 'volume_measurement'
  elif re.search('^"*[0-9\.,]+\s*m"*$', entity):
    classRegEx = 'distance_meters'
  elif re.search('^"*[0-9\.,]+\s*in"*$', entity):
    classRegEx = 'distance_inches'
  elif re.search('^"*[0-9\.,]+\s*yd"*$', entity):
    classRegEx = 'distance_yards'
  elif re.search('^"*[0-9\.,]+\s*ft"*$', entity):
    classRegEx = 'distance_feet'
  elif re.search('[0-9\.,]+.*millimetres', entity):
    classRegEx = 'distance_millimetres'
  elif re.search('[0-9\.,]+.*centimetres', entity):
    classRegEx = 'distance_centimetres'
  elif re.search('[0-9\.,]+.*metres', entity):
    classRegEx = 'distance_metres'
  elif re.search('[0-9\.,]+.*inches', entity):
    classRegEx = 'distance_inches'
  elif re.search('[0-9\.,]+.*yards', entity):
    classRegEx = 'distance_yards'
  elif re.search('[0-9\.,]+.*feet', entity):
    classRegEx = 'distance_feet'
  elif re.search('[0-9\.,]+.*seconds', entity):
    classRegEx = 'duration_seconds'
  elif re.search('[0-9\.,]+.*minutes', entity):
    classRegEx = 'duration_minutes'
  elif re.search('[0-9\.,]+.*hours', entity):
    classRegEx = 'duration_hours'
  elif re.search('[0-9\.,]+.*days', entity):
    classRegEx = 'duration_days'
  elif re.search('[0-9\.,]+.*weeks', entity):
    classRegEx = 'duration_weeks'
  elif re.search('[0-9\.,]+.*months', entity):
    classRegEx = 'duration_months'
  elif re.search('[0-9\.,]+.*years', entity):
    classRegEx = 'duration_years'
  elif re.search('[0-9\.,]+.* (engine|horsepower)', entity):
    classRegEx = 'engine'
  elif re.search('[0-9\.,]+.*euros', entity):
    classRegEx = 'moneyQuantity_euros'
  elif re.search('[0-9\.,]+.*dollars', entity):
    classRegEx = 'moneyQuantity_dollars'
  elif re.search('[0-9\.,]+.*kilometrePerSeconds', entity):
    classRegEx = 'speed_kmPerSec'
  elif re.search('[0-9\.,]+.*degreeCelsius', entity):
    classRegEx = 'temperature_celsius'
  elif re.search('[0-9\.,]+.*kelvins', entity):
    classRegEx = 'temperature_kelvin'
  elif re.search('[0-9\.,]+-speed', entity):
    classRegEx = 'transmission'
  elif re.search('^"*[0-9\.,]+.*(\sg|grams)', entity):
    classRegEx = 'weight_grams'
  elif re.search('^"*[0-9\.,]+.*\skg', entity):
    classRegEx = 'weight_kilograms'
  elif re.search('^"*[0-9\.,]+.*tonnes', entity):
    classRegEx = 'weight_tonnes'
  elif re.search('^"*[0-9\.,]+.*pounds', entity):
    classRegEx = 'weight_pounds'
  elif re.search('^"*[0-9]+/[0-9]+"*$', entity):
    classRegEx = 'fraction'
  elif re.search('^"*[0-9a-zA-Z]+/[0-9a-zA-Z\s\']+"*$', entity):
    classRegEx = 'runwayName'
  elif re.search('^"*[0-9]{4}[-–][0-9]{4}"*$', entity):
    classRegEx = 'issnNumber'
  elif re.search('^"*[0-9]+[-–][0-9]+[-–][0-9]+[-–][0-9]+[-–]*[0-9]*"*$', entity):
    classRegEx = 'isbnNumber'
  elif re.search('^"*[0-9]+[-–][0-9]+[-–]*[0-9]*[-–]*[0-9]*[-–]*[0-9]*"*$', entity):
    classRegEx = 'unknownIdentifier'
  elif re.search('^"*[0-9]{2} [a-zA-Z]+"*$', entity):
    classRegEx = 'celestialBody'
  elif re.search('^"*[0-9]{3} [a-zA-Z]+"*$', entity):
    classRegEx = 'celestialBody'
  elif re.search('^"*[0-9]+_[a-zA-Z]+"*$', entity):
    classRegEx = 'celestialBody'
  elif re.search('_FC_', entity):
    classRegEx = 'footballClub'
  elif re.search('(season|EPSTH|league|League|Liga|Season|Bundesliga|Eredivisie|Football_Conference|Lega_Pro|Regionalliga|Serie_A|Serie_B|Topklasse|Campeonato)', entity):
    classRegEx = 'sportsSeason'
  elif re.search('[Mm]onument', entity):
    classRegEx = 'monument'
  elif re.search('^"*[0-9]{4} [a-zA-Z]+', entity):
    classRegEx = 'celestialBody'
  elif re.search('^"*[0-9-]+[stndr]*[\s\-_][^:]*[\(\)a-zA-Z\']+"*$', entity):
    if not re.search('JD2457600', entity):
      classRegEx = 'address'
    else:
      classRegEx = 'date_epoch'
  elif re.search('^"*[\+-]*[0-9\.,]+"*$', entity):
    classRegEx = 'unknownQuantity'
  elif re.search('[0-9\.,]+, [0-9\.,]+', entity):
    classRegEx = 'unknownQuantity_multiple'

  return(classRegEx)

# Code below adapted from ChatGPT
def get_wikidata_id(entity_label):
  # Define the Wikidata API endpoint
  wikidata_api_url = "https://www.wikidata.org/w/api.php"

  # Set the parameters for the API request
  params = {
    "action": "wbsearchentities",
    "format": "json",
    "language": "en",  # You can change the language if needed
    "search": entity_label,
  }

  try:
    # Send a GET request to the Wikidata API
    response = requests.get(wikidata_api_url, params=params)
    response.raise_for_status()
    # Parse the JSON response
    data = response.json()
    # Check if any entities were found
    if "search" in data and data["search"]:
      # Get the first entity (assuming it's the most relevant)
      entity_id = data["search"][0]["id"]
      return entity_id
    return None  # Entity not found

  except requests.exceptions.RequestException as e:
    print("Error connecting to the Wikidata API:", e)
    return None

# Example usage:
# entity_label = '23 g'
# wikidata_id = get_wikidata_id(entity_label)
# if wikidata_id:
#   print(f"The Q-ID for {entity_label} is {wikidata_id}.")
# else:
#   print(f"No entity found for {entity_label}.")
# print(assign_classRegEx(entity_label))

def get_wikidata_id_bulk(rows, list_entities, bar):
  bar = createProgressBar(bar, len(list_entities)-1)
  for count, entity in enumerate(list_entities):
    bar.update(count)
    row = []
    clean_entity = ''
    if entity == 'School of Business and Social Sciences at the Aarhus University':
      clean_entity = 'Aarhus School of Business'
    else:
      clean_entity = format_entity_wkd(entity)
    wikidata_id = get_wikidata_id(clean_entity)
    # wikidata_id = None
    if wikidata_id:
      # print(f"The Q-ID for {entity} is {wikidata_id}.")
      row.append(wikidata_id)
      row.append(clean_entity)
      row.append(entity)
      row.append(assign_classRegEx(entity))
    else:
      # print(f"No entity found for {entity}.")
      row.append('???')
      row.append(clean_entity)
      row.append(entity)
      row.append(assign_classRegEx(entity))
    rows.append(row)

# ChatGPT prompt: Please write some Python code to get the value of an named entity's "gold:hypernym" property according to DBpedia
def get_dbpedia_hypernym(entity_name):
  entity_name = format_entity_dbp(entity_name)
  # For DBpedia specifically, we need to replace spaces by underscores in entity names to avoid query errors
  entity_name = ('_').join(entity_name.split(' '))

  # Set up the SPARQL endpoint
  sparql = SPARQLWrapper("http://dbpedia.org/sparql")

  # Define the SPARQL query
  query = f"""
  SELECT ?hypernym
  WHERE {{
    dbr:{entity_name} gold:hypernym ?hypernym
  }}
    """

  # Set the query and response format
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)

  # Execute the query
  results = sparql.query().convert()

  # Extract and return the hypernym value
  if 'results' in results and 'bindings' in results['results']:
    bindings = results['results']['bindings']
    # print(bindings)
    # Return the first value only
    if bindings:
      if re.search('/', bindings[0]['hypernym']['value']):
        return bindings[0]['hypernym']['value'].rsplit('/',1)[1]
      else:
        return bindings[0]['hypernym']['value']

    return None

# ChatGPT prompt: Please write some Python code to get the value of an named entity's "gold:hypernym" property according to Wikidata
def get_wikidata_hypernym(entity_ID):
  # Define the Wikidata Query Service endpoint URL
  wikidata_endpoint = "https://query.wikidata.org/sparql"

  # Define the SPARQL query
  query = f"""
  SELECT ?hypernymLabel
  WHERE {{
    wd:{entity_ID} wdt:P31 ?hypernym.
    SERVICE wikibase:label {{ bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }}
  }}
  """

  # Set up the request headers
  headers = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
    'Accept': 'application/json'
  }

  # Set up the request parameters
  params = {
    'query': query,
    'format': 'json'
  }

  # Make the API request
  response = requests.get(wikidata_endpoint, headers=headers, params=params)

  # Parse the JSON response
  data = response.json()

  # Extract and return the hypernym value
  if 'results' in data and 'bindings' in data['results']:
    bindings = data['results']['bindings']
    if bindings:
      return bindings[0]['hypernymLabel']['value']

  return None

def get_Wikidata_id_property(dbpedia_prop):
  # Set up the SPARQL endpoint

  wikidata_ids = []

  sparql = SPARQLWrapper("http://dbpedia.org/sparql")

  # Define the SPARQL query to get the Wikidata ID for a DBpedia property
  query = f"""
  SELECT ?wikidataProperty
  WHERE {{
    <{dbpedia_prop}> owl:equivalentProperty ?wikidataProperty .
    FILTER(STRSTARTS(STR(?wikidataProperty), "http://www.wikidata.org/entity/"))
  }}
  """

  # Set the query
  sparql.setQuery(query)
  sparql.setReturnFormat(JSON)

  try:
    # Execute the query
    results = sparql.query().convert()

    # Extract and print the Wikidata property ID
    wikidata_properties = [result["wikidataProperty"]["value"] for result in results["results"]["bindings"]]
    # print(wikidata_properties)
    for prop in wikidata_properties:
      # Extract the Wikidata ID (e.g., "P569") from the full URL
      wikidata_id = prop.split('/')[-1]
      if wikidata_id not in wikidata_ids:
        wikidata_ids.append(wikidata_id)

  except Exception as e:
    print(f"An error occurred: {e}")

  return(wikidata_ids)

# Enhanced WebNLG

## Make text planning dataset

In [None]:
#@title Load and download data
# For enhanced WebNLG
from IPython.display import clear_output
import os
import os.path
import shutil

def clear_folder(folder):
  "Function to clear whole folders."
  if os.path.exists(folder) and os.path.isdir(folder):
    try:
      shutil.rmtree(folder)
    except Exception as e:
      print('Failed to delete %s. Reason: %s' % (folder, e))

dataset = load_dataset("enriched_web_nlg", "en", trust_remote_code=True)

# Clone original data
!git clone https://github.com/ThiagoCF05/webnlg.git
path_data_en = '/content/webnlg/data/v2.0/en'

splits = ['train', 'dev', 'test']

clear_output()

# Problem: the HuggingFace dataset does not contain the sentence groupings...
# data = dataset['dev']

# # Accessing by brackets does not have a default but the "get" method does and the default is None.
# for c, entry in enumerate(data):
#   if c == 401:
#     print(f'Entry #{c}')
#     texts = entry.get('lex').get('text')
#     print(texts)
#     sorted_triple_sets = entry.get('lex').get('sorted_triple_sets')
#     for sorted_triple_set in sorted_triple_sets:
#       print(sorted_triple_set)
#       #extractTripleElements(sorted_triple_set)

In [None]:
#@title Clean data
# There seems to be a problem with the original data, so I need to clean it
import re
import codecs
import glob

new_path = '/content/webnlg/data/v2.1/en'
if os.path.exists(new_path):
  clear_folder(new_path)
else:
  os.makedirs(new_path)
  for split in splits:
    new_subfolder = os.path.join(new_path,split)
    os.makedirs(new_subfolder)

path_original_data = '/content/webnlg/data/v2.0/en'
# Get to the 3 subfolders for each data split
for split in splits:
  path_to_open = os.path.join(path_original_data, split)
  # Get to the (up to) 7 subfolders for each input size
  for triple_size_folder in glob.glob(os.path.join(path_to_open,'*triples')):
    # create folder for each triple
    new_triple_size_folder = re.sub('2.0', '2.1', triple_size_folder)
    os.makedirs(new_triple_size_folder)
    # Get to the input files
    for xml_file in glob.glob(os.path.join(triple_size_folder, '*.xml')):
      out_path = re.sub('2.0', '2.1', xml_file)
      lines = codecs.open(xml_file, 'r', 'utf-8').readlines()
      # Create new file
      with codecs.open(out_path, 'w', 'utf-8') as fo:
        for line in lines:
          # Do not write lines that look suspicious
          if re.search('^\s*<[^\s]+/>\n', line):
            pass
          else:
            fo.write(line)

In [None]:
#@title Save a version of the test data for running Thiago's evaluation
import glob
import codecs
import os
import re

#  The test data should contain inputs of size 2 and above, ordered by triple size, and for each size ordered alphabetically by category, and for each category ordered by eid.
# I don't remember if this was used in the en or not; not trace of out_path_testFile_tcf or of Enhanced_WebNLG17 in my code...

# All the cleaned stuff came from the lex field; in the last version of the data we remove this field so no more difference between 2.0 and 2.1 now.
data_to_use = 'v2.0'#@param['v2.0', 'v2.1']

in_path_testFile_tcf = os.path.join('/content/webnlg/data', data_to_use, 'en', 'test')
out_path_testFile_tcf = os.path.join('/content/webnlg/data', data_to_use, 'en', 'test', 'Enhanced_WebNLG17-test_2-7.xml')

lines_to_ignore = ['<?xml version="1.0" ?>', '<benchmark>', '</benchmark>', '<entries>', '</entries>', '']

with codecs.open(out_path_testFile_tcf, 'w', 'utf-8') as fo_test_tcf:
  # Keep folders sorted to maintain alignment with Thiago's files
  fo_test_tcf.write('<?xml version="1.0" encoding="UTF-8" standalone="yes"?>\n<benchmark>\n  <entries>\n')
  for triple_size_folder in sorted(glob.glob(os.path.join(in_path_testFile_tcf,'*triples'))):
    # Exclude 1 triples (not useful for sentence packaging).
    if not re.search('1triple', triple_size_folder):
      # Get to the input files (keep them sorted to maintain alignment with Thiago's files)
      for xml_file_tcf in sorted(glob.glob(os.path.join(triple_size_folder, '*.xml'))):
        print(f'------------------------\n{xml_file_tcf}\n------------------------')
        lines_xml_file_tcf = codecs.open(xml_file_tcf, 'r', 'utf-8').readlines()
        # Need to replace <otriple> by <mtriple> in <modifiedtripleset>
        within_modified_triple_field = False
        within_lex_field = False
        for line_xml_file_tcf in lines_xml_file_tcf:
          # Process lines
          if re.search('<modifiedtripleset>', line_xml_file_tcf):
            within_modified_triple_field = True
          if re.search('</modifiedtripleset>', line_xml_file_tcf):
            within_modified_triple_field = False
          if re.search('<lex ', line_xml_file_tcf):
            within_lex_field = True
          if re.search('</lex>', line_xml_file_tcf):
            within_lex_field = False
          if re.search('</modifiedtripleset>', line_xml_file_tcf):
            within_modified_triple_field = False
          if re.search('<otriple>', line_xml_file_tcf) and within_modified_triple_field == True:
            line_xml_file_tcf = re.subn('<otriple>', '<mtriple>', line_xml_file_tcf)[0]
          if re.search('</otriple>', line_xml_file_tcf) and within_modified_triple_field == True:
            line_xml_file_tcf = re.subn('</otriple>', '</mtriple>', line_xml_file_tcf)[0]
          while re.search('\t', line_xml_file_tcf):
            line_xml_file_tcf = re.subn('\t', '  ', line_xml_file_tcf)[0]
          # Write out file
          if line_xml_file_tcf.strip() in lines_to_ignore :
            pass
          elif within_lex_field == True :
            pass
          elif re.search('</lex>', line_xml_file_tcf) :
            pass
          else:
            fo_test_tcf.write(line_xml_file_tcf)
  fo_test_tcf.write('  </entries>\n</benchmark>')

In [None]:
#@title Extract sentence groupings
import glob
import xmltodict
import xml.etree.ElementTree as ET
import xmltodict
import json

# splits = ['dev']
splits = ['dev', 'test', 'train']
exclude_input_size = 'size1'#@param['none', 'size1', 'size1&2']

def getOriginalInputTriples(entry, split):
  orig_triples = []
  if split == 'train' or split == 'dev':
    if isinstance(entry['modifiedtripleset']['mtriple'], list):
      for original_input_triple in entry['modifiedtripleset']['mtriple']:
        orig_triples.append(original_input_triple)
    else:
      orig_triples.append(entry['modifiedtripleset']['mtriple'])
  # In the test data, the mtriple field was replaced by otriple
  elif split == 'test':
    if isinstance(entry['modifiedtripleset']['otriple'], list):
      for original_input_triple in entry['modifiedtripleset']['otriple']:
        orig_triples.append(original_input_triple)
    else:
      orig_triples.append(entry['modifiedtripleset']['otriple'])
  return orig_triples

def extract_groups(text_planning_data, split, data_point, dtp_count, count_input_num, xml_file, metadata):
  #To catch ill-formed data points (see else of this condition)
  if 'sortedtripleset' in data_point:
    if 'sentence' in data_point['sortedtripleset']:
      text_planning_data[split].append([])
      # For each datapoint, there is one or more sentence(s); if there are 2 or more sentences, we have a list in data_point['sortedtripleset']['sentence']
      if isinstance(data_point['sortedtripleset']['sentence'], list):
        sent_count = 0
        for ref_sentence in data_point['sortedtripleset']['sentence']:
          # For each sentence, there is one or more triple(s); if there are a least 2 triples, ref_sentence['striple'] is a list
          if 'striple' in ref_sentence:
            text_planning_data[split][dtp_count].append([])
            if isinstance(ref_sentence['striple'], list):
              for triple in ref_sentence['striple']:
                # print(f"List sentences list triple: {triple}")
                text_planning_data[split][dtp_count][sent_count].append(triple)
            # If there is only one triple in a sentence, no list
            else:
              # print(f"List sentences individual triple: {ref_sentence['striple']}")
              text_planning_data[split][dtp_count][sent_count].append(ref_sentence['striple'])
            sent_count += 1
          else:
            print(f'ERROR Data "striple" field not found: {xml_file}, input#{count_input_num}')
      # On the contrary, if there is only one sentence, there is no list in data_point['sortedtripleset']['sentence']
      else:
        # Same as above: for each sentence, there is one or more triple(s); if there are a least 2 triples, data_point['sortedtripleset']['sentence']['striple'] is a list
        if 'striple' in data_point['sortedtripleset']['sentence']:
          # Adding a list although there will be only one text to maintain consistent formatting in the whole dataset
          text_planning_data[split][dtp_count].append([])
          if isinstance(data_point['sortedtripleset']['sentence']['striple'], list):
            for triple in data_point['sortedtripleset']['sentence']['striple']:
              # print(f"Single sentence list member: {triple}")
              text_planning_data[split][dtp_count][0].append(triple)
          # If there is only one triple in a sentence, no list
          else:
            # print(f"Single sentence individual triple: {data_point['sortedtripleset']['sentence']['striple']}")
            text_planning_data[split][dtp_count][0].append(data_point['sortedtripleset']['sentence']['striple'])
        else:
          print(f'ERROR Data "striple" field not found: {xml_file}, input#{count_input_num}')
      # print(f'   {text_planning_data[split][dtp_count]}')
      # add metadata
      text_planning_data[split][dtp_count].append(metadata)
      # Update counter after each text
      dtp_count += 1
    else:
      print(f'ERROR Data "sentence" field not found: {xml_file}, input#{count_input_num}')
  else:
    print(f'ERROR Data "sortedtripleset" field not found: {xml_file}, input#{count_input_num}')
  # Check if an empty list was created because no sentence was found in the dataset:
  if len(text_planning_data[split][dtp_count-1]) == 0:
    print(f'AAAAAAAAAAAAAAH empty data point (no text): {xml_file}, input#{count_input_num}')
  # If the list (text) is not empty, check if it contains a list (sentence) that is empty (shouldn't happen given how these lists are built)
  else:
    for sentence in text_planning_data[split][dtp_count-1]:
      if len(sentence) == 0:
        print(f'OOOH empty sentence: {xml_file}, input#{count_input_num}')
  return(dtp_count)

# This will be a dictionary with 3 keys, one per split, and for each key, a list of text contents; each list element is a list of sentences, each of which is a list of one or more triple(s).
# Metadata of the input triple the sentence comes from is stored at the end of each list of sentences in another list
# E.g. text_planning_data['train'][0] = [['103_Colmore_Row | location | "Colmore Row, Birmingham, England"', '103_Colmore_Row | completionDate | 1976', '103_Colmore_Row | architect | John_Madin'], ['103_Colmore_Row | floorCount | 23'], ['Building', 'Id1', '4']]
text_planning_data = {}

path_data_en = '/content/webnlg/data/v2.1/en'
# Get to the 3 subfolders for each data split
for split in splits:
  text_planning_data[split] = []
  path_to_open = os.path.join(path_data_en, split)
  # Initialise a counter of data points (one data point per reference text, as opposed to one data point per input, since each text is a different data point due to different text plan)
  # I don't separate the training data into size batches
  dtp_count = 0
  # Get to the (up to) 7 subfolders for each input size (keep them sorted to maintain alignment with Thiago's files)
  for triple_size_folder in sorted(glob.glob(os.path.join(path_to_open,'*triples'))):
    # Exclude 1 triples (not useful for sentence packaging); also excluding 2 triples, which may create noise because in many cases one will want one sentence.
    if (exclude_input_size == 'size1' and re.search('1triple', triple_size_folder)) or (exclude_input_size == 'size1&2' and (re.search('1triple', triple_size_folder) or re.search('2triples', triple_size_folder))):
      pass
    else:
      # Get to the input files (keep them sorted to maintain alignment with Thiago's files)
      for xml_file in sorted(glob.glob(os.path.join(triple_size_folder, '*.xml'))):
        print(f'------------------------\n{xml_file}\n------------------------')
        # To capture input number for debugging (numbering starts at 1 in files)
        count_input_num = 1
        # if xml_file == '/content/webnlg/data/v2.0/en/dev/2triples/Building.xml':
        #to change the encoding type to be able to set it to the one you need
        tree = ET.parse(xml_file)
        xml_data = tree.getroot()
        xmlstr = ET.tostring(xml_data, encoding='utf-8', method='xml')
        input_file_dict = dict(xmltodict.parse(xmlstr))
        # If there are several entries in the dataset
        if isinstance(input_file_dict['benchmark']['entries']['entry'], list):
          # Get to the list that contains every input as a list element
          for entry in input_file_dict['benchmark']['entries']['entry']:
            # Get metadata for each input (also keep original input as part of the metadata (ugly but easier at this point))
            metadata = [str(split), entry['@category'], entry['@eid'], entry['@size'], getOriginalInputTriples(entry, split)]
            # For each input, there is one data point per reference text
            # If there are several reference texts, entry['lex'] will be a list, otherwise not
            if isinstance(entry['lex'], list):
              for data_point in entry['lex']:
                # Update dtp count when calling the function
                dtp_count = extract_groups(text_planning_data, split, data_point, dtp_count, count_input_num, xml_file, metadata)
            # If entry['lex'] is not a list, do all the same as above but replacing data_point per entry['lex']
            else:
              # Update dtp count when calling the function
              dtp_count = extract_groups(text_planning_data, split, entry['lex'], dtp_count, count_input_num, xml_file, metadata)
            count_input_num += 1
        # If there is only one entry in the dataset (e.g. test 5 triples comicsCharacters)
        else:
          # Get metadata for each input
          metadata = [str(split), input_file_dict['benchmark']['entries']['entry']['@category'], input_file_dict['benchmark']['entries']['entry']['@eid'], input_file_dict['benchmark']['entries']['entry']['@size'], getOriginalInputTriples(input_file_dict['benchmark']['entries']['entry'], split)]
          if isinstance(input_file_dict['benchmark']['entries']['entry']['lex'], list):
            for data_point2 in input_file_dict['benchmark']['entries']['entry']['lex']:
              # Update dtp count when calling the function
              dtp_count = extract_groups(text_planning_data, split, data_point2, dtp_count, count_input_num, xml_file, metadata)
          # If entry['lex'] is not a list, do all the same as above but replacing data_point per entry['lex']
          else:
            # Update dtp count when calling the function
            dtp_count = extract_groups(text_planning_data, split, input_file_dict['benchmark']['entries']['entry']['lex'], dtp_count, count_input_num, xml_file, metadata)
          count_input_num += 1

clear_output()

# for dtpX in text_planning_data['test']:
#   print(dtpX)

In [None]:
#@title Debug: Print test data
for dtpX in text_planning_data['test']:
  print(dtpX)

In [None]:
#@title Create raw input/output pairs
from collections import Counter

print(f"Sample raw data point: {text_planning_data['train'][0]}")

def camelCaseClean(text):
  words = [[text[0]]]
  for c in text[1:]:
    if words[-1][-1].islower() and c.isupper():
      words.append(list(c.lower()))
    else:
      words[-1].append(c)
  words = [''.join(word) for word in words]
  cleaned = ' '.join(words)
  return cleaned

class TP_Datapoint:
  def __init__(self, input_triple_list, text_plan, orig_input):
    # input: [<__main__.Triple object at 0x7850dbe26590>, <__main__.Triple object at 0x7850dbe265c0>, <__main__.Triple object at 0x7850dbe265f0>, <__main__.Triple object at 0x7850dbe26620>]
    # output: [[<__main__.Triple object at 0x7850dbe26590>], [<__main__.Triple object at 0x7850dbe265c0>, <__main__.Triple object at 0x7850dbe265f0>, <__main__.Triple object at 0x7850dbe26620>]]
    # original_input: [<__main__.Triple object at 0x7850dbe26590>, <__main__.Triple object at 0x7850dbe265c0>, <__main__.Triple object at 0x7850dbe265f0>, <__main__.Triple object at 0x7850dbe26620>]
    self.input = input_triple_list
    self.output = text_plan
    self.original_input = orig_input

class Triple:

  def camelCaseCleanProperty(self, triple_property):
    word_letters = [[triple_property[0]]]
    for c in triple_property[1:]:
      if word_letters[-1][-1].islower() and c.isupper():
        word_letters.append(list(c.lower()))
      else:
        word_letters[-1].append(c)
    word_letters = [''.join(letter) for letter in word_letters]
    cleaned = ' '.join(word_letters)
    # Remove underscore too
    cleaned = re.subn('_', ' ', cleaned)[0]
    return cleaned

  def __init__(self, triple, metadata):
    # List of common prepositions found on the net
    prepositions_list = ['above', 'across', 'against', 'along', 'among', 'around', 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'between', 'by', 'down', 'from', 'in', 'into', 'near', 'of', 'off', 'on', 'to', 'toward', 'under', 'upon', 'with', 'within']
    # Metadata of the triple set
    self.data_split = metadata[0]
    self.category = metadata[1]
    self.eid = metadata[2]
    self.size = metadata[3]
    # Will be used to sort the data as used by Thiago
    self.uniqueInputID = f'{metadata[0]}_{metadata[1]}_{ metadata[3]}_{ metadata[2]}'
    self.subject_label = triple.split(' | ')[0]
    self.property_label = triple.split(' | ')[1]
    self.object_label = triple.split(' | ')[2]
    self.subject_class_wkd = '_'
    self.subject_class_dbp = '_'
    self.subject_class_regex = '_'
    self.subject_class_merged = '_'
    self.subject_class_abstract = '_'
    self.subject_id_wkd = None
    # To store the entity ID in the input (and keep track of coreference between entities)
    self.subject_id_input = None
    self.property_class = '_'
    self.property_lemma= '_'
    self.object_class_wkd = '_'
    self.object_class_dbp = '_'
    self.object_class_regex = '_'
    self.object_class_merged = '_'
    self.object_class_abstract = '_'
    self.object_id_wkd = None
    # To store the entity ID in the input (and keep track of coreference between entities)
    self.object_id_input = None
    # remove parentheses, replace underscores by spaces, split camel casing and add spaces?
    self.property_label_split = self.camelCaseCleanProperty(self.property_label)
    # Assign a lemma to the property
    # If there are several words in the porperty label
    if len(self.property_label_split.split(' ')) > 1:
      has_prep = False
      for property_label_part in self.property_label_split.split(' '):
        # If there is a preposition somewhere in the label, take the first word as the lemma (that's very simplisitc but works in most cases)
        if property_label_part in prepositions_list:
          # index_part = self.property_label_split.split(' ').index(property_label_part)
          has_prep = True
          # print(f'{self.property_label_split} -> {self.property_lemma}')
      if has_prep == True:
        self.property_lemma = self.property_label_split.split(' ')[0]
      # Otherwise, take the last word
      else:
        self.property_lemma = self.property_label_split.split(' ')[-1]
    # If there is only one word in the property name, just use that word
    else:
      self.property_lemma = self.property_label

# final_data will be a dictionary with 3 keys, one per split.
# Each split is a list of datapoints.
# Each datapoint is a TP_Datapoint object with 2 features: input (list of Triple objects), output (list of list of Triple objects)
# Sample input/output pair:
# 	Input: [<__main__.Triple object at 0x7ef56ed13ac0>, <__main__.Triple object at 0x7ef56ed13af0>, <__main__.Triple object at 0x7ef56ed13b20>, <__main__.Triple object at 0x7ef56ed13b50>, <__main__.Triple object at 0x7ef56ed13b80>, <__main__.Triple object at 0x7ef56ed13bb0>]
# 	Output: [[<__main__.Triple object at 0x7ef56ed13ac0>, <__main__.Triple object at 0x7ef56ed13af0>], [<__main__.Triple object at 0x7ef56ed13b20>, <__main__.Triple object at 0x7ef56ed13b50>, <__main__.Triple object at 0x7ef56ed13b80>, <__main__.Triple object at 0x7ef56ed13bb0>]]

final_data = {}
# Store subjects and objects to get their hypernyms later on and query dbpedia/wikidata once per entity instead of once per entity instance in an input
all_entities_input = []
removed_dtp = []
for data_split in text_planning_data:
  # Create a list for each split
  final_data[data_split] = []
  for text_data in text_planning_data[data_split]:
    # The last element of the list is metadata; if there is nothing before that, remove datapoint
    # Update: these datapoints were note removed in Thiago's experiments, so do I need to include them (for my test/dev data, otherwise I have different test/dev files from references?)
    if len(text_data[:-1]) == 0:
      # print(text_data)
      removed_dtp.append(f'{data_split}_{text_data[-1][1]}_{text_data[-1][3]}_{text_data[-1][2]}')
    else:
      # Create a list that will contain the plan (i.e. sentence groupings) of each text
      text_plan = []
      input_triples = []
      sent_id = 0
      # the last element in text_data is the metadata, not a sentence
      metadata = text_data[-1]
      for sentence_data in text_data[:-1]:
        # Create a list for each sentence grouping
        text_plan.append([])
        for triple_data in sentence_data:
          triple_object = Triple(triple_data, metadata[:-1])
          # Store the triples in each respective sentence
          text_plan[sent_id].append(triple_object)
          # Store the triples of all sentences in the "input" list
          input_triples.append(triple_object)
          # store all subjects and objects in a list to get their hypernyms/IDs in the next loop
          if triple_object.subject_label not in all_entities_input:
            all_entities_input.append(triple_object.subject_label)
          if triple_object.object_label not in all_entities_input:
            all_entities_input.append(triple_object.object_label)
        sent_id += 1
      # Append input/output pair in final data structure
      orig_input_objects = []
      for orig_input_triple in metadata[-1]:
        orig_input_objects.append(Triple(orig_input_triple, metadata[:-1]))
      final_data[data_split].append(TP_Datapoint(input_triples, text_plan, orig_input_objects))

dtp_count_all = []
for dt_split in final_data:
  dtp_count_all.append(len(final_data[dt_split]))

print(f"Sample input/output pair:\n\tInput: {final_data['train'][0].input}\n\tOutput: {final_data['train'][0].output}\n\tOrig input: {final_data['train'][0].original_input}")

if exclude_input_size == 'size1&2':
  print(f'Removed {len(removed_dtp)} empty data points.\n  Expected: 14; {Counter(removed_dtp)}.')
  print(f'{str(dtp_count_all)} input/output pairs were collected (3 triples and more, expected: [1311, 2860, 10363]).')
  print(f'There are {len(all_entities_input)} different subject/object values in Enhanced WebNLG (expected: 2356).')

elif exclude_input_size == 'size1':
  print(f'Removed {len(removed_dtp)} empty data points.\n  Expected: 18; {Counter(removed_dtp)}.')
  print(f'{str(dtp_count_all)} input/output pairs were collected (2 triples and more, expected: [1725, 3834, 13726]).')
  print(f'There are {len(all_entities_input)} different subject/object values in Enhanced WebNLG (expected: 2542).')

elif exclude_input_size == 'none':
  print(f'Removed {len(removed_dtp)} empty data points.\n  Expected: 28; {Counter(removed_dtp)}.')
  print(f'{str(dtp_count_all)} input/output pairs were collected (all triples sizes, expected: [2254, 4916, 18071]).')
  print(f'There are {len(all_entities_input)} different subject/object values in Enhanced WebNLG (expected: 2730).')

# print(len(final_data['dev']))



In [None]:
#@title Debug: Print sample datapoints
for itp in final_data['test'][0].input:
  print(itp.property_label)
print('\n')
for itp in final_data['test'][2].input:
  print(itp.property_label)
print('\n')
for itp in final_data['test'][5].input:
  print(itp.property_label)
print('\n')
for itp in final_data['test'][5].original_input:
  print(itp.property_label)
  print(itp.subject_label)
  print(itp.object_label)

In [None]:
#@title Get class information from Wikidata and DBpedia for input pairs
get_classes_entities = 'Via saved JSON'#@param['Via saved JSON', 'Via live query']

import sys
import codecs
import json

# Store here problematic inputs to test them quickly
# all_entities_input = ['Albert_Einstein', 'Lippincott_Williams_&_Wilkins', 'Washington,_D.C.', 'Am._J._Math.', '18R/36L', '-3.3528', '30843.8 (square metres)', 'Atatürk_Monument_(İzmir)', '"52.0"(minutes)', '"A894 VA; A904 VD;"', '~500', '<http://www.ghampara.gov.lk/>', 'Edwin E. Aldrin, Jr.']
#  Extract all entities in the dataset (now done during the building of final_data dico)
# all_entities_input = []
# for split in final_data:
#   # print(split)
#   for data_point in final_data[split]:
#     for triple in data_point.input:
#       subject_value = triple.subject_label
#       if subject_value not in all_entities_input:
#         all_entities_input.append(subject_value)
#       object_value = triple.object_label
#       if object_value not in all_entities_input:
#         all_entities_input.append(object_value)

all_entities_classes_by_entity = {}

# Use following if loading the json, otherwise run previous cell (takes a while, about 25min) and update variable name to all_entities_classes_by_entity

if get_classes_entities == 'Via saved JSON':
  ! gdown 11MYilE-SC43dBR1ux_7Yw-iCq9lNk0lO
  # ! gdown 1bT4TsE3MjTCsyvc18b_l6z9SmkhMzqyG
  with open('webnlg_dbp-wkd-classes.json') as json_file:
    all_entities_classes_by_entity = json.load(json_file)
    # print(all_entities_classes_by_entity['American_Civil_War'])

elif get_classes_entities == 'Via live query':
  # First get the DBpedia hypernym/class and add it to the dictionary under the key created by the Wikidata loop
  print(f'Querying DBpedia for hypernyms...')
  with codecs.open('dbpedia_info.txt', 'w', 'utf-8') as fo2:
    for count3, entity_label in enumerate(all_entities_input):
      print(f'{count3+1}/{len(all_entities_input)}')
      # Only process each entity once
      if entity_label not in all_entities_classes_by_entity:
        hypernym_value = get_dbpedia_hypernym(entity_label)
        all_entities_classes_by_entity[entity_label] = {}
        if hypernym_value == None:
          # So all empty values for class have the same format
          hypernym_value = ''
        all_entities_classes_by_entity[entity_label]['class_dbp'] = hypernym_value
        line = f'{entity_label}: '
        if not hypernym_value == '':
          line = line + f'class_dbp[{hypernym_value}]'
        line = line + '\n'
        fo2.write(line)

  print(f'\nQuerying Wikidata for QIDs...')
  # Now get Wikidata QID
  # Wikidata IDs is a list of lists of quadruples, with the Wikidata ID, the clean entity name, the original entity name and a basic class assigned using regex
  # ex: [['Q937', 'Albert Einstein', 'Albert_Einstein'], ['Q76', 'Barack Obama', 'Barack_Obama']]
  wikidata_ids_and_labels = []
  get_wikidata_id_bulk(wikidata_ids_and_labels, all_entities_input, bar)
  # print(wikidata_ids_and_labels)
  # print(f'\nThere are {len(all_entities_input):,} different subject and object values in the dataset\n')

  entities_seen_wikidata = []
  # Now get the hypernym/class of each entity according to Wikidata
  print(f'\nQuerying Wikidata for hypernyms...')
  with codecs.open('wikidata_info.txt', 'w', 'utf-8') as fo:
    for count2, wikidata_id_and_labels in enumerate(wikidata_ids_and_labels):
      print(f'{count2+1}/{len(wikidata_ids_and_labels)}')
      # if entity_label not in entities_seen_wikidata:
      if wikidata_id_and_labels[2] not in entities_seen_wikidata:
        hypernym_value = ''
        entities_seen_wikidata.append(wikidata_id_and_labels[2])
        # If there is no QID, make it an empty value for the class
        if wikidata_id_and_labels[0] == '???':
          hypernym_value = ''
        else:
          # Otherwise, go check Wikidata
          hypernym_value = get_wikidata_hypernym(wikidata_id_and_labels[0])
        # Fill in the dictionary with the wikidata and regex info
        # wikidata_id_and_labels[2] is the same as entity_label, so there should be a match with the existing dico keys
        all_entities_classes_by_entity[wikidata_id_and_labels[2]]['QID'] = wikidata_id_and_labels[0]
        all_entities_classes_by_entity[wikidata_id_and_labels[2]]['class_wkd'] = hypernym_value
        all_entities_classes_by_entity[wikidata_id_and_labels[2]]['class_regex'] = wikidata_id_and_labels[3]
        line = f'{wikidata_id_and_labels[2]}: '
        if not hypernym_value == '':
          line = line + f'class_wkd[{hypernym_value}] '
        if not wikidata_id_and_labels[3] == '':
          line = line + f'class_regex[{wikidata_id_and_labels[3]}] '
        line = line + '\n'
        fo.write(line)
else:
  print('Select parameter a the top of the cell!')

print(f'--------\nDone!')

In [None]:
#@title Debug: Compare two json files
import json

all_entities_classes_by_entity17 = {}
all_entities_classes_by_entity37 = {}

with open('webnlg_dbp-wkd-classes_1-7.json') as json_file:
  all_entities_classes_by_entity17 = json.load(json_file)

with open('webnlg_dbp-wkd-classes_3-7.json') as json_file:
  all_entities_classes_by_entity37 = json.load(json_file)

entities_in_common = 0
wkd_in_common = 0
dbp_in_common = 0
regex_in_common = 0
entities_only_in_17 = 0
for entity_label in all_entities_classes_by_entity17:
  if entity_label in all_entities_classes_by_entity37:
    entities_in_common += 1
    if all_entities_classes_by_entity17[entity_label]['class_wkd'] == all_entities_classes_by_entity37[entity_label]['class_wkd']:
      wkd_in_common +=1
    else:
      print(f"1-7: {all_entities_classes_by_entity17[entity_label]['class_wkd']}: 3-7: {all_entities_classes_by_entity37[entity_label]['class_wkd']}")
    if all_entities_classes_by_entity17[entity_label]['class_dbp'] == all_entities_classes_by_entity37[entity_label]['class_dbp']:
      dbp_in_common +=1
    if all_entities_classes_by_entity17[entity_label]['class_regex'] == all_entities_classes_by_entity37[entity_label]['class_regex']:
      regex_in_common +=1
  else:
    entities_only_in_17 += 1

print(f'{entities_in_common} entities are in both files.')
print(f'{wkd_in_common} wkd classes are in both files.')
print(f'{dbp_in_common} dbp classes are in both files.')
print(f'{regex_in_common} regex classes are in both files.')
print(f'{entities_only_in_17} entities are only in 1-7.')

In [None]:
#@title Export class info in a json file
import json

with open('webnlg_dbp-wkd-classes.json', 'w') as fp:
  json.dump(all_entities_classes_by_entity, fp)

In [None]:
#@title Get class mapping information into a dico via pandas dataframe
import requests
import pandas as pd

sheet_url_props = 'https://docs.google.com/spreadsheets/d/14hnO_ci5LqGIUYoCLAupW6J82gCGlXXTtbOMTauCGrA/edit#gid=947115834'
sheet_url_entities = 'https://docs.google.com/spreadsheets/d/14hnO_ci5LqGIUYoCLAupW6J82gCGlXXTtbOMTauCGrA/edit#gid=2130476960'
# Change the url to specify export format
csv_export_url_props = sheet_url_props.replace('/edit#gid=', '/export?format=csv&gid=')
csv_export_url_entities = sheet_url_entities.replace('/edit#gid=', '/export?format=csv&gid=')

mappings_properties = pd.read_csv(csv_export_url_props,
                               # Set first row as rownames in data frame
                               header=0,
                               usecols=[1, 2, 3, 4, 5, 6, 7, 8, 9]
                               )
mappings_entities = pd.read_csv(csv_export_url_entities,
                               # Set first row as rownames in data frame
                               header=0,
                               usecols=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
                               )

def fill_dico_fine2coarse(pd_dataframe, dico_mapping_fine2coarse, propORentity):
  for column_prop in pd_dataframe:
    # Check if a property in is a column and get property name and column header (=class of property)
    # if (mappings_properties[column_prop] == property_label_f).any():
    for propentity_label_s in pd_dataframe[column_prop]:
      if isinstance(propentity_label_s, str):
        if propentity_label_s in dico_mapping_fine2coarse[propORentity]:
          print(f'ERROR, key already exists: {propentity_label_s}')
        else:
          dico_mapping_fine2coarse[propORentity][propentity_label_s] = column_prop
        # Also add in the dico mapping a version in which the labels of entity categories are joined with underscores (I use underscores when building the merged entity class)
        # Actually do it for both entities and properties, because it looks like Thiago's data used properties with underscores and we need the mapping from these too
        # if propORentity == 'entities':
        propentity_label_s_u = propentity_label_s.replace(" ", "_")
        if propentity_label_s_u not in dico_mapping_fine2coarse[propORentity]:
          dico_mapping_fine2coarse[propORentity][propentity_label_s_u] = column_prop

# When I check in the dataframe directly for the class of each entity in the next cell, it seems to be very slow, so I make a dictionary to store the info
dico_mapping_fine2coarse = {'properties': {}, 'entities': {}}
fill_dico_fine2coarse(mappings_properties, dico_mapping_fine2coarse, 'properties')
fill_dico_fine2coarse(mappings_entities, dico_mapping_fine2coarse, 'entities')

print(f"Class of property comparable: {dico_mapping_fine2coarse['properties']['comparable']}")
print(f"Class of property was a crew member of: {dico_mapping_fine2coarse['properties']['was a crew member of']}")
print(f"Class of property was_a_crew_member_of: {dico_mapping_fine2coarse['properties']['was_a_crew_member_of']}")
print(f"Class of property air base: {dico_mapping_fine2coarse['entities']['air base']}")
print(f"Class of property air_base: {dico_mapping_fine2coarse['entities']['air_base']}")
print(f"Class of entityt class version, edition or translation: {dico_mapping_fine2coarse['entities']['version, edition or translation']}")

In [None]:
#@title Add class info to triple objects for input pairs
import re
import random

random.seed(42)

# I also add a "fused" class that's supposed to take what looks most reliable from the different sources
# I use: regex > wkd classes that have n >= 2 dbp classes mapping to them > dbp classes

def choose_superClass(propOrEntity_label, dico_mapping_fine2coarse, propORentity):
  if propOrEntity_label in dico_mapping_fine2coarse[propORentity]:
    return dico_mapping_fine2coarse[propORentity][propOrEntity_label]

def choose_merged_class(class_wkd, class_dbp, class_regex):
  # list that contains the wikidata classes that generalise a bit over the DBpedia classes
  # The idea is that DBpedia classes looks better, but are sometimes too fine-grained, so I use more general Wikidata classes instead when available
  wkd_2dbp_classes = ['academic major', 'air base', 'air force', 'aircraft family', 'airport', 'archaeological site', 'architectural firm', 'architectural style', 'army', 'art museum', 'association football club', 'association football league', 'asteroid', 'automobile manufacturer', 'automobile model series', 'automobile model', 'autonomous community of Spain', 'award', 'basketball team', 'battle', 'big city', 'book publisher', 'business', 'capital city', 'car classification', 'city council', 'city in the United States', 'city', 'clade', 'coachwork type', 'color', 'comics character', 'commune of France', 'commune of Italy', 'country', 'county seat', 'division', 'ethnic group in Indonesia', 'ethnic group', 'film', 'fixed-base operator', 'food ingredient', 'food', 'geographic region', 'government agency', 'high-rise building', 'historical country', 'historical region', 'hotel', 'human biblical figure', 'human population', 'human settlement', 'human', 'international airport', 'island', 'language', 'launch site', 'literary work', 'material', 'medallion', 'metropolitan area', 'municipality of Belgium', 'municipality of Spain', 'music genre', 'musical group', 'nation', 'national anthem', 'national association football team', 'Olympic stadium', 'organization', 'pandemic', 'panethnicity', 'parliament', 'peninsula', 'political party', 'position', 'private university', 'profession', 'professional sports league', 'public office', 'record label', 'river', 'rocket model', 'sculpture', 'shipyard', 'skyscraper', 'sovereign state', 'spaceport', 'sports season', 'street', 'subsidiary', 'taxon', 'town', 'township of Illinois', 'township of New Jersey', 'trademark', 'type of food or dish', 'type of musical instrument', 'university', 'urban municipality in Germany', 'village', 'war', 'written work']
  class_merged = ''
  # By default, use the regex classes, who are targeting some specific values such as dates, quantities, runway names, etc.
  if not (class_regex == '' or class_regex == None):
    class_merged = class_regex.lower()
  # If no regex was assigned, we use the "general" wkd classes
  elif class_wkd in wkd_2dbp_classes:
    class_merged = class_wkd.replace(" ", "_").lower()
  # If none of the above, use the dbp class if any
  elif not (class_dbp == '' or class_dbp == None):
    class_merged = class_dbp.replace(" ", "_").lower()
  # If none of the above, use the wkd class
  elif not (class_wkd == '' or class_wkd == None):
    class_merged = class_wkd.replace(" ", "_").lower()
  else:
    class_merged = '_'
  return class_merged

def add_class_info_triple_objects (triple, all_entities_classes_by_entity, list_entities_data_point):
  # Get class of property
  triple.property_class = choose_superClass(triple.property_label, dico_mapping_fine2coarse, 'properties')
  # If the subject is found in the dico that contains classes, retrieve its classes
  if triple.subject_label in all_entities_classes_by_entity:
    triple.subject_class_wkd = all_entities_classes_by_entity[triple.subject_label]['class_wkd']
    triple.subject_class_dbp = all_entities_classes_by_entity[triple.subject_label]['class_dbp']
    triple.subject_class_regex = all_entities_classes_by_entity[triple.subject_label]['class_regex']
    triple.subject_id_wkd = all_entities_classes_by_entity[triple.subject_label]['QID']
    # For the ID of an entity, get its position in the lit of unique entities in the input
    triple.subject_id_input = list_entities_data_point.index(triple.subject_label)
    # Call function that chooses a class among the available ones
    triple.subject_class_merged = choose_merged_class(all_entities_classes_by_entity[triple.subject_label]['class_wkd'], all_entities_classes_by_entity[triple.subject_label]['class_dbp'], all_entities_classes_by_entity[triple.subject_label]['class_regex'])
    # Call function that assigns abstract label
    triple.subject_class_abstract = choose_superClass(triple.subject_class_merged, dico_mapping_fine2coarse, 'entities')
  # If the object is found in the dico that contains classes, retrieve its classes
  if triple.object_label in all_entities_classes_by_entity:
    triple.object_class_wkd = all_entities_classes_by_entity[triple.object_label]['class_wkd']
    triple.object_class_dbp = all_entities_classes_by_entity[triple.object_label]['class_dbp']
    triple.object_class_regex = all_entities_classes_by_entity[triple.object_label]['class_regex']
    triple.object_id_wkd = all_entities_classes_by_entity[triple.object_label]['QID']
    # For the ID of an entity, get its position in the list of unique entities in the input
    triple.object_id_input = list_entities_data_point.index(triple.object_label)
    # Call function that chooses a class among the available ones
    triple.object_class_merged = choose_merged_class(all_entities_classes_by_entity[triple.object_label]['class_wkd'], all_entities_classes_by_entity[triple.object_label]['class_dbp'], all_entities_classes_by_entity[triple.object_label]['class_regex'])
    # Call function that assigns abstract label
    triple.object_class_abstract = choose_superClass(triple.object_class_merged, dico_mapping_fine2coarse, 'entities')

# data_split = train/dev/test
for data_split in final_data:
  # data_point is an object with two features, input and output
  for data_point in final_data[data_split]:
    # Create a list to keep track of entities and use the list index as coreferring ID
    # Update: We need to shuffle this list so as not to encode the original irder in the IDs
    list_entities_data_point = []
    # data_point.input contains a list of triples
    for triple_input in data_point.input:
      # Add subjects to the list of entities
      if triple_input.subject_label not in list_entities_data_point:
        list_entities_data_point.append(triple_input.subject_label)
      # Add objects to the list of entities
      if triple_input.object_label not in list_entities_data_point:
        list_entities_data_point.append(triple_input.object_label)
      # Shuffle the list of entities to detach the ID from the original position in the reference order
      random.shuffle(list_entities_data_point)
      # Update the triple object with the class information
      add_class_info_triple_objects(triple_input, all_entities_classes_by_entity, list_entities_data_point)
    # data_point.output contains a list of list of triples; triples are grouped into sentences
    for sentence_group in data_point.output:
      for triple_output in sentence_group:
        add_class_info_triple_objects(triple_output, all_entities_classes_by_entity, list_entities_data_point)
    # Also add the class info to the original input triples that we use to repair incomplete data points later on
    # I had forgotten that part in the original experiments; not sure it will be used but it's cleaner to have it
    for triple_input_orig in data_point.original_input:
      # Add subjects to the list of entities
      if triple_input_orig.subject_label not in list_entities_data_point:
        list_entities_data_point.append(triple_input_orig.subject_label)
      # Add objects to the list of entities
      if triple_input_orig.object_label not in list_entities_data_point:
        list_entities_data_point.append(triple_input_orig.object_label)
      # Update the triple object with the class information
      # No need to shuffle more the list of entities, since the original input is already independent from the reference order
      add_class_info_triple_objects(triple_input_orig, all_entities_classes_by_entity, list_entities_data_point)

# Make another version of all the keys where the entity labels have underscores (for matching Thiago's labels)
# Build a new dico with entities with underscores as keys
entities_with_underscores_classes_by_entity = {re.subn('\s+', '_', key_orig)[0]:all_entities_classes_by_entity[key_orig] for key_orig in all_entities_classes_by_entity}
# Merge old and new dicos
# https://stackoverflow.com/questions/38987/how-do-i-merge-two-dictionaries-in-a-single-expression-in-python
# The desired result is to get a new dictionary (z) with the values merged, and the second dictionary's values overwriting those from the first.
new_all_entities_classes_by_entity = entities_with_underscores_classes_by_entity | all_entities_classes_by_entity

# print(len(new_all_entities_classes_by_entity))

In [None]:
#@title Get number of different merged classes check for missing
seen_entities_print_label = []
seen_entities_print_triple = []
unique_merged_classes = []
unique_properties = []
unique_properties_train = []
unique_properties_dev = []
unique_properties_test = []
unique_property_classes = []
unique_entity_classes_abstract = []
unique_property_lemmas = []
missing_property_class = []
missing_entity_class = []
for data_split in final_data:
  for data_point in final_data[data_split]:
    for triple in data_point.input:
      # If the subject hasn't been seen already, save triple in list, mark object as seen too
      if triple.subject_label not in seen_entities_print_label:
        seen_entities_print_triple.append(triple)
        seen_entities_print_label.append(triple.subject_label)
        seen_entities_print_label.append(triple.object_label)
      #If the subject has been seen already, check if we have seen the object too; if not, add triple to the list
      elif triple.object_label not in seen_entities_print_label:
        seen_entities_print_triple.append(triple)
        seen_entities_print_label.append(triple.object_label)
      if triple.subject_class_merged not in unique_merged_classes:
        unique_merged_classes.append(triple.subject_class_merged)
      if triple.object_class_merged not in unique_merged_classes:
        unique_merged_classes.append(triple.object_class_merged)
      # Check abstract classes
      if triple.subject_class_abstract not in unique_entity_classes_abstract:
        unique_entity_classes_abstract.append(triple.subject_class_abstract)
      if triple.object_class_abstract not in unique_entity_classes_abstract:
        unique_entity_classes_abstract.append(triple.object_class_abstract)
      if triple.property_label not in unique_properties:
        unique_properties.append(triple.property_label)
      if data_split == 'train':
        if triple.property_label not in unique_properties_train:
          unique_properties_train.append(triple.property_label)
      elif data_split == 'dev':
        if triple.property_label not in unique_properties_dev:
          unique_properties_dev.append(triple.property_label)
      elif data_split == 'test':
        if triple.property_label not in unique_properties_test:
          unique_properties_test.append(triple.property_label)
      if triple.property_class not in unique_property_classes:
        unique_property_classes.append(triple.property_class)
      if triple.property_lemma not in unique_property_lemmas:
        unique_property_lemmas.append(triple.property_lemma)
      # Check if we have all property and entity mappings to coars-grained classes
      if not triple.property_label == None and triple.property_label not in dico_mapping_fine2coarse['properties']:
        if triple.property_label not in missing_property_class:
          missing_property_class.append(triple.property_label)
      if not triple.subject_class_merged == None and triple.subject_class_merged not in dico_mapping_fine2coarse['entities']:
        if triple.subject_class_merged not in missing_entity_class:
          missing_entity_class.append(triple.subject_class_merged)
      if not triple.object_class_merged == None and triple.object_class_merged not in dico_mapping_fine2coarse['entities']:
        if triple.object_class_merged not in missing_entity_class:
          missing_entity_class.append(triple.object_class_merged)

# print(len(seen_entities_print_triple))
print(f'{len(unique_properties)} unique properties')
print(sorted([str(unique_property) for unique_property in unique_properties]))
print(f'  ->{len(unique_properties_train)} unique properties in the train data')
print(sorted([str(unique_property_t) for unique_property_t in unique_properties_train]))
print(f'  ->{len(unique_properties_dev)} unique properties in the dev data')
print(sorted([str(unique_property_d) for unique_property_d in unique_properties_dev]))
print(f'  ->{len(unique_properties_test)} unique properties in the test data')
print(sorted([str(unique_property_te) for unique_property_te in unique_properties_test]))
print(f'{len(unique_property_lemmas)} property lemmas')
print(sorted([str(unique_property_lemma) for unique_property_lemma in unique_property_lemmas]))
print(f'{len(unique_property_classes)} property classes')
print(sorted([str(unique_property_class) for unique_property_class in unique_property_classes]))
print(f'{len(unique_merged_classes)} merged entity classes')
print(sorted([str(unique_class) for unique_class in unique_merged_classes]))
print(f'{len(unique_entity_classes_abstract)} abstract entity classes')
print(sorted([str(unique_class_abstract) for unique_class_abstract in unique_entity_classes_abstract]))
print('--------------')
print(f'Missing property classes: {missing_property_class}')
print(f'Missing entity classes: {missing_entity_class}')

# print(f'There are {len(seen_entities_print_triple)} entities.')
# for triple in seen_entities_print_triple[:10]:
#   print('------------------')
#   print(triple.object_label)
#   print('------------------')
#   print(f'wkd: {triple.object_class_wkd}')
#   print(f'dbp: {triple.object_class_dbp}')
#   print(f'regex: {triple.object_class_regex}')
#   print(f'merged: {triple.object_class_merged}')
#   print(f'qid: {triple.object_id_wkd}')
#   print(f'coref_id: {triple.object_id_input}')


In [None]:
#@title Get class mappings (and stats)

def make_dic_map_wkd_dbp (dico_mapping, key_sbj, key_obj, val_sbj, val_obj):
  if key_sbj not in dico_mapping:
    dico_mapping[key_sbj] = []
    dico_mapping[key_sbj].append(val_sbj)
  else:
    if val_sbj not in dico_mapping[key_sbj]:
      dico_mapping[key_sbj].append(val_sbj)
  if key_obj not in dico_mapping:
    dico_mapping[key_obj] = []
    dico_mapping[key_obj].append(val_obj)
  else:
    if val_obj not in dico_mapping[key_obj]:
      dico_mapping[key_obj].append(val_obj)
  return dico_mapping

dico_mapping_wkd2dbp = {}
dico_mapping_dbp2wkd = {}
for data_split in final_data:
  for data_point in final_data[data_split]:
    for triple in data_point.input:
      sbj = triple.subject_label
      obj = triple.subject_label
      wkd_sbj = triple.subject_class_wkd
      dbp_sbj = triple.subject_class_dbp
      wkd_obj = triple.object_class_wkd
      dbp_obj = triple.object_class_dbp
      regex = triple.object_class_regex
      dico_mapping_wkd2dbp = make_dic_map_wkd_dbp(dico_mapping_wkd2dbp, wkd_sbj, wkd_obj, dbp_sbj, dbp_obj)
      dico_mapping_dbp2wkd = make_dic_map_wkd_dbp(dico_mapping_dbp2wkd, dbp_sbj, dbp_obj, wkd_sbj, wkd_obj)

print(dico_mapping_wkd2dbp.keys())
print(dico_mapping_dbp2wkd.keys())
# print(dico_wkd2dbp['human'])

In [None]:
#@title Debug: Print sample datapoints
for itp in final_data['test'][0].input:
  print(itp.property_label)
print('\n')
for itp in final_data['test'][2].input:
  print(itp.property_label)
print('\n')
for itp in final_data['test'][5].input:
  print(itp.property_label)
  print(f'{itp.subject_label} - {itp.subject_class_abstract} - {itp.subject_id_input}')
  print(f'{itp.object_label} - {itp.object_class_abstract} - {itp.object_id_input}')
print('\n')
for itp in final_data['test'][5].original_input:
  print(itp.property_label)
  print(f'{itp.subject_label} - {itp.subject_class_abstract} - {itp.subject_id_input}')
  print(f'{itp.object_label} - {itp.object_class_abstract} - {itp.object_id_input}')

In [None]:
#@title Export mapping info in a json file (not needed to continue with code)
import json

with open('webnlg_dbp-wkd-map.json', 'w') as fp1:
  json.dump(dict(sorted(dico_mapping_dbp2wkd.items())), fp1)
  # json.dump(dico_mapping_dbp2wkd, fp1)
# with open('webnlg_wkd-dbp-map.json', 'w') as fp2:
  # json.dump(dico_mapping_wkd2dbp, fp2)
  # json.dump(dict(sorted(dico_mapping_wkd2dbp.items())), fp2)

### Extract needed info from our final dataset to create training data (parsing style)

In [None]:
#@title Create CoNLL line contents
from collections import Counter

# Keep all input/out pairs (there can be as many as there are reference texts)
# or only one per input triple set (i.e. only 1 per reference text)
# or only unique ones (if e.g. a 4-triple input has only 3 properties in the data and these 3 properties have already been seen, do no use this input)
inOutPairs_kept = '1perTripleSet'#@param['all', '1perTripleSet', 'onlyTrulyUnique']

# What info do we use in the input?
# Label, for "form" column
property_label_in_conll = 'original' #@param ['none', 'original', 'split']
# Lemma, for "lemma" column
property_lemma_in_conll = False#@param {type:"boolean"}
# Class of property, for PoS column
property_class_in_conll = False#@param {type:"boolean"}
# Subject/Domain and Object/Range classes, for feats column
subject_in_conll = 'none' #@param ['none', 'original', 'wkd', 'dbp', 'merged', 'abstract']
object_in_conll = 'none' #@param ['none', 'original', 'wkd', 'dbp', 'merged', 'abstract']
# Coreference information for feats column
coreference_in_conll = 'sbj&obj' #@param ['none', 'sbj', 'obj', 'sbj&obj']
use_only_conllNode_class = False#@param {type:"boolean"}
# Some outputs have less properties than in the original input; 'True' here adds the missing properties in the structure
add_missing_properties = False#@param{type:"boolean"}
print_groupings = False#@param {type:"boolean"}

def checkMismatchProperties(data_point, unique_input_ID):
  # Check how many cases of mismatch between original input and actual input there is in the data
  list_missing_props = []
  if not len(data_point.input) == len(data_point.original_input):
    props_orig = []
    props_real = []
    for orig_triple_object in data_point.original_input:
       props_orig.append(f'{orig_triple_object.property_label}')
    for real_triple_object in data_point.input:
       props_real.append(f'{real_triple_object.property_label}')
    # Check which properties is/are missing in the real input
    # list_missing_props = [p for p in props_orig if p not in props_real] # Doesn't take into account duplicates
    # Difference of list including duplicates
    list_missing_props = list((Counter(props_orig) - Counter(props_real)).elements())
    # Full debug message
    # print(f'Mismatch {unique_input_ID}: Orig={len(data_point.original_input)}, Real={len(data_point.input)}, diff={str(list_missing_props)}')
    # print(f'Mismatch props {unique_input_ID}, real input missing {str(list_missing_props)}')
  return list_missing_props

class CoNLL_Node:
  def __init__(self, idx, form, lemma, pos, cpos, feats_list, head, deprel):
    self.idx = idx
    self.form = form.replace(' ', '_')
    self.lemma = lemma
    self.pos = pos
    self.cpos = cpos
    self.feats = ''
    if len(feats_list) == 0:
      self.feats = '_'
    elif len(feats_list) == 1:
      self.feats = feats_list[0]
    else:
      self.feats = '|'.join([str(feat_conll) for feat_conll in feats_list])
    self.head = head
    self.deprel = deprel

def extract_formLemmaPosFeats(unique_properties_train, triple_object, property_label_in_conll, property_lemma_in_conll, property_class_in_conll, subject_in_conll, object_in_conll, coreference_in_conll):
 # Get form, lemma, pos and feats
  node_conll_feats = []
  # Form
  if property_label_in_conll == 'original':
    node_conll_form = str(triple_object.property_label)
  elif property_label_in_conll == 'split':
    node_conll_form = str(triple_object.property_label_split)
  elif property_label_in_conll == 'none':
    node_conll_form = '_'
  # Lemma
  if property_lemma_in_conll == True:
    node_conll_lemma = str(triple_object.property_lemma)
  elif property_lemma_in_conll == False:
    node_conll_lemma = '_'
  # PoS
  if property_class_in_conll == True:
    # Only copy the class of properties seen during training
    if triple_object.property_label in unique_properties_train:
      node_conll_pos = str(triple_object.property_class)
    # In Thiago's files, he replaces spaces by underscores
    elif triple_object.property_label.replace('_', ' ') in unique_properties_train:
      node_conll_pos = str(triple_object.property_class)
    # Need to find a better strategy to assign classes to unknown properties
    else:
      node_conll_pos = '_'
  elif property_class_in_conll == False:
    node_conll_pos = '_'
  # Feat class subj/domain
  if subject_in_conll == 'merged':
    if not (triple_object.subject_class_merged == None or triple_object.subject_class_merged == ''):
      feat_subj_class = 'dom_class='+str(triple_object.subject_class_merged)
      node_conll_feats.append(feat_subj_class)
  elif subject_in_conll == 'abstract':
    if not (triple_object.subject_class_abstract == None or triple_object.subject_class_abstract == ''):
      feat_subj_class = 'dom_class='+str(triple_object.subject_class_abstract)
      node_conll_feats.append(feat_subj_class)
  elif subject_in_conll == 'dbp':
    if not (triple_object.subject_class_dbp == None or triple_object.subject_class_dbp == ''):
      feat_subj_class = 'dom_class='+str(triple_object.subject_class_dbp)
      node_conll_feats.append(feat_subj_class)
  elif subject_in_conll == 'wkd':
    if not (triple_object.subject_class_wkd == None or triple_object.subject_class_wkd == ''):
      feat_subj_class = 'dom_class='+str(triple_object.subject_class_wkd)
      node_conll_feats.append(feat_subj_class)
  elif subject_in_conll == 'original':
    if not (triple_object.subject_label == None or triple_object.subject_label == ''):
      feat_subj_class = 'dom_class='+str(triple_object.subject_label)
      node_conll_feats.append(feat_subj_class)
  # Feat class obj/range
  if object_in_conll == 'merged':
    if not (triple_object.object_class_merged == None or triple_object.object_class_merged == ''):
      feat_obj_class = 'ran_class='+str(triple_object.object_class_merged)
      node_conll_feats.append(feat_obj_class)
  elif object_in_conll == 'abstract':
    if not (triple_object.object_class_abstract == None or triple_object.object_class_abstract == ''):
      feat_obj_class = 'ran_class='+str(triple_object.object_class_abstract)
      node_conll_feats.append(feat_obj_class)
  elif object_in_conll == 'dbp':
    if not (triple_object.object_class_dbp == None or triple_object.object_class_dbp == ''):
      feat_obj_class = 'ran_class='+str(triple_object.object_class_dbp)
      node_conll_feats.append(feat_obj_class)
  elif object_in_conll == 'wkd':
    if not (triple_object.object_class_wkd == None or triple_object.object_class_wkd == ''):
      feat_obj_class = 'ran_class='+str(triple_object.object_class_wkd)
      node_conll_feats.append(feat_obj_class)
  elif object_in_conll == 'original':
    if not (triple_object.object_label == None or triple_object.object_label == ''):
      feat_obj_class = 'ran_class='+str(triple_object.object_label)
      node_conll_feats.append(feat_obj_class)
  # Feat coreference
  if coreference_in_conll == 'sbj' or  coreference_in_conll == 'sbj&obj':
    if not (triple_object.subject_id_input == None or triple_object.subject_id_input == ''):
      feat_subj_coref = 'dom_ID='+str(triple_object.subject_id_input)
      node_conll_feats.append(feat_subj_coref)
  if coreference_in_conll == 'obj' or  coreference_in_conll == 'sbj&obj':
    if not (triple_object.object_id_input == None or triple_object.object_id_input == ''):
      feat_obj_coref = 'ran_ID='+str(triple_object.object_id_input)
      node_conll_feats.append(feat_obj_coref)
  # print(node_conll_feats)
  return node_conll_form, node_conll_lemma, node_conll_pos, node_conll_feats

def build_CoNLL_Lines(whole_data, data_split, unique_properties_train, inOutPairs_kept, print_groupings):
  conll_contents = []
  # Store unique datapoints to avoid duplicates if required
  all_triple_combinations_in_dataset = []
  # Input ID in this context is split_category_size_eid
  all_input_IDs = []
  count_duplicate_inputs = 0
  for lim, data_point in enumerate(whole_data[data_split]):
    # if lim< 50:
    # print(list_orig_props)
    triple_combination = []
    unique_input_ID = f'{data_point.input[0].data_split}_{data_point.input[0].category}_{data_point.input[0].size}_{data_point.input[0].eid}'
    for input_triple in data_point.input:
      triple_combination.append(input_triple.property_label)
    if inOutPairs_kept=='all' or (inOutPairs_kept=='onlyTrulyUnique' and sorted(triple_combination) not in all_triple_combinations_in_dataset) or (inOutPairs_kept=='1perTripleSet' and unique_input_ID not in all_input_IDs):
      if inOutPairs_kept=='onlyTrulyUnique' and sorted(triple_combination) in all_triple_combinations_in_dataset:
        count_duplicate_inputs += 1
      elif inOutPairs_kept=='1perTripleSet' and unique_input_ID in all_input_IDs:
        count_duplicate_inputs += 1
      all_triple_combinations_in_dataset.append(sorted(triple_combination))
      all_input_IDs.append(unique_input_ID)
      if print_groupings == True:
        print('------------------------')
        print(f'{data_split}-{str(len(conll_contents))}: {data_point.input[0].category}-{data_point.input[0].size}-id{data_point.input[0].eid}')
        print('------------------------')
      missing_properties = checkMismatchProperties(data_point, unique_input_ID)
      conll_content = []
      # Starting at 1 because the first line in a conll has ID = 1
      i_line_struct = 1
      # We'll store here the latest element within the same sentence
      intra_head = None
      # We'll store here the head element from the previous sentence
      inter_head = None
      for i_sent, output_sentence_grouping in enumerate(data_point.output):
        if print_groupings == True:
          print('<BOUNDARY>')

        # Re-initialise the internal head for each new sentence
        intra_head = None
        for i_triple, output_triple_object in enumerate(output_sentence_grouping):

          if print_groupings == True:
            print(f'  {output_triple_object.property_label}')

          node_conll_idx = ''
          node_conll_form = ''
          node_conll_lemma = ''
          node_conll_pos = ''
          node_conll_feats = []
          node_conll_head = ''
          node_conll_deprel = ''
          # In this block, get dependency information
          # First line of the conll (first triple of the first sentence)
          if i_sent == 0 and i_triple == 0:
            # print(f'    {intra_head}  {inter_head} {i_line_struct}')
            # conll_content.append(CoNLL_Node(i_line_struct, output_triple_object.property_label, '_', '_', '_', '_', '0', 'ROOT'))
            node_conll_idx = i_line_struct
            node_conll_head = '0'
            node_conll_deprel = 'ROOT'
            # Update the heads for the next elements of the same sentence and of the next one
            intra_head = i_line_struct
            inter_head = i_line_struct
            i_line_struct += 1
          # For the non-initial sentences
          elif i_triple >= 0:
            # print(f'    {intra_head}  {inter_head} {i_line_struct}')
            # If we are starting a new sentence, the first node will be attached to the head of the last sentence with the relation "inter"
            if intra_head == None:
              # conll_content.append(CoNLL_Node(i_line_struct, output_triple_object.property_label, '_', '_', '_', '_', str(inter_head), 'inter'))
              node_conll_idx = i_line_struct
              node_conll_head = str(inter_head)
              node_conll_deprel = 'inter'
              # The present node thus becomes the new head for within the sentence and for the next one
              intra_head = i_line_struct
              inter_head = i_line_struct
              i_line_struct += 1
            # If we are in the same sentence as before, we attach the next node with the "intra" relation
            else:
              # conll_content.append(CoNLL_Node(i_line_struct, output_triple_object.property_label, '_', '_', '_', '_', str(intra_head), 'intra'))
              node_conll_idx = i_line_struct
              node_conll_head = str(intra_head)
              node_conll_deprel = 'intra'
              intra_head = i_line_struct
              i_line_struct += 1

          # In this block (now it's a function), get form, lemma, pos and feats
          node_conll_form, node_conll_lemma, node_conll_pos, node_conll_feats = extract_formLemmaPosFeats(unique_properties_train, output_triple_object, property_label_in_conll, property_lemma_in_conll, property_class_in_conll, subject_in_conll, object_in_conll, coreference_in_conll)

          # Create conll node object with ID and empty columns
          node_conll = CoNLL_Node(node_conll_idx, node_conll_form, node_conll_lemma, node_conll_pos, '_', node_conll_feats, node_conll_head, node_conll_deprel)
          conll_content.append(node_conll)

      # Add missing properties for test set after the last grouping; for these, is it possible to get the feats too?
      # If this is activated, there should be a part of the code to do the same thing on the gold data (see last cell below)
      if add_missing_properties == True:
        if data_split == 'test':
          if len(missing_properties) > 0:
            print(f'  Added {len(missing_properties)} missing properties to #{str(len(conll_contents))} ({unique_input_ID})')
            for missing_property in missing_properties:
              conll_content.append(CoNLL_Node(i_line_struct, missing_property, '_', '_', '_', [], str(intra_head), 'intra'))
              intra_head = i_line_struct
              i_line_struct += 1
      conll_contents.append(conll_content)
    else:
      count_duplicate_inputs += 1
  message = f'There are {len(conll_contents)} data points in {data_split} ({count_duplicate_inputs} duplicate inputs'
  if inOutPairs_kept == 'all':
    print(message+').')
  else:
    print(message+' were removed).')
  return conll_contents

conll_contents_train = ''
conll_contents_dev = ''
conll_contents_test = ''

if use_only_conllNode_class == False:
  conll_contents_train = build_CoNLL_Lines(final_data, 'train', unique_properties_train, inOutPairs_kept, print_groupings)
  conll_contents_dev = build_CoNLL_Lines(final_data, 'dev', unique_properties_train, inOutPairs_kept, print_groupings)
  conll_contents_test = build_CoNLL_Lines(final_data, 'test', unique_properties_train, inOutPairs_kept, print_groupings)

# print(len(conll_contents_test))
# Print output
# for conll_2be in conll_contents_test:
#   for line_2be in conll_2be:
#     print(f'{line_2be.idx} {line_2be.form} {line_2be.lemma} {line_2be.pos} {line_2be.feats} {line_2be.head} {line_2be.deprel}')


In [None]:
#@title Debug: Print sample datapoint
for i, test_check_data in enumerate(conll_contents_test[:5]):
  print(f'---{i}---')
  for tcl in test_check_data:
    print(f'{tcl.idx} {tcl.form} {tcl.lemma} {tcl.pos} {tcl.feats} {tcl.head} {tcl.deprel}')

In [None]:
#@title Mount drive if you want to save data to drive in the next cells
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
#@title Write file in CoNLL format
import codecs
import os

conll_format = 'CoNLL-U'#@param['CoNLL09', 'CoNLL-U']
# If True, will put all the contents in the same conll with a bubble around each sentence (False for Text Structuring)
same_conll = False#@param {type:"boolean"}
# In case you want to just load the funtion without running the code in this cell
use_only_instantiate_conll_function = False#@param {type:"boolean"}

conllFolder = '/content/conllOut'
conllSplitFolder= os.path.join(conllFolder, 'conllu2conll')
if not os.path.exists(conllSplitFolder):
  os.makedirs(conllSplitFolder)

# To keep here the code of the dataset
final_folder_ID = ''

# Moved in cell above
# class CoNLL_Node:
#   def __init__(self, idx, form, feats, head, deprel):
#     self.idx = idx
#     self.form = form
#     self.feats = feats
#     self.head = head
#     self.deprel = deprel

def instantiate_conll_template(sentences, conll_format, same_conll):
  """
  Takes as input a list of sentences. Each sentence is a list of tokens of the CoNLL_Node class.
  Returns a CoNLL structure in the 14-column 2009 format (tree or graph, with or without bubbles), or in the 10-column CoNLL-U format (tree, no bubbles).
  """
  if conll_format == 'CoNLL-U' and same_conll == True:
    print(f'Cannot create bubbles in CoNL-U format; keeping sentences separated.')

  conll_str = ''
  if conll_format == 'CoNLL09' and same_conll == True:
    conll_str = "0	_	_	_	_	_	_	_	_	_	_	_	_	_"

  contentTemplate = ''
  if conll_format == 'CoNLL09':
    contentTemplate = "{0.idx}	{0.form}	{0.lemma}	_	{0.pos}	_	{0.feats}	_	{0.head}	_	{0.deprel}	_	_	_"
  if conll_format == 'CoNLL-U':
    contentTemplate = "{0.idx}	{0.form}	{0.lemma}	{0.pos}	{0.cpos}	{0.feats}	{0.head}	{0.deprel}	_	_"

  sentence_count = len(sentences)
  if sentence_count == 0:
    conll_str += "\tslex=Sentence\tslex=Text\n1\tNo_relevant_information_found\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\ttrue\ttrue"

  if conll_format == 'CoNLL09' and same_conll == True:
    bubbles = []
    if sentence_count > 1:
      conll_str += "\tslex=Text"
      bubbles.append("true")

    offset = len(bubbles)
    for i in range(sentence_count):
      conll_str += "\tslex=Sentence"
      bubbles.append("_")
    conll_str += "\n"

  line_count = 0
  for idx, sentence in enumerate(sentences):
    if conll_format == 'CoNLL09' and same_conll == True:
      bubbles[idx + offset] = "true"
      bubbles_str = "\t" + "\t".join(bubbles)
      bubbles[idx + offset] = "_"
    for id_tok, token in enumerate(sentence):
      token.idx = str(line_count+1)
      conll_str += contentTemplate.format(token)
      if conll_format == 'CoNLL09' and same_conll == True:
        conll_str += bubbles_str
      else:
        # If we are at the last token of a sentence and we don't want bubbles, add linebreak except for last structure of file
        if id_tok == len(sentence)-1 and idx < len(sentences) - 1:
          conll_str += "\n"
          # Also set back counter to -1 (not 0 because just below it gets a +1 increase already)
          line_count = -1
      line_count += 1
      conll_str += "\n"
  # Don't add linebrak after last structure (edit: actually it seems to break the scrambling for some reason)
  # if idx < len(sentences) - 1:
  conll_str += "\n"

  return conll_str

# Toy data structure to test conll building
# word1 = CoNLL_Node(1, 'born', 'tense=PAST', '0', 'root')
# word2 = CoNLL_Node(2, 'Jesus', 'class=Person|dpos=NP', '1', 'A1')
# word3 = CoNLL_Node(3, 'Dec.25', '_', '1', 'Time')
# sentences = [[word1, word2, word3], [word1, word2]]
# conll = instantiate_conll_template(sentences, conll_format, same_conll)

# print(conll)

ext1 = ''
ext2 = ''
if use_only_instantiate_conll_function == False:
  if exclude_input_size == 'size1&2&3':
    ext1 = '4'
  elif exclude_input_size == 'size1&2':
    ext1 = '3'
  elif exclude_input_size == 'size1':
    ext1 = '2'
  if inOutPairs_kept == 'all':
    ext1 = ext1+'A'
  elif inOutPairs_kept == '1perTripleSet':
    ext1 = ext1+'T'
  elif inOutPairs_kept == 'onlyTrulyUnique':
    ext1 = ext1+'U'
  if property_label_in_conll == 'original':
    ext1 = ext1+'O'
  elif property_label_in_conll == 'split':
    ext1 = ext1+'S'
  if property_lemma_in_conll == True:
    ext1 = ext1+'e1'
  if property_class_in_conll == True:
    ext1 = ext1+'p1'
  if subject_in_conll == 'merged':
    ext1 = ext1+'d1'
  elif subject_in_conll == 'abstract':
    ext1 = ext1+'d2'
  elif subject_in_conll == 'original':
    ext1 = ext1+'d3'
  if object_in_conll == 'merged':
    ext1 = ext1+'r1'
  elif object_in_conll == 'abstract':
    ext1 = ext1+'r2'
  elif object_in_conll == 'original':
    ext1 = ext1+'r3'
  if coreference_in_conll == 'sbj':
    ext1 = ext1+'c1'
  elif coreference_in_conll == 'obj':
    ext1 = ext1+'c2'
  elif coreference_in_conll == 'sbj&obj':
    ext1 = ext1+'c3'
  if conll_format == 'CoNLL-U':
    ext2 = '.conllu'
  elif conll_format == 'CoNLL09':
    ext2 = '.conll'

  final_folder_ID = ext1

  conll_train = instantiate_conll_template(conll_contents_train, conll_format, same_conll)
  conll_dev = instantiate_conll_template(conll_contents_dev, conll_format, same_conll)
  conll_test = instantiate_conll_template(conll_contents_test, conll_format, same_conll)

  with codecs.open(os.path.join(conllSplitFolder, 'train-TextStruct_'+ext1+ext2), 'w', 'utf-8') as fo1:
    fo1.write(conll_train)

  with codecs.open(os.path.join(conllSplitFolder, 'dev-TextStruct_'+ext1+ext2), 'w', 'utf-8') as fo2:
    fo2.write(conll_dev)

  with codecs.open(os.path.join(conllSplitFolder, 'test-TextStruct_'+ext1+ext2), 'w', 'utf-8') as fo3:
    fo3.write(conll_test)

# print('Download your file from the left!')

In [None]:
#@title Debug: Print conll structures
print(conll_test)

In [None]:
#@title Scramble the CoNLLs for learning and move to final folder

move_to_custom_folder = True#@param{type:"boolean"}

# Paths to update:
custom_folder_conllu = '/content/drive/MyDrive/M-FleNS/Papers-Slides/M-FleNS_papers/2024-05_Fluency_Improvements/Parsing4'#@param{type:"string"}

path_jars = '/content/UD_Converter/Resources'
path_conllScramble = os.path.join(path_jars, 'conllScramble.py')
new_path_conllScramble = os.path.join(path_jars, 'conllScramble_TS.py')

# Need to deactivate the introduction of original_id feat in the script:
conll_scram_lines = codecs.open(path_conllScramble, 'r', 'utf-8').readlines()
with codecs.open(new_path_conllScramble, 'w', 'utf-8') as fo_scram:
  id_last_commented_line = 0
  for i, conll_scram_line in enumerate(conll_scram_lines):
    if re.search("^\t\t\t\tif dType == '1'", conll_scram_line):
      new_line = re.subn("^\t\t\t\tif dType == '1'", "#\t\t\t\tif dType == '1'", conll_scram_line)[0]
      id_last_commented_line = i
      fo_scram.write(new_line)
    # All next lines that are indented should be commented too
    elif i == id_last_commented_line + 1 and re.search("^\t\t\t\t\t", conll_scram_line):
      new_line = re.subn("^\t\t\t\t\t", "#\t\t\t\t\t", conll_scram_line)[0]
      id_last_commented_line = i
      fo_scram.write(new_line)
    else:
      fo_scram.write(conll_scram_line)

trainfile = 'train-TextStruct_'+ext1+'.conllu'
devfile = 'dev-TextStruct_'+ext1+'.conllu'
testfile = 'test-TextStruct_'+ext1+'.conllu'

!python {new_path_conllScramble} {trainfile} {conllFolder} 't1' '1' {conllFolder}
!python {new_path_conllScramble} {devfile} {conllFolder} 't1' '1' {conllFolder}
!python {new_path_conllScramble} {testfile} {conllFolder} 't1' '1' {conllFolder}


final_folder_path_scrambled = os.path.join(custom_folder_conllu, final_folder_ID)
final_folder_path_lin = os.path.join(custom_folder_conllu, final_folder_ID+'-lin')
if not os.path.exists(final_folder_path_scrambled):
  os.makedirs(final_folder_path_scrambled)
if not os.path.exists(final_folder_path_lin):
  os.makedirs(final_folder_path_lin)

# Rename original files with the extension "-lin" so as to avoid confusions and move to final folder
for filepath in glob.glob(os.path.join(conllFolder, 'conllu2conll', '*.conllu')):
  filepath_noext = filepath.rsplit('.', 1)[0]
  new_filepath = filepath_noext+'-lin.conllu'
  os.rename(filepath, new_filepath)
  if move_to_custom_folder == True:
    shutil.move(new_filepath, final_folder_path_lin)

# Finally, move the scrambled files too
if move_to_custom_folder == True:
  for filepath2 in glob.glob(os.path.join(conllFolder, '*.conllu')):
    shutil.move(filepath2, final_folder_path_scrambled)



In [None]:
#@title Convert gold structuring file into CoNLL for eval with LAS and move to drive (only use if 1 input per triple set (T) above)
import json
import codecs
import random
import sys
import re

random.seed(4242)

# Just to make sure we have the right file infix at this point (we want a T file).
# Could also just manipulate the ext1 string but simpler like this for now, since we always need the T files anyway
if not re.search('T', ext1):
  sys.exit("Please use this cell only after generating the 'T' files for a combination of feats. ")


# This is for making the reference file for the evaluation with LAS of the M-FleNS structuring step using Gold ordering as input (in a similar fashion to what was done in Thiago's experiments)
# This file is used as input for the parser too (the first columns only; the gov and deprel columns are only used for evaluation as far as I know)
data_split_gold_str = 'test'#@param['dev', 'test']

# The following file is in C:\Users\sfmil\OneDrive\Desktop\DCU\Papers\2024-05_Fluency-improvements\TextPlanning\thiago_eval
path_gold_structuring = '/content/drive/MyDrive/M-FleNS/Papers-Slides/M-FleNS_papers/2024-05_Fluency_Improvements/Thiago-files/structuring_gold-'+data_split_gold_str+'.json'
# The json looks like this inside:
# [{"eid": "Id285", "category": "Airport", "size": "2", "source": ["<TRIPLE>", "Aarhus_Airport", "operatingOrganisation", "\"Aarhus_Lufthavn_A/S\"", "</TRIPLE>", "<TRIPLE>", "Aarhus_Airport", "runwayLength", "2777.0", "</TRIPLE>"], "targets": [{"lid": "Id1", "comment": "good", "output": ["<SNT>", "operatingOrganisation", "runwayLength", "</SNT>"]}, {"lid": "Id2", "comment": "good", "output": ["<SNT>", "operatingOrganisation", "runwayLength", "</SNT>"]}]}...]
# To control if we add feats or not in the file, it will use the feats selected in the cell "Create CoNLL line contents" above.
use_feats = 'yes'#@param['yes', 'no']

contents_gold_struct = ''
with codecs.open(path_gold_structuring, 'r', 'utf-8') as json_file:
  contents_gold_struct = json.load(json_file)

conll_contents_gold_struct = []
for i, datapoint in enumerate(contents_gold_struct):
  # print(f'---{i}---')
  # Create a list to keep track of entities and use the list index as coreferring ID
  list_entities_data_point = []
  # Create conll content with the deprel and the property name
  conll_content = []
  inter_gov = 0
  intra_gov = 0
  last_element = ''
  # Use the first target as reference
  # (Sample target ['<SNT>', 'location', 'elevationAboveTheSeaLevel_(in_metres)', '</SNT>', '<SNT>', '</SNT>'])
  idy = 0
  prop_label = ''
  gov = ''
  deprel = ''
  for target_element in datapoint['targets'][0]['output']:
    # If an element is a <SNT> tag, update last_element
    if target_element == '<SNT>':
      last_element = '<SNT>'
    # If it's a property, create a line with the available data
    elif not target_element == '</SNT>':
      idy += 1
      prop_label = target_element
      # Get position of property in the source field, where the subject (-1) and object (+1) can be found too
      id_prop_in_source = datapoint['source'].index(target_element)
      # Subject and Object labels all have underscores, instead of spaces in the original data, so I reestablish the spaces here
      # Actually I added a version of all entities with underscores in the all_entities_classes_by_entity list
      subject_label = datapoint['source'][id_prop_in_source-1]#.replace('_', ' ')
      object_label = datapoint['source'][id_prop_in_source+1]#.replace('_', ' ')
      if last_element == '<SNT>':
        gov = inter_gov
        if inter_gov == 0:
          deprel = 'ROOT'
        else:
          deprel = 'inter'
        inter_gov = idy
        intra_gov = idy
      else:
        gov = intra_gov
        deprel = 'intra'
        intra_gov = idy
      last_element = target_element

      if use_feats == 'yes':
        # Make the current data into a triple object
        goldStruct_metadata = [data_split_gold_str, datapoint['category'], datapoint['eid'], datapoint['size']]
        goldStruct_triple_object = Triple(f'{subject_label} | {target_element} | {object_label}', goldStruct_metadata)
        # Add subjects to the list of entities
        if goldStruct_triple_object.subject_label not in list_entities_data_point:
          list_entities_data_point.append(goldStruct_triple_object.subject_label)
        # Add objects to the list of entities
        if goldStruct_triple_object.object_label not in list_entities_data_point:
          list_entities_data_point.append(goldStruct_triple_object.object_label)
        # Shuffle the list of entities so the coref IDs have the same shape as the training/dev data
        random.shuffle(list_entities_data_point)
        # Update the triple object with the class information
        add_class_info_triple_objects(goldStruct_triple_object, new_all_entities_classes_by_entity, list_entities_data_point)
        # Get form, lemma, pos and feats as defined earlier in the pipeline
        node_conll_form, node_conll_lemma, node_conll_pos, node_conll_feats = extract_formLemmaPosFeats(unique_properties_train, goldStruct_triple_object,  property_label_in_conll, property_lemma_in_conll, property_class_in_conll, subject_in_conll, object_in_conll, coreference_in_conll)
        conll_content.append(CoNLL_Node(idy, node_conll_form, node_conll_lemma, node_conll_pos, '_', node_conll_feats, str(gov), deprel))
      else:
        conll_content.append(CoNLL_Node(idy, prop_label, '_', '_', '_', [], str(gov), deprel))
  conll_contents_gold_struct.append(conll_content)

# for line_conll_struct in conll_contents_gold_struct[2000]:
#   print (line_conll_struct.idx, line_conll_struct.form)

# Create conll file
conll_gold_struct = instantiate_conll_template(conll_contents_gold_struct, 'CoNLL-U', False)
with codecs.open(os.path.join(custom_folder_conllu, '0_structuring_gold_input', 'conll_gold_struct_'+ext1+'-'+data_split_gold_str+'.conllu'), 'w', 'utf-8') as fo1:
  fo1.write(conll_gold_struct)