<a href="https://colab.research.google.com/github/mille-s/GEM24_D2T_StratifiedSampling/blob/main/GEM_D2T_data_selection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# importing required libraries

import os
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
import pandas as pd
import csv
import random
import json

In [None]:
# filepath definition

# project_dir_path = os.path.join('/', 'content', 'drive', 'MyDrive', 'WebNLG_data_selection')
project_dir_path = '/content'
rdf_path = os.path.join(project_dir_path, 'testdata')
output_path = os.path.join(rdf_path, 'output_data.csv')

if not os.path.exists(rdf_path):
  os.makedirs(rdf_path)

In [None]:
seed = 49

In [None]:
def extract_data(rdf_filepath, stratify_categories, exclude_size):

  '''
      This method:
      a. extracts the required entries (RDF triple(s), number of triples, property and category) from the json file.
      b. categorizes the triple and verbalisation pair as seen/unseen category based on its presence in the training set.
      c. groups the required extracted entry field (in this case, number of triples and property) for stratified selection.
  '''

  data = []
  count = 0
  for filename in os.listdir(rdf_filepath):
    tree = ET.parse(f"{rdf_filepath}/{filename}")
    root = tree.getroot()

    # extract triples
    for entry in root.findall('./entries/entry'):
      triples = []
      pred = []
      for triple in entry.find('modifiedtripleset').findall('mtriple'):
        str_triple = triple.text
        triples.append(str_triple)
        only_pred = str_triple.split('|')[1]
        pred.append(only_pred)
      if exclude_size == 'none' or (exclude_size == '1 only' and int(entry.attrib['size']) > 1) or (exclude_size == '1 and 2' and int(entry.attrib['size']) > 2):
        curr_entry = {
            'id': count,
            'triples': triples.copy(),
            'property': pred.copy(),
            'num_triples': int(entry.attrib['size']),
            'category': 'unseen' if entry.attrib['category'] in ['Athlete', 'Artist', 'CelestialBody', 'MeanOfTransportation', 'Politician'] else 'seen',
            'category_all': entry.attrib['category']
        }
        if stratify_categories == 'seenUnseen':
          curr_entry['strat_field'] = str(curr_entry['num_triples'])+curr_entry['category']
        elif stratify_categories == 'allCategories':
          curr_entry['strat_field'] = str(curr_entry['num_triples'])+curr_entry['category_all']
        data.append(curr_entry)
        count += 1

  # Remove data points for which there is only one member in a stratify category (triggers an error when stratifying, needs 2 members min)
  clean_data = []
  # Make a dico with the count of instances of each strat_field
  count_strat_field_instances = {}
  for datapoint in data:
    if datapoint['strat_field'] in count_strat_field_instances:
      count_strat_field_instances[datapoint['strat_field']] += 1
    else:
      count_strat_field_instances[datapoint['strat_field']] = 1
  # If a count of a strat_field is one, do no include it in the final dataset
  for datapoint_clean in data:
    if count_strat_field_instances[datapoint_clean['strat_field']] == 1:
      print(f"  Removed datapoint  {datapoint_clean['strat_field']} because there is only one member!")
    else:
      clean_data.append(datapoint_clean)

  return clean_data

In [None]:
stratify_categories = 'allCategories'#@param['seenUnseen', 'allCategories']
number_samples = "180"#@param[50, 100, 120, 150, 180, 200, 300, 400, 500]
num_samples = int(number_samples)
exclude_size = '1 only'#@param['none', '1 only', '1 and 2']
# Get data
data=extract_data(rdf_path, stratify_categories, exclude_size)

In [None]:
# stratified selection using train_test_split

tset = pd.DataFrame.from_dict(data)
X_train, X_test, = train_test_split(tset, test_size=num_samples, random_state=seed, stratify=tset['strat_field'], shuffle=True)
print(len(X_train), len(X_test))

In [None]:
# tset['num_triples']
# len(tset.loc[tset['category'] == 'unseen'])
# print(X_test['num_triples'])

# Show mean of column that contains triple number in each input (https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.loc.html)
# print(X_test.loc[:, 'num_triples'].mean())
print(f"{round(X_test['num_triples'].mean(), 2)} triples per input on average")

def count_num_instances(pd_column):
  count = {}
  for category in pd_column:
    if category in count:
      count[category] += 1
    else:
      count[category] = 1

  for count_category in sorted(count):
    print(f'{count_category}\t{count[count_category]}')
    # print(f'{count[count_category]}')
  print('-----------------')

count_num_instances(X_test['num_triples'])
count_num_instances(X_test['category_all'])

In [None]:
X_test.to_csv(output_path, index=False)