# Preparing data and indexing to be used for experiment using PMR repositories
    We are using a file consisting of annotated entities in the PMR repository

In [None]:
# required modules
import pandas as pd
from tqdm import tqdm
import json
import torch
from sentence_transformers import util

## Create dictionaries of ontology classes, predicates, and entities

In [None]:
# Loading data of variables, components, cellms (already extracted from repositories)

with open('casbert_resources/pmr_list_of_variable.json', 'r') as fp:
    variables = json.load(fp)

#### Extract predicates

In [None]:
# get set of predicates and ontology classes from repositories
# besides ontology classes, there are also other information such as reference and textual information

predicates = []
leave = {'file':[], 'http':[], 'other':[]}
for key, data in variables['data'].items():
    if 'rdf' in data:
        for triple in data['rdf']:
            predicates += [triple[1]]

# set of predicates
predicates = set(predicates)

#### Extract relationship b/w ontology classes and entities

In [None]:
class2Vars = {}

for varId, value in variables['data'].items():
    
    if 'rdfLeaves' in value:
        for leaf in value['rdfLeaves']:
            leaf = str(leaf).strip()
            # get ontology class para
            classId = None
            if leaf.startswith('http'):
                classId = leaf.rsplit('/',1)[-1].rsplit('#',1)[-1].replace('_',':')
            elif leaf.startswith('urn:miriam'):
                classId = leaf.rsplit(':',1)[-1].replace('%3A',':')
            if classId != None:
                if classId not in class2Vars:
                    class2Vars[classId] = []
                class2Vars[classId] += [varId]
                

#### Preparing embedder

In [None]:
# Initialisation

# BERT model used
modelName = 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1'
seqLength = 128

from sentence_transformers import SentenceTransformer, models

word_embedding_model = models.Transformer(modelName, max_seq_length=seqLength)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
sbert = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Load ontology terms from available files
Ontologies:
  - OPB
  - CL
  - FMA
  - GO
  - CHEBI
  - PR

In [None]:
# load ontology dictionaries
import gzip, pickle
file = gzip.GzipFile('casbert_resources/ontoDf.gz', 'rb')
ontologies = pickle.load(file)
file.close()
ontologies.head(1)

#### Create dict of ontoly class embeddings

In [None]:
classId_List = []
ontoEmbedding = {'name':[], 'synonym':[], 'name_synonym':[], 'def':[], 'name_synonym_def':[], 'name_synonym_def':[]}
embedShape = 384

def change2embedding(txt):
    return sbert.encode(txt, convert_to_tensor=True)

def from_synonym(syn):
    tmp = []
    try:
        for txt in syn.split('|'):
            if len(txt) == 0: continue
            if (txt.find('"') < txt.rfind('"')):
                tmp += [txt[txt.find('"')+1:txt.rfind('"')]]
            elif (txt.find("'") == txt.rfind("'")):
                tmp += [txt[txt.find("\'")+1:txt.rfind("\'")]]
            else: 
                tmp += [txt] 
        return torch.mean(change2embedding(tmp), 0)
    except:
        return torch.zeros(embedShape)

for classId in tqdm(class2Vars):
    if classId in ontologies.index:
        ocls = ontologies.loc[classId]
        classId_List += [classId]
        ontoEmbedding['name'] += [change2embedding(ocls['name'])]
        ontoEmbedding['synonym'] += [from_synonym(ocls['synonym'])]
        if isinstance(ocls['def'], str):
            ontoEmbedding['def'] += [change2embedding(ocls['def'])]
        else:
            ontoEmbedding['def'] += [torch.zeros(embedShape)]

In [None]:
ontoEmbedding['name'] = torch.stack(ontoEmbedding['name'], dim=0)
ontoEmbedding['synonym'] = torch.stack(ontoEmbedding['synonym'], dim=0)
ontoEmbedding['def'] = torch.stack(ontoEmbedding['def'], dim=0)

In [None]:
# now combine some embeddings

# name_synonym
tensor = torch.stack([ontoEmbedding['name'], ontoEmbedding['synonym']], dim=0)
ontoEmbedding['name_synonym'] = torch.div(torch.nansum(tensor, dim=0),(~torch.isnan(tensor)).count_nonzero(dim=0))

# name_synonym_def
tensor = torch.stack([ontoEmbedding['name'], ontoEmbedding['synonym'], ontoEmbedding['def']], dim=0)
ontoEmbedding['name_synonym_def'] = torch.div(torch.nansum(tensor, dim=0),(~torch.isnan(tensor)).count_nonzero(dim=0))

In [None]:
# OntoEmbedding function
from sentence_transformers import util
    
# returning text
def getClassText(classId, features=['name']):
    """
    features: name, synonym, parent, def
    """
    corpus = {}
    if classId in classId_List:
        for feature in features:
            corpus[feature] = ontologies.loc[classId][feature]
    return corpus

# returning embedding
def getClassEmbedding(classId, feature='name_synonym'):
    """
    classId: an id of a class such as 'CHEBI:29101'
    feature: name, name_synonym, name_synonym_def, name_synonym_def, name_synonym_def_parent
    """
    import torch
    if 'http' in classId: classId = classId.rsplit('/')[-1].split('#')[-1].replace('_',':')
    if classId in classId_List:
        return ontoEmbedding[feature][classId_List.index(classId)]
    return None

# returning ontology classes
def getClasses(text, feature='name_synonym', topK = 20):
    """
    feature: name, name_synonym, name_synonym_def, name_synonym_def, name_synonym_def_parent
    """
    textEmbedding = sbert.encode(text, convert_to_tensor=True)
    # We use cosine-similarity and torch.topk to find the highest top_k scores
    if feature in ontoEmbedding:
        cosScores = util.pytorch_cos_sim(textEmbedding, ontoEmbedding[feature])[0]
        topResults = torch.topk(cosScores, k=topK)
        classes = {}
        for rank, (score, idx) in enumerate(zip(topResults[0], topResults[1])):
            classId = classId_List[idx.item()]
            classes[classId] = (rank, score.item(), ontologies.loc[classId]['name'])
        return classes
    return None

In [None]:
torch.save(ontoEmbedding, 'casbert_resources/pmr_classes.pt')

#### Create dictionary of predicate embedding

In [None]:
def _camelCaseSplitter(identifier):
    from re import finditer
    matches = finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return ' '.join([m.group(0) for m in matches])

In [None]:
textPredicates = []
purePredicates = []
for p in predicates:
    predicate = _camelCaseSplitter(p.split('#')[-1].split('/')[-1]) if p.startswith('http') else p
    if len(predicate.split())>0 and len(predicate) > 2:
        textPredicates += [predicate]
        purePredicates += [p]
predicateEmbedding = dict(zip(purePredicates, sbert.encode(textPredicates, convert_to_tensor=True, show_progress_bar=True)))

In [None]:
def getPredicateEmbedding(predicates):
    """
    predicates: a string or a list of predicate. if it is a list, the mean of embedding is returned
    """
    if type(predicates) == str: predicates = [predicates]
    embeddings = []
    for predicate in predicates:
        if predicate in predicateEmbedding:
            embeddings += [predicateEmbedding[predicate]]
    if len(embeddings) == 0:
        return None
    return torch.mean(torch.stack(embeddings), dim=0)

def getPredicates(text, topK = 20):
    """
    text: string about predicate
    returning a dictionary of predicates ordered by similarities
    """
    textEmbedding = sbert.encode(text, convert_to_tensor=True)
    # We use cosine-similarity and torch.topk to find the highest top_k scores
    cosScores = util.pytorch_cos_sim(textEmbedding, torch.stack(list(predicateEmbedding.values())))[0]
    topResults = torch.topk(cosScores, k=topK)
    predicates = {}
    for rank, (score, idx) in enumerate(zip(topResults[0], topResults[1])):
        predicates[list(predicateEmbedding.keys())[idx]] = (rank, score.item())
    return predicates

In [None]:
torch.save(predicateEmbedding, 'casbert_resources/pmr_predicates.pt')

#### Create dictionary of entity embeddings
This step includes:
- creating entity embeddings
- extracting query-entity dataset for testing candidate

In [None]:
### get entity representation (ontology class and predicates - ontology classes)


def getPredicatePaths(obj=None, prevPred=[], arrRdf=None):
    import numpy as np
    if obj != None:
        rows = arrRdf[np.where(arrRdf[:,2] == obj)]
    if len(rows) > 0:
        finalPreds = []
        for row in rows:
            finalPreds += getPredicatePaths(obj=row[0], prevPred=[row[1]]+prevPred, arrRdf=arrRdf)
        return list(finalPreds)
    return [prevPred]

test = {}  ## this if for dataset

# classId to content (ontology classes and predicate paths)
entityKeys = {}
entityIds = []
entityEmbedding = {'class':[], 'class_predicate':[]}

# ontology class - predicate to varIds
classPredicateMap = {'multi':{}, 'single':{}}
# the embedding of ontology class - predicate
classPredicateEmbedding = {'multi':[], 'single':[]}

alpha = 0.22


# loop for each variable
for varId, value in variables['data'].items():
    
    embeddings = []
    pathClassEmbeddings = []
    terms = {}
    candQuery = []
    if 'rdfLeaves' in value:
        for leaf in value['rdfLeaves']:
            leaf = str(leaf).strip()
            # get ontology class para
            if leaf.startswith('http'):
                classId = leaf.rsplit('/',1)[-1].rsplit('#',1)[-1].replace('_',':')
            elif leaf.startswith('urn:miriam'):
                classId = leaf.rsplit(':',1)[-1].replace('%3A',':')
            if leaf.startswith('http') or leaf.startswith('urn:miriam'):
                embedding = getClassEmbedding(classId)
                # get ontology class embedding
                if embedding != None:
                    embeddings += [embedding]
                    # get predicates path embedding
                    import numpy as np
                    arrRdf = np.array(value['rdf'])
                    paths = getPredicatePaths(obj=leaf, arrRdf=arrRdf)
                    pathTexts = [[p.rsplit('/')[-1].rsplit('#')[-1] for p in path] for path in paths]
                    pathEmbeddings = [getPredicateEmbedding(path) for path in paths if getPredicateEmbedding(path) != None]
                    if len(pathEmbeddings) > 0:
                        pathEmbedding = alpha * torch.mean(torch.stack(pathEmbeddings, dim=0), dim=0)
                        pathClassEmbedding = torch.sum(torch.stack([embedding, pathEmbedding], dim=0), dim=0)/(1+alpha)
                        pathClassEmbeddings += [pathClassEmbedding]
                    else:
                        pathClassEmbeddings += [embedding]
                    
                    # setting for the use of unique classPredicateEmbedding
                    for path in paths:
                        key = leaf + '-'.join([p.rsplit('/')[-1].rsplit('#')[-1] for p in path])
                        #index pair of path and ontology class
                        pEmbedding = getPredicateEmbedding(path)
                        if pEmbedding != None:
                            tmpEmbedding = torch.mean(torch.stack([embedding, pEmbedding], dim=0), dim=0)
                        else:
                            tmpEmbedding = embedding
                        if key not in classPredicateMap['multi']:
                            classPredicateEmbedding['multi'] += [tmpEmbedding]
                            classPredicateMap['multi'][key] = [varId]
                        else:
                            classPredicateMap['multi'][key] += [varId]
                        
                        # index single predicate and ontology class
                        for predicate in path:
                            key = leaf + predicate
                            pEmbedding = getPredicateEmbedding([predicate])
                            if pEmbedding != None:
                                tmpEmbedding = torch.mean(torch.stack([embedding, pEmbedding], dim=0), dim=0)
                            else:
                                tmpEmbedding = embedding
                            if key not in classPredicateMap['single']:
                                classPredicateEmbedding['single'] += [tmpEmbedding]
                                classPredicateMap['single'][key] = [varId]
                            else:
                                classPredicateMap['single'][key] += [varId]
                        
                    # set term
                    terms[classId] = {'name':getClassText(classId, features=['name'])['name'], 'path':pathTexts}
            elif not leaf.startswith('file://'):
                paths = getPredicatePaths(obj=leaf, arrRdf=arrRdf)
                for path in paths:
                    if 'description' in ''.join(path):
                        candQuery += [leaf]
                        break
    
        if len(embeddings) > 0:
            entityKeys[varId] = {'pos':len(entityKeys), 'classes':terms}
            entityEmbedding['class'] += [torch.mean(torch.stack(embeddings, dim=0), dim=0)]
            entityEmbedding['class_predicate'] += [torch.mean(torch.stack(pathClassEmbeddings, dim=0), dim=0)]
            entityIds += [varId]
            for q in candQuery:
                if q not in test:
                    test[q] = {}
                test[q][varId] = terms
                    

for k, v in entityEmbedding.items():
    entityEmbedding[k] = torch.stack(v, dim=0)

classPredicateEmbedding['multi'] = torch.stack(classPredicateEmbedding['multi'], dim=0)
classPredicateEmbedding['single'] = torch.stack(classPredicateEmbedding['single'], dim=0)
    
entityEmbedding['class'].shape



In [None]:
torch.save(entityEmbedding, 'casbert_resources/pmr_entities.pt')

## Performance measure using Mean average Precision

In [None]:
def averagePrecision(prediction):
    if 1 not in prediction:
        return 0
    tot = 0
    for idx, p in enumerate(prediction):
        if p>0:
            tot += sum(prediction[0:idx+1])/(idx+1)
    return tot/sum(prediction)

def meanAP(predictions):
    tap = 0
    stat = {}
    for idx, prediction in enumerate(predictions):
        ap = averagePrecision(prediction)
        stat[idx] = ap
        tap += ap
    return {'score':tap/len(predictions), 'stat':stat}

def meanRR(predictions):
    trr = 0
    stat = {}
    for idx, prediction in enumerate(predictions):
        rr = 1/(prediction.index(1)+1) if 1 in prediction else 0
        stat[idx] = rr
        trr += rr
    return {'score':trr/len(predictions), 'stat':stat}

def getMAP(queries, searchFunction, indexType=None, pathType=None, topK=10, minSim=0.5):
    predictions = []
    for query, facts in tqdm(queries.items()):
        results = searchFunction(query=query, topK=topK, indexType=indexType, pathType=pathType, minSim=minSim)
        predictions += [[1 if varId in facts['vars'] else 0 for varId in results]]
    MAP = meanAP(predictions)
    return {'MAP':MAP,'MRR':meanRR(predictions)}
    

In [None]:

def getMAPs(queries, functions, indexTypes=['class', 'class_predicate'], pathTypes=['single', 'multi'], topK=10, minSim=0.5):
    for idx, path in zip(indexTypes, pathTypes):
        print(idx, path)
    for function in functions:
        print(function.__name__)
#         print(inspect.getfullargspec(function))

## Preparing Data Test

In [None]:
# identify the quality of query-entity pairs by their similarity:
    # 0 if query-entity similarity > 0.7
    # -1 if 0.5 <= query-entity similarity <= 0.7
    # -2 if query-entity similarity < 0.5

# save test data to csv, and then inspect manually by experts to validate


import pandas as pd
from tqdm import tqdm

data = []
for query, classes in tqdm(test.items()):
    queryEmb = sbert.encode(query, convert_to_tensor=True)
    for varId, classes in classes.items():
        if len(query) > 0:
            entityEmb = entityEmbedding['class'][entityKeys[varId]['pos']]
            cosScores = util.pytorch_cos_sim(entityEmb, queryEmb)
            check = -1 if cosScores >= 0.5 and cosScores <= 0.7 else -2 if cosScores < 0.5 else 0
            data += [[query, check, classes, varId, cosScores.item()]]

df = pd.DataFrame(data, columns=['query', 'check', 'classes', 'varID', 'score'])

In [None]:
df.to_csv('casbert_resources/pmr_test_data_raw.csv', index=False)

#### Manual inspection is painful, but we did it
It is available now with same name "pmr_test_data.csv"

Columns: query, check, classes, varID, score, query_modif

In [None]:
df = pd.read_csv('casbert_resources/pmr_test_data.csv')
df['query'].unique().shape
df.head(1)

#### Enrich test data
In the test_data.csv, each query is associated to a limited number of entities. There is a possibility that the query is also associated to other entities. Therefore, we need to enrich the test_data with any possible entities. This enriched data than named as 'silver data'

In [None]:
import re
import random
def camel_case_split(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]

In [None]:
import ast

if 'queryTest' not in locals():
    queryTest = {}    
queryTest['noPredicate'] = {}

def generateQueryTest(dfTest, name = None):
    if name == None:
        qTestIdx = str(len(queryTest['noPredicate']))
    else:
        qTestIdx = name
    queryTest['noPredicate'][qTestIdx] = {}
    for index, row in dfTest.iterrows():
        query = row['query'] if row['check']==0 else row['query_modif']
        classes = ast.literal_eval(row['classes']).keys()
        if query not in queryTest['noPredicate'][qTestIdx]:
            queryTest['noPredicate'][qTestIdx][query] = {'vars':[row['varID']], 'score':row['score']}
        else:
            queryTest['noPredicate'][qTestIdx][query]['vars'] += [row['varID']]
        lstVars = [class2Vars[c] for c in classes]
        otherVarIds = set.intersection(*map(set,lstVars))
        queryTest['noPredicate'][qTestIdx][query]['vars'] += list(otherVarIds)
        queryTest['noPredicate'][qTestIdx][query]['vars'] = list(set(queryTest['noPredicate'][qTestIdx][query]['vars']))

In [None]:
# this is for the best test data
generateQueryTest(df[(df.check==0)])
# this is for the first and second best test data
generateQueryTest(df[(df.check==0) | (df.check==-1)]) 
# this is for the first, second, and third best test data
generateQueryTest(df)

In [None]:
# generate query test with predicate

def generateQueryPredicateTest(dfTest, name = None):
    if name == None:
        qTestIdx = str(len(queryTest['withPredicate']))
    else:
        qTestIdx = name
    queryTest['withPredicate'][qTestIdx] = {}
    for index, row in tqdm(dfTest.iterrows(), total=dfTest.shape[0]):
        query = row['query'] if row['check']==0 else row['query_modif']
        doc = nlp(query)
        classes = list(entityKeys[row['varID']]['classes'].keys())
        
        clsPredicate = {}
        for ent in doc.ents:
            entClasses = list(getClasses(ent.text, topK=10).keys())
            for classId in entClasses:
                if classId in classes and classId not in clsPredicate and bool(random.getrandbits(1)):
                    entPredicates = random.choice(entityKeys[row['varID']]['classes'][classId]['path'])
                    if 'is' in entPredicates: entPredicates.remove('is')
                    if len(entPredicates) > 0:
                        # we select randomnly a predicate from position -3 to the end
                        # the reason is that those predicates are more likely more relevant related to the ontology class
                        entPredicate = random.choice(entPredicates[-3:])
                        clsPredicate[classId] = entPredicate
                        textPredicate = ' '.join(camel_case_split(entPredicate)).lower()
                        textPredicate = textPredicate.replace('is ', '')                    
                        query = query.replace(ent.text, '%s '%textPredicate+ent.text)
                        break

        if query not in queryTest['withPredicate'][qTestIdx]:
            queryTest['withPredicate'][qTestIdx][query] = {'vars':[row['varID']], 'score':row['score']}
        else:
            queryTest['withPredicate'][qTestIdx][query]['vars'] += [row['varID']]
        
        lstVars = [class2Vars[c] for c in classes]
        otherVarIds = list(set.intersection(*map(set,lstVars)))
        
        for varId in otherVarIds:
            for classId, predicate in clsPredicate.items():
                if predicate not in entityKeys[varId]['classes'][classId]['path']:
                    otherVarIds.remove(varId)
                    break

        
        queryTest['withPredicate'][qTestIdx][query]['vars'] += list(otherVarIds)
        queryTest['withPredicate'][qTestIdx][query]['vars'] = list(set(queryTest['withPredicate'][qTestIdx][query]['vars']))

In [None]:
# In some functions, we utilise scispacy to locate phrases or concepts related to a query. 
# For example a query 'Calcium reverse membrane potential.' 
# is identified having 2 concepts of 'calcium' and 'reverse membrane potential'
# For accurate identification we use 'en_core_sci_scibert' which required GPU for faster performance

import en_core_sci_scibert
nlp = en_core_sci_scibert.load()


In [None]:
queryTest['withPredicate'] = {}
# this is for the best test data
generateQueryPredicateTest(df[(df.check==0)])
# this is for the first and second best test data
generateQueryPredicateTest(df[(df.check==0) | (df.check==-1)]) 
# this is for the first, second, and third best test data
generateQueryPredicateTest(df[(df.check==0) | (df.check==-1) | (df.check==-2)])

In [None]:
# generate combine test data
import random
listQueryP = random.sample(list(queryTest['withPredicate']['2']), len(queryTest['noPredicate']['2']))
queryTestCombine = {**queryTest['noPredicate']['2'], **{key:queryTest['withPredicate']['2'][key] for key in listQueryP}}
queryTest['combine'] = queryTestCombine

In [None]:
# the size of initial query test
print('no predicate', len(queryTest['noPredicate']['2']))
print('with predicate', len(queryTest['withPredicate']['2']))
print('combination', len(queryTest['combine']))

### Prepare for IR with classifier (for retrieval purpose)
We are experimenting the automatic use of predicate in our retrieval. The classifier decides its use.
- We prepare the classifier dataset here
- Then, run the classifier training in another script to get the model
    - Follow this link: <a href="Train Query Classifier - PMR.ipynb">Train Classifier for PMR</a>
- Finally, using the generated model in this experiment

In [None]:
def entitySearchCombine(query, topK=20, minSim=0.5, indexType='class', pathType=None):
    """
    In this approach: combining entitySearch and entitySearchClass
    1. Query is chunked into entities and classified into biomedical phrases and predicate
    2. Connect predicate to biomedical phrase
    3. If indexType is 'class', generate vector for biomedical phrase
    4. If indexType is 'class_predicate', generate vector for biomedical phrase and predicate pair
    5. Combine vectors becoming one vector using mean function, named it as local vector
    6. Get vector of query named it as global vector
    7. Combine local vector and global vector
    8. Get similar entities using cosine similarity
    9. Return topK result in descending
    """

    ### Get local query embedding 
    
    doc = nlp(query)
    
    ontoClasses = []
    predicates = []
    validClassPredicates = {}
    offset2Class = {}
    cScores = []
    for ent in doc.ents:
        predicateScores = getPredicates(ent.text, topK=1)
        pScore = list(predicateScores.values())[0][1]
        classScores = getClasses(ent.text, topK=1)
        cScore = list(classScores.values())[0][1]
        if cScore >= pScore:
            cScores += [cScore]
            ontoClasses += [ent]
            for token in ent:
                offset2Class[token.i] = ent
        elif indexType =='class_predicate' and pScore >= minSim:
            predicates += [ent]
    
    if len(ontoClasses) == 0: 
        ontoClasses = [doc]
        cScores = [list(getClasses(query, topK=1).values())[0][1]]
            
            
    # check the entities describe by predicate (usually predicate's child)
    for ent in predicates:
        for token in ent:
            for child in token.children:
                if child.i in offset2Class:
                    idx = ontoClasses.index(offset2Class[child.i])
                    if idx not in validClassPredicates:
                        validClassPredicates[idx] = [ent]
                    else:
                        validClassPredicates[idx] += [ent]
                    break # only consider the closest class. remove break if considering all classes
                
    if indexType=='class':
        classEmbeddings = [sbert.encode(ent.text, convert_to_tensor=True) for ent in ontoClasses]
        textEmbedding = torch.mean(torch.stack(classEmbeddings, dim=0), dim=0)
    elif indexType=='class_predicate':
        classEmbeddings = []
        for i in range(len(ontoClasses)):
            classEmbedding = sbert.encode(ontoClasses[i].text, convert_to_tensor=True)
            if i in validClassPredicates:
                predicateEmbeddings = [sbert.encode(ent.text, convert_to_tensor=True) for ent in validClassPredicates[i]]
                pathEmbedding = alpha * torch.mean(torch.stack(predicateEmbeddings, dim=0), dim=0)
                classEmbedding = torch.mean(torch.stack([classEmbedding, pathEmbedding], dim=0), dim=0)
            classEmbeddings += [classEmbedding]
        textEmbedding = torch.mean(torch.stack(classEmbeddings, dim=0), dim=0)                                
        
    ### Get global query embedding
    textEmbeddingGlobal = sbert.encode(query, convert_to_tensor=True)
    ### Combine global and local embedding
    factor = sum(cScores)/len(cScores)
    textEmbedding = torch.mean(torch.stack([textEmbeddingGlobal, factor * textEmbedding], dim=0), dim=0)       
    
    # We use cosine-similarity and torch.topk to find the highest top_k scores
    cosScores = util.pytorch_cos_sim(textEmbedding, entityEmbedding[indexType])[0]
    topResults = torch.topk(cosScores, k=topK)
    entities = {}
    varIds = list(entityKeys.keys())
    for rank, (score, idx) in enumerate(zip(topResults[0], topResults[1])):
        entities[varIds[idx]] = [rank, score.item(), entityKeys[varIds[idx]]]
    return entities

In [None]:
resultCombineClass = getMAP(queryTestCombine, entitySearchCombine, topK=10, indexType='class')
resultCombineClassPredicate = getMAP(queryTestCombine, entitySearchCombine, topK=10, indexType='class_predicate')

In [None]:
# tag the query -1=neutral, 0=class, 1=class_predicate
result = list(zip(resultCombineClass['MAP']['stat'].values(), resultCombineClassPredicate['MAP']['stat'].values()))
for idx, (query, value) in enumerate(queryTestCombine.items()):
    if result[idx][0] == result[idx][1]:
        value['indexType'] = -1
    elif result[idx][0] > result[idx][1]:
        value['indexType'] = 0
    elif result[idx][0] < result[idx][1]:
        value['indexType'] = 1


In [None]:
with open('casbert_resources/pmr_classifier_data.json', 'w') as fp:
    json.dump(queryTestCombine, fp)

#### Divide into train, validation, and test data (proportion 4:3:3)

In [None]:
data = {'queries':[], 'labels':[]}
for q, v in queryTestCombine.items():
    if v['indexType'] != -1:
        data['queries'] += [q]
        data['labels'] += [v['indexType']]
        
df = pd.DataFrame(data)

In [None]:
import numpy as np
df_train, df_eval, df_test = np.split(df.sample(frac=1, random_state=0), [int(.3*len(df)), int(.6*len(df))])
print(df_train.shape, df_eval.shape, df_test.shape)

In [None]:
queryTest['combine'] = {}
for q, v in queryTestCombine.items():
    if v['indexType'] == -1 or q in df_test['queries'].to_list():
        queryTest['combine'][q] = v

In [None]:
# save queryTest to file
with open('casbert_resources/pmr_query_test.json', 'w') as fp:
    json.dump(queryTest, fp)