In [1]:
import pandas as pd
import numpy as np
import time
import torch
from torch_geometric.data import HeteroData
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GATConv, Linear, to_hetero
import rdflib
from rdflib import URIRef
from sklearn.metrics import precision_score, recall_score, f1_score

from src.utils import *
from src.gnn import *
from src.sparql_queries import *

# 1. Data

In [2]:
g = rdflib.Graph()
g.parse('datasets/family.owl')

print(f'Triplets found: %d' % len(g))

Triplets found: 5017


In [3]:
relations = list(set(g.predicates()))
nodes = list(set(g.subjects()).union(set(g.objects())))

relations_dict = {rel: i for i, rel in enumerate(relations)}
nodes_dict = {node: i for i, node in enumerate(nodes)}
nodes_dict_rev = {value: key for key, value in nodes_dict.items()}

In [4]:
data = get_data(g, nodes, nodes_dict, relations_dict)

In [5]:
data = split_edges(data)

In [6]:
data

HeteroData(
  edge_index=[2, 5017],
  edge_type=[5017],
  node_type=[1909, 1909],
  val_pos_edge_index=[2, 0],
  test_pos_edge_index=[2, 247],
  train_pos_edge_index=[2, 2228]
)

# 2. GNN

In [7]:
GNN_variant = 'GAT'

**Train**

In [8]:
print(f'{GNN_variant}:')
st = time.time()
model = GNN()
model._train(GNN_variant, data, nodes, 0.5)
torch.save(model, f'models/{GNN_variant}')
et = time.time()
elapsed_time = et - st
print(f'Run time: {elapsed_time:.0f} seconds, {elapsed_time/60:.0f} minutes')

GAT:
Epoch: 0, Loss: 0.3927
Epoch: 300, Loss: 0.1706
Run time: 11 seconds, 0 minutes


**Eval**

In [9]:
print(f'{GNN_variant}:')
model = torch.load(f'models/{GNN_variant}')
model._eval(GNN_variant, data, 0.5)

GAT:
hits@1: 0.146, hits@10: 0.826


# 3. Generate New Links

### GNN

In [10]:
output = model.model(model.node_embeds, data.edge_index)

In [11]:
link_pred_scores = torch.matmul(output, output.T)
output_norm = torch.norm(output, dim=1, keepdim=True)
link_pred_scores_norm = link_pred_scores / (output_norm * output_norm.T)

In [12]:
def add_triples_gnn(nodes_dict_rev, link_pred_scores_norm, g, k):
    _, topk_indices = torch.topk(link_pred_scores_norm.flatten(), k, largest=False)
    row_indices = topk_indices // link_pred_scores_norm.size(1)
    col_indices = topk_indices % link_pred_scores_norm.size(1)
    
    node1_lst = [nodes_dict_rev[key] for key in row_indices.tolist()]
    node2_lst = [nodes_dict_rev[key] for key in col_indices.tolist()]
    
    new_g = copy_graph(g)
    
    for node1, node2 in zip(node1_lst,node2_lst):
        hasFather = URIRef('http://www.co-ode.org/roberts/family-tree.owl#hasFather') 
        new_g.add((node1,hasFather,node2))
        new_g.add((node2,hasFather,node1))
        
        hasMother = URIRef('http://www.co-ode.org/roberts/family-tree.owl#hasMother') 
        new_g.add((node1,hasMother,node2))
        new_g.add((node2,hasMother,node1))
        
    return new_g

In [13]:
new_g = add_triples_gnn(nodes_dict_rev, link_pred_scores_norm, g, k=100)

In [14]:
query1, query2 = get_queries()

In [15]:
for q in [query1,query2]:
    print_result(g, q)

0
0


In [16]:
for q in [query1,query2]:
    print_result(new_g, q)

100
100


In [None]:
# must be heterogeneous !