In [72]:
import os
import json
import numpy as np
import xml.etree.ElementTree as ET

In [197]:
def list_files_recursive(directory):
    file_paths = list()
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            if file.endswith(".xml") and ("3triples" in file_path or "rdf-to-text-generation-test-data-with-refs-en" in file_path):
                file_paths.append(file_path)
    # for direc in dirs:
    #     file_paths += list_files_recursive(direc)
    return file_paths

def process_file(file_path, categories, cat_to_rels, cat_to_sents):
    tree = ET.parse(file_path)
    root = tree.getroot()
    entry_list = list()
    entries = root.find('entries')
    for entry in entries.findall('entry'):
        category = entry.get('category')
        categories.add(category)
        # eid = entry.get('eid')
        # shape = entry.get('shape')
        # shape_type = entry.get('shape_type')
        # size = entry.get('size')
        triples = list()
        modifiedtripleset = entry.find('modifiedtripleset')
        for mtriple in modifiedtripleset.findall('mtriple'):
            splits = mtriple.text.split("|")
            rel = splits[1].strip()
            rel_set = cat_to_rels.get(category, set())
            rel_set.add(rel)
            cat_to_rels[category] = rel_set
            triples.append([splits[0].strip(), rel, splits[2].strip()])
            
            
        # if len(new_subjects.intersection(seen_subjects)) > 0:
        #     continue
        # else:
        #     seen_subjects.update(new_subjects)
            
        lex = entry.find('lex')  
        if triples and lex.text:
            tr_sents = cat_to_sents.get(category, list())
            tr_sents.append({"sent": lex.text, "triples": triples})
            cat_to_sents[category] = tr_sents
            
        entry_list.append(entry)
    return entry_list
            
def save_sentences(cat_id, sents):
    #splits = [0.4, 0.3, 0.3]
#     sents = np.array(sents)
#     indices = np.random.permutation(sents.shape[0])
#     train_count = int(sents.shape[0] * splits[0])
#     val_count = int(sents.shape[0] * splits[1])
#     test_count = sents.shape[0] - train_count - val_count
    
#     train_sents = sents[indices[:train_count]]
#     val_sents= sents[indices[train_count:train_count+val_count]]
#     test_sents = sents[indices[train_count+val_count:]]

    train_subjects = set()
    test_subjects = set()
    train_sents = list()
    test_sents = list()
    for sent in sents:
        sent_subjects = {f"{cat_id}{tr[0]}" for tr in sent['triples']}
        if len(sent_subjects.intersection(test_subjects)) > 0:
            test_sents.append(sent)
            test_subjects.update(sent_subjects)
        elif len(test_sents) < len(train_subjects) * 2 and len(sent_subjects.intersection(train_subjects)) == 0:
            test_sents.append(sent)
            test_subjects.update(sent_subjects)
        else:
            train_sents.append(sent)
            train_subjects.update(sent_subjects)
            
    print(f"{cat_id}: {len(train_sents) + len(test_sents)}")
    print(f"train: {len(train_sents)}")
    print(f"test: {len(test_sents)}\n")   
    
    with open(f"data/train/{cat_id}_train.jsonl", "w") as out_file:
        idx = 1
        for s in train_sents:
            sentence = s['sent'] 
            triples = s['triples']
            data = {"id": f"{cat_id}_train_{idx}", "sent": sentence}
            data["triples"] = [{"sub": tr[0], "rel": tr[1], "obj": tr[2]} for tr in triples]
            idx += 1
            out_file.write(f"{json.dumps(data)}\n")
                
    # with open(f"data/validation/{cat_id}_validation.jsonl", "w") as out_file:
    #     idx = 1
    #     for s in val_sents:
    #         sentence = s['sent'] 
    #         triples = s['triples']
    #         data = {"id": f"{cat_id}_valid_{idx}", "sent": sentence}
    #         data["triples"] = [{"sub": tr[0], "rel": tr[1], "obj": tr[2]} for tr in triples]                
    #         idx += 1
    #         out_file.write(f"{json.dumps(data)}\n")
                
    with open(f"data/test/{cat_id}_test.jsonl", "w") as test_file, open(f"data/ground_truth/{cat_id}_ground_truth.jsonl", "w") as gt_file:
        idx = 1
        for s in test_sents:
            sentence = s['sent'] 
            triples = s['triples']
            data = {"id": f"{cat_id}_test_{idx}",  "sent": sentence}
            test_file.write(f"{json.dumps(data)}\n")
            data["triples"] = [{"sub": tr[0], "rel": tr[1], "obj": tr[2]} for tr in triples]
            gt_file.write(f"{json.dumps(data)}\n")
            idx += 1

In [171]:
file_list = sorted(list_files_recursive("webnlg"), reverse=True)

In [145]:
categories = set()
cat_to_rels = dict()
cat_to_sents = dict()

In [149]:
entry_list = list()
seen_triples = set()
for file_path in file_list:
    entry_list += process_file(file_path, categories, cat_to_rels, cat_to_sents)

print(f"processed {len(entry_list)} entries.")

processed 4860 entries.


In [198]:
for idx, cat in enumerate(categories):
    onto_data = dict()
    onto_data["title"] = f"{cat} Ontology"
    onto_data["id"] = f"ont_{idx+1}_{cat.lower()}"
    onto_data["concepts"] = [{"qid": cat, "label": cat}] + [{"qid": "", "label": ""} for i in range(len(cat_to_rels[cat]))]
    relations = list()
    for rel in cat_to_rels[cat]:
        relations.append({"pid": rel, "label": rel, "domain": f"{cat}", "range": rel})
    onto_data["relations"] = relations
    file_name = onto_data["id"] + "_ontology.json"
    with open(f"ontology/{file_name}", "w") as out_file:
        json.dump(onto_data, out_file, indent=2)
        
    save_sentences(onto_data["id"], cat_to_sents[cat])

ont_1_university: 156
train: 85
test: 71

ont_2_musicalwork: 290
train: 81
test: 209

ont_3_airport: 306
train: 227
test: 79

ont_4_building: 275
train: 172
test: 103

ont_5_athlete: 293
train: 186
test: 107

ont_6_politician: 319
train: 184
test: 135

ont_7_company: 153
train: 97
test: 56

ont_8_celestialbody: 194
train: 122
test: 72

ont_9_astronaut: 154
train: 86
test: 68

ont_10_comicscharacter: 102
train: 66
test: 36

ont_11_meanoftransportation: 314
train: 222
test: 92

ont_12_monument: 92
train: 73
test: 19

ont_13_food: 398
train: 245
test: 153

ont_14_writtenwork: 322
train: 195
test: 127

ont_15_sportsteam: 235
train: 125
test: 110

ont_16_city: 348
train: 131
test: 217

ont_17_artist: 386
train: 302
test: 84

ont_18_scientist: 259
train: 110
test: 149

ont_19_film: 264
train: 137
test: 127

