In [None]:
import mowl
mowl.init_jvm("10g")

In [None]:
from mowl.datasets.base import PathDataset
source_owl = "../data/mouse.owl"
target_owl = "../data/human.owl"


In [None]:
from mowl.reasoning.normalize import GCI2
from org.semanticweb.owlapi.apibinding import OWLManager
from org.semanticweb.owlapi.model import IRI
from org.semanticweb.owlapi.search import EntitySearcher

manager = OWLManager()
factory = manager.getOWLDataFactory()

prefix = "http://mowl/"

def createExAxiom(subc, prop, filler):
    subc = factory.getOWLClass(IRI.create(f"{subc}"))
    prop = factory.getOWLObjectProperty(IRI.create(f"{prefix}has_annot_label"))
    filler = factory.getOWLClass(IRI.create(f"{prefix}{filler}"))

    axiom = factory.getOWLSubClassOfAxiom(
                subc, factory.getOWLObjectSomeValuesFrom(
                    prop, filler))
    return GCI2(axiom)
           
def format_value(value):
    value = value.lower()
    value = value.replace(" ", "_")
    return value

def getAnnotationsAsAxioms(ont_path_file):
    ds = PathDataset(ont_path_file)
    ont = ds.ontology
    classes = ont.getClassesInSignature()
    
    #Get annotations per class
    annots = {}
    for c in classes:
        annot = EntitySearcher.getAnnotations(c, ont)
        annots[c.toStringID()] = annot
        
        
    annots_as_axioms = []
    for k,v in annots.items():
        for a in v:
            if a.getValue().isLiteral() and a.getProperty().isLabel():
                property_ = a.getProperty().toString()
                value = str(a.getValue().asLiteral().get().getLiteral())
                value = format_value(value)
                annots_as_axioms.append(createExAxiom(k, property_, value))

    return annots_as_axioms

In [None]:
import torch.nn as nn
import torch as th
from mowl.reasoning.normalize import ELNormalizer
from normalizer import Normalizer
from mowl.datasets.el import GCI0Dataset, GCI1Dataset, GCI2Dataset, GCI3Dataset
from mowl.models.elboxembeddings.module import ELBoxModule
from tqdm import trange
import numpy as np

class ELBoxBasedModule(nn.Module):
    def __init__(self, num_classes_src, num_rels_src, num_classes_tar, num_rels_tar, emb_size):
        super().__init__()
        self.num_classes_src = num_classes_src
        self.num_classes_tar = num_classes_tar
        self.num_rels_src = num_rels_src
        self.num_rels_tar = num_rels_tar
        self.emb_size = emb_size
        
        self.elbox_src = ELBoxModule(
            self.num_classes_src,
            self.num_rels_src,
            embed_dim = emb_size
        )

        self.elbox_tar = ELBoxModule(
            self.num_classes_tar,
            self.num_rels_tar,
            embed_dim = emb_size
        )
        
        self.fc = nn.Sequential(
            nn.Linear(2*emb_size, emb_size),
            nn.ReLU(),
            nn.Linear(emb_size, emb_size),
            nn.Sigmoid()
        )

    def forward(self, data, model = None, gci_name = None, align=False, neg = False):
            if model == "src":
                return self.elbox_src(data, gci_name, neg = neg)
            elif model == "tar":
                return self.elbox_tar(data, gci_name, neg = neg)
            else:
                if align:
                    return self.fc(data)
        
class ELBoxModel():
    def __init__(self, source_ont_path, target_ont_path, emb_size = 50, device = "cpu", lr = 0.001, epochs = 1000):
        self.device = device
        self.lr = lr
        self.epochs = epochs
        self.source_ont_path = source_ont_path
        self.target_ont_path = target_ont_path
        
        self.source_ds = PathDataset(source_ont_path)
        self.target_ds = PathDataset(target_ont_path)

        self.model_filepath = "elboxbased.th"
        
        self.load_data()
        num_classes_src = len(self.source_ds.classes) + len(list(self.src_set))
        num_rels_src = len(self.source_ds.ontology.getObjectPropertiesInSignature()) + 1
        num_classes_tar = len(self.target_ds.classes) + len(list(self.tar_set))
        num_rels_tar = len(self.target_ds.ontology.getObjectPropertiesInSignature()) + 1
        
        self.module = ELBoxBasedModule(num_classes_src, num_rels_src, num_classes_tar, num_rels_tar, emb_size)

    def load_data(self):
        annots_ax_src = getAnnotationsAsAxioms(self.source_ont_path)
        annots_ax_tar = getAnnotationsAsAxioms(self.target_ont_path)
        
        self.src_set = set([x.filler for x in annots_ax_src])
        self.tar_set = set([x.filler for x in annots_ax_tar])
        self.common_annots = self.src_set & self.tar_set
        
        normalizer = Normalizer()
        self.source_el_gcis = normalizer.normalize(self.source_ds.ontology)
        self.target_el_gcis = normalizer.normalize(self.target_ds.ontology)
        
        self.source_el_gcis["gci2"] += annots_ax_src
        self.target_el_gcis["gci2"] += annots_ax_tar
        
        #Source
        self.class_to_id_src = {v:k for k,v in enumerate(self.source_ds.classes + list(self.src_set))}
        self.rel_to_id_src = {v:k for k,v in enumerate(self.source_ds.object_properties)}
        self.rel_to_id_src["http://mowl/has_annot_label"] = len(self.rel_to_id_src)
        
        src_gci0_dataset = GCI0Dataset(self.source_el_gcis["gci0"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        src_gci1_dataset = GCI1Dataset(self.source_el_gcis["gci1"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        src_gci2_dataset = GCI2Dataset(self.source_el_gcis["gci2"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        src_gci3_dataset = GCI3Dataset(self.source_el_gcis["gci3"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        src_gci0_bot_dataset = GCI0Dataset(self.source_el_gcis["gci0"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        src_gci1_bot_dataset = GCI1Dataset(self.source_el_gcis["gci1"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        src_gci3_bot_dataset = GCI3Dataset(self.source_el_gcis["gci3"], self.class_to_id_src, self.rel_to_id_src, device = self.device)
        
        self.source_el_datasets = {
            "gci0": src_gci0_dataset,
            "gci1": src_gci1_dataset,
            "gci2": src_gci2_dataset,
            "gci3": src_gci3_dataset,
            "gci0_bot": src_gci0_bot_dataset,
            "gci1_bot": src_gci1_bot_dataset,
            "gci3_bot": src_gci3_bot_dataset
        }
        
        #Target
        self.class_to_id_tar = {v:k for k,v in enumerate(self.target_ds.classes + list(self.tar_set))}
        self.rel_to_id_tar = {v:k for k,v in enumerate(self.target_ds.object_properties)}
        self.rel_to_id_tar["http://mowl/has_annot_label"] = len(self.rel_to_id_tar)
        
        tar_gci0_dataset = GCI0Dataset(self.target_el_gcis["gci0"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        tar_gci1_dataset = GCI1Dataset(self.target_el_gcis["gci1"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        tar_gci2_dataset = GCI2Dataset(self.target_el_gcis["gci2"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        tar_gci3_dataset = GCI3Dataset(self.target_el_gcis["gci3"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        tar_gci0_bot_dataset = GCI0Dataset(self.target_el_gcis["gci0"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        tar_gci1_bot_dataset = GCI1Dataset(self.target_el_gcis["gci1"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        tar_gci3_bot_dataset = GCI3Dataset(self.target_el_gcis["gci3"], self.class_to_id_tar, self.rel_to_id_tar, device = self.device)
        
        self.target_el_datasets = {
            "gci0": tar_gci0_dataset,
            "gci1": tar_gci1_dataset,
            "gci2": tar_gci2_dataset,
            "gci3": tar_gci3_dataset,
            "gci0_bot": tar_gci0_bot_dataset,
            "gci1_bot": tar_gci1_bot_dataset,
            "gci3_bot": tar_gci3_bot_dataset
        }
        
        src_annot_idxs = []
        tar_annot_idxs = []
        for ann in list(self.common_annots):
            src_annot_idxs.append(self.class_to_id_src[ann])
            tar_annot_idxs.append(self.class_to_id_tar[ann])
        
        self.src_annot_idxs = th.tensor(src_annot_idxs, device = self.device)
        self.tar_annot_idxs = th.tensor(tar_annot_idxs, device = self.device)
            
    def get_embeddings(self):
        src_embs = self.module.elbox_src.class_embed.weight
        tar_embs = self.module.elbox_tar.class_embed.weight
        
        src_embs = {k:v for k,v in zip(self.class_to_id_src.keys(), src_embs.cpu().detach().numpy())}
        tar_embs = {k:v for k,v in zip(self.class_to_id_tar.keys(), tar_embs.cpu().detach().numpy())}
        
        return src_embs, tar_embs

    
    def train(self):
        el_box_criterion = nn.MSELoss()
        alignment_criterion = nn.BCELoss()
        params = list(self.module.parameters())
        optimizer = th.optim.Adam(params, lr=self.lr)

        best_loss = float('inf')

        for epoch in trange(self.epochs):
            self.module.train()

            #Source space
            loss = 0
            for gci_name, gci_dataset in self.source_el_datasets.items():
                if len(gci_dataset) == 0 or gci_name == "gci0_bot":
                    continue
                #Positive scores
                rand_index = np.random.choice(len(gci_dataset), size = 512)
                data = gci_dataset[rand_index]
                dst = self.module(data, model = "src", gci_name = gci_name)
                mse_loss = el_box_criterion(dst, th.zeros(dst.shape, requires_grad = False).to(self.device))
                loss += mse_loss

                #Negative scores
                rand_index = np.random.choice(len(self.class_to_id_src), size = 512, replace = True)
                rand_index = th.tensor(rand_index, device = self.device)
                neg_data = th.cat([data[:, :-1], rand_index.unsqueeze(1)], dim = 1)

                dst = self.module(neg_data, model = "src", gci_name = gci_name, neg = True)
                mse_loss = el_box_criterion(dst, th.ones(dst.shape, requires_grad = False).to(self.device))
                loss += mse_loss

            #Target space
            for gci_name, gci_dataset in self.target_el_datasets.items():
                if len(gci_dataset) == 0 or gci_name == "gci0_bot":
                    continue
                #Positive scores
                rand_index = np.random.choice(len(gci_dataset), size = 512)
                data = gci_dataset[rand_index]
                dst = self.module(data, model = "tar", gci_name = gci_name)
                mse_loss = el_box_criterion(dst, th.zeros(dst.shape, requires_grad = False).to(self.device))
                loss += mse_loss
                #Negative scores
                rand_index = np.random.choice(len(self.class_to_id_tar), size = 512, replace = True)
                rand_index = th.tensor(rand_index, device = self.device)
                neg_data = th.cat([data[:, :-1], rand_index.unsqueeze(1)], dim = 1)


                dst = self.module(neg_data, model = "tar", gci_name = gci_name, neg = True)
                mse_loss = el_box_criterion(dst, th.ones(dst.shape, requires_grad = False).to(self.device))
                loss += mse_loss

            el_box_loss = loss
            #Alignment
            
            src_embs = self.module.elbox_src.class_embed(self.src_annot_idxs)
            tar_embs = self.module.elbox_tar.class_embed(self.tar_annot_idxs)
            embs = th.cat([src_embs, tar_embs], dim = 1)
            
            align_loss = self.module(embs, align = True)
            align_loss = alignment_criterion(align_loss, th.ones(align_loss.shape, device = align_loss.device))
            loss += align_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            el_box_loss += loss.detach().item()

            checkpoint = 100
            if best_loss > loss and (epoch+1) % checkpoint == 0:
                best_loss = loss
                print("Saving model..")
                th.save(self.module.state_dict(), self.model_filepath)
            if (epoch+1) % checkpoint == 0:
                print(f'Epoch {epoch}: ElBox loss: {el_box_loss} Align loss: {align_loss}')

        
        

In [None]:
model = ELBoxModel(source_owl, target_owl)

In [None]:
model.train()

In [None]:
src_embs, tar_embs = model.get_embeddings()

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from mowl.datasets.base import PathDataset
source_owl = "../data/mouse.owl"
target_owl = "../data/human.owl"

source_ont_classes = PathDataset(source_owl).classes
target_ont_classes = PathDataset(target_owl).classes

src_to_id = {v:k for k,v in enumerate(src_embs.keys())}
tar_to_id = {v:k for k,v in enumerate(tar_embs.keys())}

src_to_id = {k:v for k,v in src_to_id.items() if k in source_ont_classes}
tar_to_id = {k:v for k,v in tar_to_id.items() if k in target_ont_classes}

scores = cosine_similarity(np.array(list(src_embs.values())), np.array(list(tar_embs.values()))) # in range [-1, 1]
scores = (scores + 1)/2 # in range [0,1]
with open("elbox_scores.tsv", "w") as f:
    f.write("First_Ontology_Class\tSecond_Ontology_Class\tScore\tRelation\n")
    for src_cls in tqdm(src_to_id, total = len(src_to_id)):
        for tar_cls in tar_to_id:
            src_idx = src_to_id[src_cls]
            tar_idx = tar_to_id[tar_cls]
            f.write(f"{src_cls}\t{tar_cls}\t{scores[src_idx, tar_idx]}\t=\n")

In [None]:
from evaluate import evaluate
avg_prec, auc = evaluate("elbox_scores.tsv")

In [None]:
print(avg_prec, auc)