In [18]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import torch
import rdflib

import mowl
mowl.init_jvm('10g')
from mowl.datasets import PathDataset
from mowl.projection import OWL2VecStarProjector
from mowl.projection.edge import Edge
from mowl.walking import DeepWalk
from mowl.kge import KGEModel

from gensim.models.word2vec import LineSentence
from gensim.models import Word2Vec

from src.gnn import *

In [7]:
file_name = 'family'

**Split ontology into train/test ontologies**

In [29]:
def split_ontology(file_name, train_ratio):
    g = rdflib.Graph()
    g.parse(f'datasets/{file_name}.owl')  
    print(f'Triplets found: %d' % len(g))

    individuals = list(g.subjects(predicate=None, object=None)) 
    random.shuffle(individuals) 

    split_index = int(train_ratio * len(individuals))

    train_individuals = individuals[:split_index]
    test_individuals = individuals[split_index:]

    train_graph = rdflib.Graph()
    test_graph = rdflib.Graph()

    for individual in train_individuals:
        for s, p, o in g.triples((individual, None, None)):
            train_graph.add((s, p, o))

    for individual in test_individuals:
        for s, p, o in g.triples((individual, None, None)):
            test_graph.add((s, p, o))

    train_graph.serialize(destination=f'datasets/{file_name}_train.owl')
    print(f'Train Triplets found: %d' % len(train_graph))
    test_graph.serialize(destination=f'datasets/{file_name}_test.owl')
    print(f'Test Triplets found: %d' % len(test_graph))
    
    return train_graph, test_graph

In [30]:
train_graph, test_graph = split_ontology(file_name=file_name, train_ratio=0.8)

Triplets found: 5017
Train Triplets found: 4939
Test Triplets found: 2812


**OWL2Vec**

In [15]:
def owl2vec_fit(file_name, embed_dim, load):
    dataset = PathDataset(ontology_path=f'datasets/{file_name}_train.owl',
                          testing_path=f'datasets/{file_name}_test.owl')
    if not load:
        projector = OWL2VecStarProjector(bidirectional_taxonomy=True)
        edges = projector.project(dataset.ontology)
        walker = DeepWalk(num_walks=20,
                          walk_length=20,
                          alpha=0.1,
                          workers=4)         
        walks = walker.walk(edges)
        sentences = LineSentence(walker.outfile)
        model = Word2Vec(sentences, vector_size=embed_dim, epochs=10, window=5, min_count=1, workers=4)
        model.save(f'models/owl2vec.model')
    else:
        model = Word2Vec.load(f'models/owl2vec.model')
    return model

In [41]:
owl2vec_model = owl2vec_fit(file_name='family', embed_dim=200, load=True)

INFO:gensim.utils:loading Word2Vec object from models/owl2vec.model
INFO:gensim.utils:loading wv recursively from models/owl2vec.model.wv.* with mmap=None
INFO:gensim.utils:setting ignored attribute cum_table to None
INFO:gensim.utils:Word2Vec lifecycle event {'fname': 'models/owl2vec.model', 'datetime': '2024-03-12T11:39:15.978886', 'gensim': '4.3.1', 'python': '3.8.16 (default, Jan 17 2023, 22:25:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.22631-SP0', 'event': 'loaded'}


**Eval**

In [36]:
vectors = owl2vec_model.wv
words = list(owl2vec_model.wv.key_to_index)
output_owl2vec = torch.tensor(vectors[words])

In [37]:
nodes = list(set(words))
nodes_dict = {node: i for i, node in enumerate(nodes)}
nodes_dict_rev = {value: key for key, value in nodes_dict.items()}

In [38]:
i=0
edge_data = defaultdict(list)
for s, p, o in test_graph.triples((None, None, None)):
    s = s.n3()
    s = s.replace('<','')
    s = s.replace('>','')
    o = o.n3()
    o = o.replace('<','')
    o = o.replace('>','')
    try:
        src, dst = nodes_dict[s], nodes_dict[o]
        edge_data['edge_index'].append([src, dst])
    except:
        i+=1
edge_index = torch.tensor(edge_data['edge_index'])

In [39]:
hits1, hits10 = eval_hits(edge_index=edge_index,
                          tail_pred=1,
                          output=output_owl2vec,
                          max_num=edge_index.size(1))
print(f'Hits@1: {hits1:.3f}, Hits@10: {hits10:.3f}')

Hits@1: 0.000, Hits@10: 1.000
