In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import time
import torch
import rdflib
from sklearn.metrics import precision_score, recall_score, f1_score

from src.utils import *
from src.gnn import *
from src.sparql_queries import *

# 1. Data

In [2]:
g = rdflib.Graph()
g.parse('datasets/family.owl')

print(f'Triplets found: %d' % len(g))

Triplets found: 5017


In [3]:
relations = list(set(g.predicates()))
nodes = list(set(g.subjects()).union(set(g.objects())))

relations_dict = {rel: i for i, rel in enumerate(relations)}
nodes_dict = {node: i for i, node in enumerate(nodes)}

nodes_dict_rev = {value: key for key, value in nodes_dict.items()}
relations_dict_rev = {value: key for key, value in relations_dict.items()}

In [4]:
data = get_data(g, nodes_dict, relations_dict)
data = split_edges(data)

In [5]:
data

HeteroData(
  edge_index=[2, 5017],
  edge_type=[5017],
  val_pos_edge_index=[2, 0],
  val_edge_type=[0],
  test_pos_edge_index=[2, 1003],
  test_edge_type=[1003],
  train_pos_edge_index=[2, 4014],
  train_edge_type=[4014]
)

# 2. GNN

**Train**

In [6]:
st = time.time()
model = GNN()

for epoch in range(300+1):
    loss = model._train(data, len(nodes), len(relations))
    if (epoch % 100) == 0:
        hits1, hits10 = model._eval(data)
        print(f'Epoch: {epoch}, Loss: {loss:.4f}, Hits@1: {hits1:.3f}, Hits@10: {hits10:.3f}')

torch.save(model, f'models/RGCN')
et = time.time()
elapsed_time = et - st
print(f'Run time: {elapsed_time:.0f} seconds, {elapsed_time/60:.0f} minutes')

Epoch: 0, Loss: 0.6932, Hits@1: 0.931, Hits@10: 0.962
Epoch: 100, Loss: 0.6931, Hits@1: 0.931, Hits@10: 0.966
Epoch: 200, Loss: 0.6932, Hits@1: 0.932, Hits@10: 0.960
Epoch: 300, Loss: 0.6930, Hits@1: 0.935, Hits@10: 0.968
Run time: 445 seconds, 7 minutes


**Eval**

In [7]:
model = torch.load(f'models/RGCN')
hits1, hits10 = model._eval(data)
print(f'Hits@1: {hits1:.3f}, Hits@10: {hits10:.3f}')

Hits@1: 0.935, Hits@10: 0.968


# 3. Generate New Links

### GNN: we add new links with a low prediction score to the ontology

In [6]:
def add_triples_gnn(g, data, k):
    new_g_gnn = copy_graph(g)
    for etype in tqdm(range(len(relations))):    
        mask = data.edge_type == etype
        edge_index = torch.tensor([data.edge_index[0,mask].tolist(),data.edge_index[1,mask].tolist()])
        edge_type = data.edge_type[mask]

        output = model.model.encode(edge_index, edge_type)

        link_pred_scores = torch.matmul(output, output.T)
        output_norm = torch.norm(output, dim=1, keepdim=True)
        link_pred_scores_norm = link_pred_scores / (output_norm * output_norm.T)
        link_pred_scores_norm[edge_index[0,:],edge_index[1,:]] = 1

        # Find the indices of the top k smallest elements
        _, topk_indices = torch.topk(link_pred_scores_norm.flatten(), k*2, largest=False)
        row_indices = topk_indices // link_pred_scores_norm.size(1)
        col_indices = topk_indices % link_pred_scores_norm.size(1)

        # Filter out indices where row index is greater than column index
        valid_indices_mask = row_indices < col_indices
        row_indices = row_indices[valid_indices_mask]
        col_indices = col_indices[valid_indices_mask]
        
        # Add generated triples
        node1_lst = [nodes_dict_rev[key] for key in row_indices.tolist()]
        node2_lst = [nodes_dict_rev[key] for key in col_indices.tolist()]
        edge_type_uri = relations_dict_rev[etype]
        new_g_gnn = add_links(new_g_gnn, node1_lst, node2_lst, edge_type_uri)
        
    return new_g_gnn

### Random: we add random links to the ontology

In [7]:
def add_triples_random(g, data, k):
    new_g_random = copy_graph(g)
    for etype in tqdm(range(len(relations))):  
        mask = data.edge_type == etype
        edge_index = torch.tensor([data.edge_index[0,mask].tolist(),data.edge_index[1,mask].tolist()])
        
        neg_edge_index = negative_sampling(data.edge_index, num_neg_samples = k)
        
        # Add generated triples
        node1_lst = [nodes_dict_rev[key] for key in neg_edge_index[0,:].tolist()]
        node2_lst = [nodes_dict_rev[key] for key in neg_edge_index[1,:].tolist()]
        edge_type_uri = relations_dict_rev[etype]
        new_g_random = add_links(new_g_random, node1_lst, node2_lst, edge_type_uri)

    return new_g_random

# 4. Experiments

In [8]:
query1,query2,query3 = get_queries()

In [None]:
# Add k triples per edge_type
k = 100000
model = torch.load(f'models/RGCN')

new_g_gnn = add_triples_gnn(g, data, k)
new_g_random = add_triples_random(g, data, k)

print(f'Triplets found: %d' % len(new_g_gnn))
print('Contradictions:')
for q in [query1,query2,query3]:
    print_result(new_g_gnn, q)

print(f'Triplets found: %d' % len(new_g_random))
print('Contradictions:')
for q in [query1,query2,query3]:
    print_result(new_g_random, q)

###########################################################################################################

query = """
PREFIX owl: <http://www.w3.org/2002/07/owl#>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX fo: <http://www.co-ode.org/roberts/family-tree.owl#>
SELECT ?personA ?personB WHERE {
 
 ?personA rdf:type owl:NamedIndividual .
 ?personB rdf:type owl:NamedIndividual .

 { ?personA fo:hasSister ?personB }
 UNION
 { ?personA fo:hasBrother ?personB } .

 { ?personA fo:hasMother ?personB }
 UNION
 { ?personA fo:hasFather ?personB } .

 FILTER (?personA != ?personB)
}
"""

qres = new_g_gnn.query(query)
for row in qres:
    print(f"{row.personA}, {row.personB}")

hasMalePartner,hasFemalePartner

james_bright_1809 = URIRef('http://www.co-ode.org/roberts/family-tree.owl#james_bright_1809')
elisa_amelia_hewett_1858 = URIRef('http://www.co-ode.org/roberts/family-tree.owl#elisa_amelia_hewett_1858') 
hasSister = URIRef('http://www.co-ode.org/roberts/family-tree.owl#hasSister') 
hasMother = URIRef('http://www.co-ode.org/roberts/family-tree.owl#hasMother') 

g.add((james_bright_1809,hasSister,elisa_amelia_hewett_1858))
g.add((james_bright_1809,hasMother,elisa_amelia_hewett_1858))

qres = g.query(query)
for row in qres:
    print(f"{row.contradictions}")