In [83]:
import pandas as pd
import os
import logging
from rdflib import URIRef, OWL, Literal, RDF, RDFS, BNode
from owlready2 import get_ontology
from tqdm import tqdm
from glob import glob
from sklearn.model_selection import train_test_split
from src.utils import *
from src.noise import *

In [84]:
dataset_name = 'OWL2DL-1'
# dataset_name = 'family'

In [85]:
ontology = get_ontology(f'datasets/{dataset_name}.owl').load()

**Create test, val sets**

In [86]:
def get_graph_type(s):
    s = s.split('\\')[-1]
    s = s.split('_')[:-1]
    s = "_".join(s)
    return s

In [87]:
def get_files_df(INPUT_GRAPHS_FOLDER, INFERENCE_GRAPHS_FOLDER):
    logging.info(f"Creating dataframe for {dataset_name} input/inference pairs")
    rdf_files = []
    for input_graph_path in tqdm(sorted(glob(INPUT_GRAPHS_FOLDER + "*"))):
        input_graph_file = os.path.basename(input_graph_path)
        inference_path = INFERENCE_GRAPHS_FOLDER + input_graph_file
        graph_type = get_graph_type(input_graph_path)
        rdf_pair = {"input_graph_file": input_graph_path, "inference_file": inference_path, "graph_type": graph_type}
        rdf_files.append(rdf_pair)
    files_df = pd.DataFrame.from_dict(rdf_files)
    return files_df

In [None]:
files_df = get_files_df(f'datasets/{dataset_name}_input_graphs_filtered_1hop/', f'datasets/{dataset_name}_inferred_graphs_filtered/')

In [89]:
# Remove classes for which only one instance exists
df_count = pd.DataFrame(files_df['graph_type'].value_counts())
graph_type_2_keep = df_count[df_count['count'] > 1].index
files_df = files_df[files_df['graph_type'].isin(graph_type_2_keep)]

In [90]:
def test_val_split(df, test_percent, stratify, seed):
    df_test, df_val = train_test_split(df, test_size=test_percent, random_state=seed, stratify=df[stratify])
    return df_test, df_val

In [91]:
rdf_data_test, rdf_data_val = test_val_split(files_df, 
                                             test_percent=0.5,
                                             stratify="graph_type",
                                             seed=1)

In [92]:
def merge_nt_files(nt_files, output_file):
    merged_graph = rdflib.Graph()

    for nt_file in nt_files:
        try:
            graph = rdflib.Graph()
            if 'TBOX' in nt_file: graph.parse(nt_file)
            else: graph.parse(nt_file, format="turtle")
            merged_graph += graph  
        except Exception as e:
            print(f"Warning: Could not parse {nt_file} - {e}")

    merged_graph.serialize(destination=output_file)
    print(f"Merged file created at {output_file}")

In [None]:
merge_nt_files(rdf_data_test['inference_file'], f'datasets/{dataset_name}_test_complete.owl')
merge_nt_files(rdf_data_val['inference_file'], f'datasets/{dataset_name}_val_complete.owl')

**Manage duplicates and drop BNodes**

In [None]:
G_train = rdflib.Graph()
G_train.parse(f'datasets/{dataset_name}.owl')
print(f'# triples in G_train: {len(G_train)}')

G_test = rdflib.Graph()
G_test.parse(f'datasets/{dataset_name}_test_complete.owl', format='turtle')
print(f'# triples in G_test: {len(G_test)}')

G_val = rdflib.Graph()
G_val.parse(f'datasets/{dataset_name}_val_complete.owl', format='turtle')
print(f'# triples in G_val: {len(G_val)}')

G_tbox = rdflib.Graph()
G_tbox.parse(f'datasets/{dataset_name}_TBOX.owl')
print(f'# triples in G_tbox: {len(G_tbox)}')

In [95]:
for triple in set(G_train.triples((None, None, OWL.NamedIndividual))):
    G_test.add(triple)
    G_val.add(triple)

In [96]:
G_test += G_tbox
G_val += G_tbox

In [97]:
def remove_bnodes(graph):
    new_graph = rdflib.Graph()
    for s, p, o in graph:
        if isinstance(s, BNode) or isinstance(p, BNode) or isinstance(o, BNode):
            continue  
        new_graph.add((s, p, o))
    return new_graph

In [98]:
filtered_G_train = remove_bnodes(G_train)
filtered_G_test = remove_bnodes(G_test)
filtered_G_val = remove_bnodes(G_val)

In [99]:
def get_entities(graph):
    classes = set()
    individuals = set()
    relations = set()

    for s, p, o in graph:
        if (s, RDF.type, OWL.Class) in graph:
            classes.add(s)
        elif (o, RDF.type, OWL.Class) in graph:
            classes.add(o)
        
        if (s, RDF.type, OWL.NamedIndividual) in graph:
            individuals.add(s)
        elif (o, RDF.type, OWL.NamedIndividual) in graph:
            individuals.add(o)
        
        if isinstance(p, URIRef):
            relations.add(p)

    return classes, relations, individuals

In [100]:
def remove_missing_entities(train_graph, graph_to_modify):
    train_classes, train_relations, train_individuals = get_entities(train_graph)
    
    classes, relations, individuals = get_entities(graph_to_modify)
    
    missing_classes = classes - train_classes
    missing_relations = relations - train_relations
    missing_individuals = individuals - train_individuals

    for s, p, o in list(graph_to_modify):
        if (s in missing_classes or o in missing_classes) or \
           (p in missing_relations) or \
           (s in missing_individuals or o in missing_individuals):
            graph_to_modify.remove((s, p, o))

    return graph_to_modify

In [101]:
# We remove classes, individuals or relations that have been inferred but are not present in the training set 
filtered_G_test = remove_missing_entities(filtered_G_train, filtered_G_test)
filtered_G_val = remove_missing_entities(filtered_G_train, filtered_G_val)

In [None]:
print(f'# triples in G_train: {len(filtered_G_train)}')
print(f'# triples in G_test: {len(filtered_G_test)}')
print(f'# triples in G_val: {len(filtered_G_val)}')

In [None]:
# Add validation set to training set because we are in an unsupervised setting (make sur that validation and test sets do not overlap)
new_filtered_G_val = rdflib.Graph()
for triple in set(filtered_G_val) - set(filtered_G_test):
    new_filtered_G_val.add(triple)

filtered_G_train += new_filtered_G_val
print(f'# triples in G_train: {len(filtered_G_train)}')
print(f'# triples in G_test: {len(filtered_G_test)}')
print(f'# triples in G_val: {len(new_filtered_G_val)}')

In [None]:
filtered_G_train.serialize(f'datasets/{dataset_name}_train.owl')
filtered_G_test.serialize(f'datasets/{dataset_name}_test.owl')
new_filtered_G_val.serialize(f'datasets/{dataset_name}_val.owl')