In [None]:
import faiss
from tqdm import tqdm
import numpy as np
import pickle as pkl
from collections import OrderedDict
from typing import Dict
from transformers import AutoModel, AutoTokenizer
from sklearn.decomposition import PCA
import sys
! cp -r /content/drive/MyDrive/sapBERT-DUT-cambridge/src /content/src
from src.model_wrapper import Model_Wrapper
import pandas as pd
import sklearn
from sklearn.model_selection import ParameterGrid
import re

In [None]:
from src.model_wrapper import (
    Model_Wrapper
)

In [None]:
def get_query_embedding(queries, tokenizer, model):
    bs = 128
    all_reps = []
    for i in tqdm(np.arange(0, len(queries), bs)):
        toks = tokenizer.batch_encode_plus(queries[i:i+bs],
                                        padding="max_length",
                                        max_length=25,
                                        truncation=True,
                                        return_tensors="pt")
        toks_cuda = {}
        for k,v in toks.items():
            toks_cuda[k] = v.cuda()
        output = model(**toks_cuda)

        cls_rep = output[0][:,0,:]

        all_reps.append(cls_rep.cpu().detach().numpy())
    all_reps_emb = np.concatenate(all_reps, axis=0)

    return all_reps_emb

def query_index(queries, tokenizer, model, index, idx2cui, cui2string, pca):
    query_embs = get_query_embedding(queries, tokenizer, model)
    query_embs = np.array(pca.transform(query_embs), dtype=np.float32)

    preds = []

    for query_emb in tqdm(query_embs):
        dist, neighbors = index.search(np.reshape(query_emb, (1,256)), 5)
        dist, neighbors = dist[0], neighbors[0]

        pred = idx2cui[neighbors[0]]
        preds.append(pred)

    return preds

In [None]:
def getResources(model_directory_path):
        model_wrapper = Model_Wrapper().load_model(
            path=f'{model_directory_path}',
            max_length=25,
            use_cuda=True,
        )
        tokenizer = model_wrapper.get_dense_tokenizer()
        model = model_wrapper.get_dense_encoder()
        index = faiss.read_index(f'{model_directory_path}/index')
        pca = pkl.load(open(f'{model_directory_path}/pca', "rb"))
    return tokenizer, model, index, pca

In [None]:
def checkRelations(pred, trues, relations):
    if isinstance(trues, list):
        for true in trues:
            key = f"{true}|{pred}"
            if key in relations:
                return relations[key]
    else:
        key = f"{trues}|{pred}"
        key2 = f"{pred}|{trues}"
        if key in relations:
            return relations[key]
        elif key2 in relations:
            return relations[key2]

    return False

In [None]:
def checkRelation(preds, trues, relations):
    trues_ = []
    relas = []
    dist_1_relas = []

    for i, pred in enumerate(preds):

        rela = checkRelations(pred, trues[i], relations)
        current_trues = trues[i]

        if pred in current_trues:
            trues_.append(pred)
            dist_1_relas.append(True)
            relas.append('')
        elif rela:
            trues_.append(current_trues[0])
            dist_1_relas.append(True)
            relas.append(rela)
        else:
            trues_.append(current_trues[0])
            dist_1_relas.append(False)
            relas.append('')
    return trues_, dist_1_relas, relas

In [None]:
def predict(model_directory_path, sentences, mentions, idx2cui, cui2string):
    tokenizer, model, index, pca = getResources(model_directory_path)
    preds = query_index(mentions, tokenizer, model, index, idx2cui, cui2string, pca)
    return preds

In [None]:
def evaluate(preds, trues, relas):
    acc = sklearn.metrics.accuracy_score(checked_trues, preds)
    acc_1dist = sum(relas)/len(relas)

    print(f"Accuracy: {acc}")
    print('1-dist accuracy: ', {acc_1dist})

    return acc, acc_1dist

In [None]:
def saveResults(sentences, mentions, cui2string, preds, trues, relas, save=False):
    results = []

    for i, true in enumerate(trues):
        try:
            results.append([sentences[i], mentions[i], cui2string[preds[i]], cui2string[true], preds[i], true, relas[i]])
        except:
            print(mentions[i], true, preds[i])
            # continue
    df_results = pd.DataFrame(results, columns=['sentence', 'mention', 'prediction', 'label', 'cui_prediction', 'cui_label', 'relation'])

    # display(df_results)

    if save:
        df_results.to_csv(f'{model_directory_path}/predictions_mantra.csv')

        with open(f'{model_directory_path}/preds', 'wb') as f:
            pkl.dump(preds, f)

In [None]:
data = pd.read_pickle(f"mantra.pkl")
sentences = data['sentence'].to_list()
mentions = data['mention'].to_list()
idx2cui = pkl.load(open('id2cui.pkl', 'rb'))
cui2string = pkl.load(open('cui_to_string', 'rb'))
relations = pkl.load(open('relations', 'rb'))
model_directory_path = f'3s10ft'

preds = predict(model_directory_path, sentences, mentions, idx2cui, cui2string)
checked_trues, dist_1_relas, relas = checkRelation(preds, trues, relations)
acc, acc_1dist = evaluate(preds, checked_trues, dist_1_relas)
saveResults(sentences, mentions, cui2string, preds, checked_trues, relas, True)

