## Preparing data and indexing to be used for experiment using OMEXBioModels
- We are using a file consisting of BioModels annotations in RDF format
- The annotation standard is using OMEX
- This file is provided by TR&D2 Group

## CREATING EMBEDDINGS FOR ENTITIES, PREDICATES, AND ONTOLOGY CLASSES

#### Required python packages

In [None]:
import rdflib
import io
import json
from sentence_transformers import util
import torch
from tqdm import tqdm
import copy
import urllib.request   

#### Loading BioModels OMEX RDF file into RDFLib graph

In [None]:
filename = 'casbert_resources/AllBioModelsOMEXRDF_latest.rdf'
with io.open(filename, 'r', encoding='"ISO-8859-1"') as f:
    text = f.read()

In [None]:
g = rdflib.Graph()
g.parse(data=text, format='xml')

#### Extract triples of entity, predicate, and ontology classes from RDFLib graph

In [None]:
def getPathToObjs(s, g):
    pathToObjs = []
    tmpObjPath = {o:[p] for p, o in g.predicate_objects(subject=s)}
    while len(tmpObjPath) > 0:
        objKeys = tmpObjPath.copy()
        for o in objKeys:
            children = list(g.predicate_objects(subject=o))
            if len(children) == 0:
                pathToObjs += [{'p': tmpObjPath[o], 'o': o}]
            else:
                for pred, obj in children:
                    tmpObjPath[obj] = tmpObjPath[o] + [pred]
            del tmpObjPath[o]
    return pathToObjs

In [None]:
# get entities, queries, ontology classes

# initialisation:
entities = {}  # subject:{'path':[paths], 'object':[objects], 'desc':desc}
classes = {}   # {class1:{'num':num, 'onto':onto, 'text':[]}, class2:{'num':num, 'onto':onto, 'text':[]}}
predicates = {}

#get all entities / subjects, objects, tracks:
for s, p, o in tqdm(g):
    try:
        predicates[p.n3().rsplit('/',1)[-1][:-1].rsplit('#',1)[-1]] = ''
        if len(list(g.subject_predicates(object=s))) == 0:
            entity = {'id':len(entities), 'subject':s.n3()[s.n3().rfind('/')+1:-1], 'path':{}, 'object':[], 'desc':''}
            predicates_objs = getPathToObjs(s, g)
            for predicates_obj in predicates_objs:
                # if object is class ontology

                if predicates_obj['o'].startswith('http') and 'doi' not in predicates_obj['o'] and 'biomodels.db' not in predicates_obj['o']:
                    ontoType, classId = predicates_obj['o'].n3()[:-1].split('/')[-2:]
                    if classId not in classes: classes[classId] = {'num':0, 'onto':ontoType, 'text':[]}
                    classes[classId]['num'] += 1
                    entity['object'] += [classId]

                    # get paths
                    if classId not in entity['path']: entity['path'][classId] = []
                    path = [p.n3().rsplit('/',1)[-1][:-1].rsplit('#',1)[-1] for p in predicates_obj['p'] if not p.endswith('/is') and not p.endswith('/hasPhysicalDefinition')]
                    entity['path'][classId] += [path]

            entities[entity['subject']] = entity
    except:
        pass

#### Organising ontology classes in the type of ontology

In [None]:
# get ontologies:
ontologies = {}
statOnto = {}
for classId, data in classes.items():
    if data['onto'] not in ontologies: 
        ontologies[data['onto']] = []
        statOnto[data['onto']] = 0
    ontologies[data['onto']] += [classId]
    statOnto[data['onto']] += data['num']

In [None]:
# normalised data. merge same ontology such as obo.pw and pw, obo.go and go
for onto in list(statOnto.keys()):
    if 'obo.' in onto:
        ontoName = onto[onto.find('.')+1:]
        if ontoName in ontologies:
            ontologies[ontoName] += ontologies[onto]
            statOnto[ontoName] += statOnto[onto]
            del ontologies[onto]
            del statOnto[onto]
    if '.ref' in onto:
        ontoName = onto[:onto.find('.')]
        ontologies[ontoName] += ontologies[onto]
        statOnto[ontoName] += statOnto[onto]
        del ontologies[onto]
        del statOnto[onto]
    try:
        if statOnto[onto] < 100:
            del ontologies[onto]
            del statOnto[onto]
    except:
        pass

#### Extracting ontology class terms
We extract ontology class terms from: 
Uniprot, Taxonomy, Kegg.pathway, Reactome, CHEBI, GO, FMA, PR, 
CL, EFO, OBI, PW, MAMO, BTO, NCIT, OPB

In [None]:
# get uniprot data
## uniprot

for uni in tqdm(ontologies['uniprot']):
    classId = uni.split('/')[-1]
    try:
        contents = urllib.request.urlopen("https://www.uniprot.org/uniprot/"+classId+".txt").read()
        contents = contents.decode('utf-8').split('\n')
        names = []
        for content in contents:
            if content.startswith('OS '):
                break
            if content.startswith('DE '):
                cts = ''
                if 'Full=' in content:
                    cts = content[content.index('Full=')+5:]
                if 'Short=' in content:
                    cts = content[content.index('Short=')+6:]
                if cts != '':
                    names += [cts[:cts.index('{')] if '{' in cts else cts]
        classes[classId]['text'] = names

#         break
    except:
        pass
    

In [None]:
# get taxonomy data
## taxonomy
for uni in tqdm(ontologies['taxonomy']):
    classId = uni.split('/')[-1]
    try:
        contents = urllib.request.urlopen("https://www.uniprot.org/taxonomy/"+classId+".rdf").read()
        contents = contents.decode('utf-8').split('\n')
        names = []
        for content in contents:
            if content.startswith('<commonName>') or content.startswith('<scientificName>') or content.startswith('<otherName>') or content.startswith('<synonym>'):
                names += [content[content.find('>')+1:content.rfind('<')]]
        classes[classId]['text'] = list(set(names))
    except:
        pass


In [None]:
# get kegg.pathway data
## kegg.pathway
for uni in tqdm(ontologies['kegg.pathway']):
    classId = uni.split('/')[-1]
    classes[classId]['text'] = []
    try:
        contents = urllib.request.urlopen("http://rest.kegg.jp/get/"+classId).read()
        contents = contents.decode('utf-8').split('\n')
        names = []
        for content in contents:
            if content.startswith('NAME'):
                names += [content.split('NAME')[-1].strip()]
                break
        classes[classId]['text'] = names        
    except:
        pass


In [None]:
# get reactome data
## reactome
for uni in tqdm(ontologies['reactome']):
    classId = uni.split('/')[-1]
    classes[classId]['text'] = []
    try:
        contents = urllib.request.urlopen("https://reactome.org/ContentService/data/query/"+classId).read()
        contents = json.loads(contents)
        names = contents['name']
        names += [contents['speciesName']] if 'speciesName' in contents else []
        classes[classId]['text'] = names        
    except:
        pass


In [None]:
def onto_bio_portal(linkTemplate, ontoName, lower=False, replace=('',''), preff=''):
    for uni in tqdm(ontologies[ontoName][:5]):
        try:
            classId = uni
            if lower: 
                classId=classId.lower()
            classId=classId.replace(replace[0],replace[1])
            classId = preff + classId
            link = linkTemplate.replace('$classId$', classId)
#             print(link)
            contents = urllib.request.urlopen(link).read()
            contents = json.loads(contents)
            classes[uni]['text'] = [contents['prefLabel']] + [s in contents['synonym']]
        except:
            print(uni)
#         break

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/CHEBI/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'chebi',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/GO/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'go',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/FMA/classes/http%3A%2F%2Fpurl.org%2Fsig%2Font%2Ffma%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'fma', True, (':',''))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/PR/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'pr',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/CL/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'cl',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/EFO/classes/http%3A%2F%2Fwww.ebi.ac.uk%2Fefo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'efo', preff='EFO_')

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/OBI/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'obi',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/PW/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'pw',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/MAMO/classes/http%3A%2F%2Fidentifiers.org%2Fmamo%2F$classId$?apikey=fc5d5241-1e8e-4b44-b401-310ca39573f6'
onto_bio_portal(linkTemplate, 'mamo',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/BTO/classes/http%3A%2F%2Fpurl.obolibrary.org%2Fobo%2F$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'bto',False, (':','_'))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/NCIT/classes/http%3A%2F%2Fncicb.nci.nih.gov%2Fxml%2Fowl%2FEVS%2FThesaurus.owl%23$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'ncit',False, ('NCIT:',''))

In [None]:
linkTemplate = 'https://data.bioontology.org/ontologies/OPB/classes/http%3A%2F%2Fbhi.washington.edu%2FOPB%23$classId$?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb'
onto_bio_portal(linkTemplate, 'opb')

In [None]:
with open('bert_resources/omex_classes.json', 'w') as fp:
    json.dump(classes, fp)
    
with open('casbert_resources/omex_entities.json', 'w') as fp:
    json.dump(entities,fp)

#### CREATE CLASS EMBEDDING

In [None]:
BERTModel = 'multi-qa-MiniLM-L6-cos-v1'
from sentence_transformers import SentenceTransformer, models
_model = SentenceTransformer(BERTModel)

In [None]:
for classId in tqdm(list(classes)):
    if len(classes[classId]['text']) > 0:
        classes[classId]['embedding'] = torch.mean(_model.encode(classes[classId]['text'], convert_to_tensor=True), dim=0)
    else:
        del classes[classId]

In [None]:
torch.save(predicates, 'casbert_resources/omex_predicates.pt')

#### CREATE PREDICATE EMBEDDING

In [None]:
import re
import random
def camel_case_split(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]

for p in tqdm(predicates):
    predicates[p] = _model.encode(' '.join(camel_case_split(p)), convert_to_tensor=True)

In [None]:
torch.save(predicates, 'casbert_resources/omex_predicates.pt')

#### CREATE ENTITY EMBEDDING

In [None]:
entityEmbeddings = {'entityIds':[], 'class':[], 'class_predicate':[]}
alpha = 0.22
for entityId, entity in tqdm(entities.items()):
    if len(entity['object']) == 0: continue
    objects = [classes[obj]['embedding'] for obj in entity['object'] if obj in classes]
    entityEmbeddings['class'] += [torch.mean(torch.stack(objects), dim=0)]
    entityEmbeddings['entityIds'] += [entityId]
    pathEmbeddings = []
    for obj, paths in entity['path'].items():
        if obj not in classes: continue
        objEmbedding = classes[obj]['embedding']
        try:
            trackEmbedding = torch.mean(torch.stack([torch.mean(torch.stack([predicates[p] for p in path]), dim=0) for path in paths]), dim=0)
            pathEmbeddings += [torch.mean(torch.stack([objEmbedding + alpha * trackEmbedding]), dim=0)]
        except:
            pathEmbeddings += [objEmbedding]
    entityEmbeddings['class_predicate'] += [torch.mean(torch.stack(pathEmbeddings), dim=0)]
        

In [None]:
entityEmbeddings['class'] = torch.stack(entityEmbeddings['class'])
entityEmbeddings['class_predicate'] = torch.stack(entityEmbeddings['class_predicate'])

In [None]:
torch.save(entityEmbeddings, 'casbert_resources/omex_entities.pt')

### Performance measure using Mean average Precision

In [None]:
def averagePrecision(prediction):
    if 1 not in prediction:
        return 0
    tot = 0
    for idx, p in enumerate(prediction):
        if p>0:
            tot += sum(prediction[0:idx+1])/(idx+1)
    return tot/sum(prediction)

def meanAP(predictions):
    tap = 0
    stat = {}
    for idx, prediction in enumerate(predictions):
        ap = averagePrecision(prediction)
        stat[idx] = ap
        tap += ap
    return {'score':tap/len(predictions), 'stat':stat}

def meanRR(predictions):
    trr = 0
    stat = {}
    for idx, prediction in enumerate(predictions):
        rr = 1/(prediction.index(1)+1) if 1 in prediction else 0
        stat[idx] = rr
        trr += rr
    return {'score':trr/len(predictions), 'stat':stat}

def getMAP(queries, searchFunction, indexType=None, pathType=None, topK=10, minSim=0.5):
    predictions = []
    for query, facts in tqdm(queries.items()):
        results = searchFunction(query=query, topK=topK, indexType=indexType, pathType=pathType, minSim=minSim)
        predictions += [[1 if varId in facts['vars'] else 0 for varId in results]]
    MAP = meanAP(predictions)
    return {'MAP':MAP,'MRR':meanRR(predictions)}

### CREATING QUERY _ ENTITY SET FOR EXPERIMENT

In [None]:
# get queries
# get entities, queries, ontology classes

# initialisation:
queries = {}   # {'query1':{'vars':entityIds, 'score':score}, 'query2':{'vars':entityIds, 'score':score}}
qRes = []

#get all entities / subjects, objects, tracks:
for s, p, o in tqdm(g):
    if len(list(g.subject_predicates(object=s))) == 0:
        predicates_objs = getPathToObjs(s, g)
        for predicates_obj in predicates_objs:
            if predicates_obj['p'][-1].endswith('description') and len(predicates_obj['o']) < 200:
                q = predicates_obj['o'].n3()
                while q[0] == '"': q = q[1:]
                while q[-1] == '"': q = q[:-1]
                q = " ".join(q.split())    
                try:
                    q = re.search(r'(?<=>).*(?=<)' , q).group(0)
                except:
                    pass
                if q not in queries: 
                    queries[q] = {'vars':[], 'score':0}
                entId = s.n3()[s.n3().rfind('/')+1:-1]
                if entId in entityEmbeddings['entityIds']:
                    qEmbedding = _model.encode(q, convert_to_tensor=True)
                    pos = entityEmbeddings['entityIds'].index(entId)
                    score = util.pytorch_cos_sim(qEmbedding, entityEmbeddings['class'][pos])
                    queries[q]['vars'] += [entId]
                    if score > queries[q]['score']:
                        queries[q]['score'] = score
                    qRes += [s.n3()[s.n3().rfind('/')+1:-1]]
                


In [None]:
def getPathToObjs2(s, g):
    pathToObjs = []
    tmpObjPath = {o:[(s,p)] for p, o in g.predicate_objects(subject=s)}
    while len(tmpObjPath) > 0:
        objKeys = tmpObjPath.copy()
        for o in objKeys:
            children = list(g.predicate_objects(subject=o))
            if len(children) == 0:
                pathToObjs += [{'p': tmpObjPath[o], 'o': o}]
            else:
                for pred, obj in children:
                    tmpObjPath[obj] = tmpObjPath[o] + [(o,pred)]
            del tmpObjPath[o]
    return pathToObjs

In [None]:
# get queries
# get entities, queries, ontology classes

# initialisation:
queries = {}   # {'query1':{'vars':entityIds, 'score':score}, 'query2':{'vars':entityIds, 'score':score}}
qRes = []

#get all entities / subjects, objects, tracks:
for s, p, o in tqdm(g):
    if len(list(g.subject_predicates(object=s))) == 0:
        predicates_objs = getPathToObjs2(s, g)
        for predicates_obj in predicates_objs:
            if not predicates_obj['o'].startswith('http') and len(predicates_obj['o'])<200:
                if predicates_obj['o'] in predicates_obj['p'][-1][0]:
                    continue
                
                q = predicates_obj['o'].n3()
                while q[0] == '"': q = q[1:]
                while q[-1] == '"': q = q[:-1]
                q = " ".join(q.split())    
                try:
                    q = re.search(r'(?<=>).*(?=<)' , q).group(0)
                except:
                    pass
                if q not in queries: 
                    queries[q] = {'vars':[], 'score':0}
                entId = s.n3()[s.n3().rfind('/')+1:-1]
                if entId in entityEmbeddings['entityIds']:
                    qEmbedding = _model.encode(q, convert_to_tensor=True)
                    pos = entityEmbeddings['entityIds'].index(entId)
                    score = util.pytorch_cos_sim(qEmbedding, entityEmbeddings['class'][pos]).item()
                    queries[q]['vars'] += [entId]
                    if score > queries[q]['score']:
                        queries[q]['score'] = score
                    qRes += [s.n3()[s.n3().rfind('/')+1:-1]]
                
len(queries)

In [None]:
for q in queries:
    queries[q]['vars'] = list(set(queries[q]['vars']))
len(queries)

#### Filter queries

In [None]:
## - remove queries with similarity score to the entities are 0
## - removing entity id in queries that do not have embedding
## - remove query that do not have variable


count = 0
for q in tqdm(list(queries.keys())):
    if queries[q]['score'] == 0:
        del queries[q]
#         continue
#     queries[q]['vars'] = [entityId for entityId in queries[q]['vars'] if entityId in entityEmbeddings['entityIds']]
#     if len(queries[q]['vars']) == 0:
#         del queries[q]
        count += 1

print('deleted queries: ', count, '; number of candidate queries: ', len(queries))

In [None]:
## - remove queries containing number only
for q in list(queries.keys()):
    test = q.replace(".", "", 1)
    if test.isdigit():
        del queries[q]

In [None]:
## - remove queries containing variable name pattern only

for q in list(queries.keys()):
                    
    status = any([q.replace("v", "", 1).isdigit(), # v1, v2, v3, ...
                 q.replace("reaction_", "", 1).isdigit(), # reaction_1, reaction_2, ...
                 q.replace("r", "", 1).isdigit(), # r1, r2, r3, ...
                 q.replace("R", "", 1).isdigit(), # R1, R2, R3, ...
                 q.replace("rel", "", 1).isdigit(), # rel1, rel2, rel3, ...
                 q.replace("re", "", 1).isdigit(), # re1, re2, re3, ...
                 q.replace("R_", "", 1).isdigit(), # R_1, R_2, R_3, ...
                 ])
    
    if status:
        del queries[q]
        
    if q.lower().startswith('compartment'):
        del queries[q]
    

In [None]:
## select queries with similarity >=0.4
for q in list(queries.keys()):
    if queries[q]['score'] < 0.5:
        del queries[q]
        
len(queries)

In [None]:
## removing queries with too many entities since it cannot represent MAP and MRR well
for q in list(queries.keys()):
    if len(queries[q]['vars']) > 50:
        del queries[q]
        
len(queries)

In [None]:
## removing queries containing only one ontology class concept
for q in tqdm(list(queries.keys())):
    if len(nlp(q).ents) == 1:
        del queries[q]
        
len(queries)

In [None]:
## remove entities related to queries where the number of ontology classes smaller than the number of concept in the query
## if there is no entity related to the query, the query is deleted
for q in tqdm(list(queries.keys())):

    v = queries[q]
    # get ontology class concept
    concepts = {}
    for ent in nlp(q).ents:
        entClasses = getClasses(ent.text, topK=10)
        for classId, classDef in entClasses.items():
            if classDef[1] < 0.6:
                break
            if ent not in concepts:
                concepts[ent] = []
            concepts[ent] += [classId]
    
    # delete short entities
    for entId in copy.deepcopy(v['vars']):
        if len(concepts) == 0: continue
        if len(entities[entId]['object']) / len(concepts) < 0.6:
            v['vars'].remove(entId)
            
    # delete queries with no entity
    if len(v['vars']) == 0:
        del queries[q]

len(queries)

In [None]:
## remove queries with no relevancies between query and entities

predictions = {}
for query in tqdm(list(queries.keys())):
    results = entitySearch(query=query)
    predictions[query] = averagePrecision([1 if varId in queries[query]['vars'] else 0 for varId in results])
    if predictions[query] == 0:
        del queries[query]
    
len(queries)

#### Enrich test data
In the test_data.csv, each query is associated to a limited number of entities. There is a possibility that the query is also associated to other entities. Therefore, we need to enrich the test_data with any possible entities. This enriched data than named as 'silver data'

##### Create dictionary ontology class to variables
It is used to enrich test data

In [None]:
class2Vars = {}

for varId in entityEmbedding['entityIds']:
    value = entities[varId]
    for classId in value['object']:
        if classId not in class2Vars: class2Vars[classId] = []
        class2Vars[classId] += [varId]             

# enrich with similar entities
def enrichQueryWithOtherEntities(dictQueries):
    for q, v in tqdm(dictQueries.items()):
        newVarIds = []
        for varId in v['vars']:
            newVarIds += [varId]
            lstVars = [class2Vars[c] for c in entities[varId]['object']]
            otherVarIds = list(set.intersection(*map(set,lstVars)))
            newVarIds += otherVarIds

        dictQueries[q]['vars'] = list(set(dictQueries[q]['vars'] + newVarIds))

In [None]:
enrichQueryWithOtherEntities(queries)

In [None]:
# enrich with entities having similar pattern with queries
import itertools
for q in tqdm(list(queries.keys())):
    
    # get ontology class concept
    concepts = {}
    for ent in nlp(q).ents:
        entClasses = getClasses(ent.text, topK=10)
        for classId, classDef in entClasses.items():
            if ent not in concepts:
                concepts[ent] = []
            if classDef[1] < 0.6: break
            concepts[ent] += [classId]
    
    # check each concept
    candEntities = []
    for concept, classIds in concepts.items():
        eIds = [class2Vars[classId] for classId in classIds if classId in class2Vars]
        candEntities += [list(set(itertools.chain.from_iterable(eIds)))]
    if len(candEntities) > 0:
        newEntities = list(set.intersection(*map(set,candEntities)))
        queries[q]['vars'] = list(set(queries[q]['vars'] + newEntities))

len(queries)

In [None]:
queryTest = {'noPredicate':queries, 'withPredicate':{}, 'combine':{}}

#### Generate query test with predicate

In [None]:
queryTest['withPredicate'] = {}
for q, variables in tqdm(queryTest['noPredicate'].items()):
    doc = nlp(q)
    entToClasses = {}
    
    for ent in doc.ents:
        entToClasses[ent] = list(getClasses(ent.text, topK=10).keys())
        
    for entityId in variables['vars']:
        _classes = entities[entityId]['path']
        clsPredicate = {}
        query = q
        for entClasses in entToClasses.values():
            for classId in entClasses:
                if classId in _classes and classId not in clsPredicate and bool(random.getrandbits(1)):
                    entPredicates = random.choice(_classes[classId])
                    if len(entPredicates) > 0:
                        # we select randomnly a predicate from position -3 to the end
                        # the reason is that those predicates are more likely more relevant related to the ontology class
                        entPredicate = random.choice(entPredicates[-3:])
                        clsPredicate[classId] = entPredicate
                        textPredicate = ' '.join(camel_case_split(entPredicate)).lower()
                        textPredicate = textPredicate.replace('is ', '')                    
                        query = query.replace(ent.text, '%s '%textPredicate+ent.text)
                        break
        
        if query == q: continue
        if query not in queryTest['withPredicate']:
            queryTest['withPredicate'][query] = {'vars':[entityId], 'score':variables['score']}
        else:
            queryTest['withPredicate'][query]['vars'] += [entityId]
        
        otherVarIds = copy.deepcopy(variables['vars'])
        
        for varId in otherVarIds:
            for classId, predicate in clsPredicate.items():
                if classId not in entities[varId]['path']:
                    otherVarIds.remove(varId)
                    break
                if all([1 if predicate not in path else 0 for path in entities[varId]['path'][classId]]):
                    otherVarIds.remove(varId)
                    break
        
        queryTest['withPredicate'][query]['vars'] += list(otherVarIds)
        queryTest['withPredicate'][query]['vars'] = list(set(queryTest['withPredicate'][query]['vars']))

#### Prepare for IR with classifier
We are experimenting the automatic use of predicate in our retrieval. The classifier decides its use.
- We prepare the classifier dataset here
- Then, run the classifier training in another script to get the model
    - Follow this link: <a href="Train Query Classifier - OMEX.ipynb">Train Classifier for OMEX BioModels</a>
- Finally, using the generated model in this experiment

In [None]:
def entitySearchCombine(query, topK=20, minSim=0.5, indexType='class', pathType=None):
    """
    In this approach: combining entitySearch and entitySearchClass
    1. Query is chunked into entities and classified into biomedical phrases and predicate
    2. Connect predicate to biomedical phrase
    3. If indexType is 'class', generate vector for biomedical phrase
    4. If indexType is 'class_predicate', generate vector for biomedical phrase and predicate pair
    5. Combine vectors becoming one vector using mean function, named it as local vector
    6. Get vector of query named it as global vector
    7. Combine local vector and global vector
    8. Get similar entities using cosine similarity
    9. Return topK result in descending
    """
    
    ### Get local query embedding 
    
    doc = nlp(query)
    alpha = 1
    
    ontoClasses = []
    predicates = []
    validClassPredicates = {}
    offset2Class = {}
    cScores = []
    for ent in doc.ents:
        predicateScores = getPredicates(ent.text, topK=1)
        pScore = list(predicateScores.values())[0][1]
        classScores = getClasses(ent.text, topK=1)
        cScore = list(classScores.values())[0][1]
        if cScore >= pScore:
            cScores += [cScore]
            ontoClasses += [ent]
            for token in ent:
                offset2Class[token.i] = ent
        elif indexType =='class_predicate' and pScore >= minSim:
            predicates += [ent]
    
    if len(ontoClasses) == 0: 
        ontoClasses = [doc]
        cScores = [list(getClasses(query, topK=1).values())[0][1]]
            
    # check the entities describe by predicate (usually predicate's child)
    for ent in predicates:
        for token in ent:
            for child in token.children:
                if child.i in offset2Class:
                    idx = ontoClasses.index(offset2Class[child.i])
                    if idx not in validClassPredicates:
                        validClassPredicates[idx] = [ent]
                    else:
                        validClassPredicates[idx] += [ent]
                    break # only consider the closest class. remove break if considering all classes
                
    if indexType=='class':
        classEmbeddings = [_model.encode(ent.text, convert_to_tensor=True) for ent in ontoClasses]
        textEmbedding = torch.mean(torch.stack(classEmbeddings, dim=0), dim=0)
    elif indexType=='class_predicate':
        classEmbeddings = []
        for i in range(len(ontoClasses)):
            classEmbedding = _model.encode(ontoClasses[i].text, convert_to_tensor=True)
            if i in validClassPredicates:
                predicateEmbeddings = [_model.encode(ent.text, convert_to_tensor=True) for ent in validClassPredicates[i]]
                pathEmbedding = alpha * torch.mean(torch.stack(predicateEmbeddings, dim=0), dim=0)
                classEmbedding = torch.mean(torch.stack([classEmbedding, pathEmbedding], dim=0), dim=0)
            classEmbeddings += [classEmbedding]
        textEmbedding = torch.mean(torch.stack(classEmbeddings, dim=0), dim=0)                                
        
    ### Get global query embedding
    textEmbeddingGlobal = _model.encode(query, convert_to_tensor=True)
    ### Combine global and local embedding
    factor = sum(cScores)/len(cScores)
    textEmbedding = torch.mean(torch.stack([textEmbeddingGlobal, factor * textEmbedding], dim=0), dim=0)       
    
    # We use cosine-similarity and torch.topk to find the highest top_k scores
    cosScores = util.pytorch_cos_sim(textEmbedding, entityEmbeddings[indexType])[0]
    topResults = torch.topk(cosScores, k=topK)
    results = {}
    varIds = entityEmbeddings['entityIds']
    for rank, (score, idx) in enumerate(zip(topResults[0], topResults[1])):
        results[varIds[idx]] = [rank, score.item(), entities[varIds[idx]]]
    return results

In [None]:
queryTestCombine = {**queryTest['noPredicate'], **queryTest['withPredicate']}

resultCombineClass = getMAP(queryTestCombine, entitySearchCombine, topK=10, indexType='class')
resultCombineClassPredicate = getMAP(queryTestCombine, entitySearchCombine, topK=10, indexType='class_predicate')

In [None]:
# tag the query -1=neutral, 0=class, 1=class_predicate
result = list(zip(resultCombineClass['MAP']['stat'].values(), resultCombineClassPredicate['MAP']['stat'].values()))
# r = [max(d) for d in r]
c1, c2, c3 = 0, 0, 0
for idx, (query, value) in enumerate(queryTestCombine.items()):
    if result[idx][0] == result[idx][1]:
        value['indexType'] = -1
        c1+=1
    elif result[idx][0] > result[idx][1]:
        value['indexType'] = 0
        c2+=1
    elif result[idx][0] < result[idx][1]:
        value['indexType'] = 1
        c3+=1

print(c1,c2,c3)

In [None]:
with open('casbert_resources/omex_classifier_data.json', 'w') as fp:
    json.dump(queryTestCombine, fp)

##### Divide into train, validation, and test data (proportion 4:3:3)

In [None]:
import pandas as pd

data = {'queries':[], 'labels':[]}
for q, v in queryTestCombine.items():
    if v['indexType'] != -1:
        data['queries'] += [q]
        data['labels'] += [v['indexType']]
        
df = pd.DataFrame(data)

In [None]:
import numpy as np
df_train, df_eval, df_test = np.split(df.sample(frac=1, random_state=0), [int(.3*len(df)), int(.6*len(df))])
print(df_train.shape, df_eval.shape, df_test.shape)

In [None]:
queryTest['combine'] = {}
for q, v in queryTestCombine.items():
    if v['indexType'] == -1 or q in df_test['queries'].to_list():
        queryTest['combine'][q] = v

In [None]:
## save queryTest to file
with open('casbert_resources/omex_queryTest.json', 'w') as fp:
    json.dump(queryTest, fp)