In [3]:
import sys
import numpy as np
import pandas as pd
import glob
import ujson

from tqdm.auto import tqdm

sys.path.insert(0, '../')
# from eval_models import evaluate_linking, format_pred_summary

In [4]:
dataset_names = ['bc5cdr','medmentions_full','medmentions_st21pv','gnormplus','nlmchem','nlm_gene']

sapbert_dict = {name: f'/efs/davidkartchner/output/sapbert/{name}/0.json' for name in dataset_names}

krissbert_dict = {name: f'/efs/davidkartchner/krissbert/usage/model_output/{name}.json' for name in dataset_names}
krissbert_dict

{'bc5cdr': '/efs/davidkartchner/krissbert/usage/model_output/bc5cdr.json',
 'medmentions_full': '/efs/davidkartchner/krissbert/usage/model_output/medmentions_full.json',
 'medmentions_st21pv': '/efs/davidkartchner/krissbert/usage/model_output/medmentions_st21pv.json',
 'gnormplus': '/efs/davidkartchner/krissbert/usage/model_output/gnormplus.json',
 'nlmchem': '/efs/davidkartchner/krissbert/usage/model_output/nlmchem.json',
 'nlm_gene': '/efs/davidkartchner/krissbert/usage/model_output/nlm_gene.json'}

In [5]:

outputs_dict = {
    'sapbert': {
        'bc5cdr': '/efs/davidkartchner/sapbert/output/sapbert/bc5cdr/2022-08-12/0.json',
        'medmentions_full': '/efs/davidkartchner/sapbert/output/sapbert/medmentions_full/2022-08-12/0.json',
        'medmentions_st21pv': '/efs/davidkartchner/sapbert/output/sapbert/medmentions_st21pv/2022-08-12/0.json',
        'gnormplus': '/efs/davidkartchner/sapbert/output/sapbert/gnormplus/2022-08-12/0.json',
        'nlmchem': '/efs/davidkartchner/sapbert/output/sapbert/nlmchem/2022-08-12/0.json',
        'nlm_gene': '/efs/davidkartchner/sapbert/output/sapbert/nlm_gene/2022-08-12/0.json'
        },

 'krissbert': {
        'bc5cdr': '/efs/davidkartchner/krissbert/usage/model_output/bc5cdr.json',
        'medmentions_full': '/efs/davidkartchner/krissbert/usage/model_output/medmentions_full.json',
        'medmentions_st21pv': '/efs/davidkartchner/krissbert/usage/model_output/medmentions_st21pv.json',
        'gnormplus': '/efs/davidkartchner/krissbert/usage/model_output/gnormplus.json',
        'nlmchem': '/efs/davidkartchner/krissbert/usage/model_output/nlmchem.json',
        'nlm_gene': '/efs/davidkartchner/krissbert/usage/model_output/nlm_gene.json'
        }
}

In [25]:
def hit(gold_cuis, candidates, k, model='sapbert', mode='relaxed'):
#     if hits is not None:
#         return any(hits[:k])
    if model=='krissbert':
        if mode == 'strict':
            hit_list = [all(x in gold_cuis for x in sublist) for sublist in candidates[:k]]
        else:
            hit_list = [any(x in gold_cuis for x in sublist) for sublist in candidates[:k]]
        if any(hit_list):
            return True
        else:
            return False

    if any([x in gold_cuis for x in candidates[:k]]):
        return True
    else:
        return False

def evaluate_preds(preds, k=1, model='sapbert', mode='relaxed'):
    total = 0
    hits = 0
    for p in preds:
        candidates = p['candidates']
        total += 1
        if model == 'sapbert':
            candidates = [x['db_id'].split(':')[-1] for x in p['candidates']]
            
            hits += hit(p['cuis'], candidates, k, model=model, mode=mode)
        if model=='krissbert':
            hits += hit(p['cuis'], candidates, k, model=model, mode=mode)
        

    return hits/total

In [26]:
# sapbert_bc5cdr = ujson.load(open('/efs/davidkartchner/sapbert/output/sapbert/bc5cdr/2022-08-12/0.json','r'))
# sapbert_bc5cdr

In [27]:
# for model, output_files in outputs_dict.items():
#     if model == 'sapbert':
#         continue
#     for dataset, file in output_files.items():
#         preds = ujson.load(open(file,'r'))
#         accuracy = evaluate_preds(preds, 1, model=model)
#         print(model, dataset, accuracy)
#         print()

In [29]:
eval_results = []
ks = [1,2,4,8,16,32,64,100]
for model, output_files in tqdm(outputs_dict.items()):
    for dataset, file in tqdm(output_files.items()):
        for mode in ['strict','relaxed']:
            if model != 'krissbert' and mode =='relaxed':
                continue
            preds = ujson.load(open(file,'r'))
            accs = []
            for k in ks:
                accuracy = evaluate_preds(preds, k, model=model, mode=mode)
                accs.append(accuracy)
            if model == 'krissbert':
                eval_results.append([dataset, f'{model}_{mode}'] + accs)
            else:
                eval_results.append([dataset, model] + accs) 

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

In [40]:
results_df = pd.DataFrame(eval_results, columns=['dataset', 'model'] + [f"recall@{k}" for k in ks])

dataset_to_pretty_name = {
    "medmentions_full": "MedMentions Full",
    "medmentions_st21pv": "MedMentions ST21PV",
    "bc5cdr": "BC5CDR",
    "gnormplus": "GNormPlus",
    "ncbi_disease": "NCBI Disease",
    "nlmchem": "NLM Chem",
    "craft": "CRAFT",
    "bc6id": "BC6ID",
    "bc3gm": "BC3GM",
    "plantnorm": "PlantNorm",
    'nlm_gene': "NLM Gene"
}

results_df.dataset = results_df['dataset'].map(dataset_to_pretty_name)
results_df.model = results_df.model.map(lambda x: x.title().replace('_', ' '))
results_df.columns = [x.title() for x in results_df.columns]

print(results_df.sort_values(by=['Dataset','Model']).round(3).to_latex(index=False, escape=False, bold_rows=True))
results_df.sort_values(by=['Dataset','Model']).round(3)

\begin{tabular}{llrrrrrrrr}
\toprule
           Dataset &             Model &  Recall@1 &  Recall@2 &  Recall@4 &  Recall@8 &  Recall@16 &  Recall@32 &  Recall@64 &  Recall@100 \\
\midrule
            BC5CDR & Krissbert Relaxed &     0.716 &     0.736 &     0.749 &     0.757 &      0.763 &      0.767 &      0.769 &       0.770 \\
            BC5CDR &  Krissbert Strict &     0.715 &     0.735 &     0.748 &     0.755 &      0.761 &      0.765 &      0.767 &       0.768 \\
            BC5CDR &           Sapbert &     0.852 &     0.863 &     0.881 &     0.896 &      0.916 &      0.929 &      0.936 &       0.936 \\
         GNormPlus & Krissbert Relaxed &     0.077 &     0.081 &     0.086 &     0.087 &      0.087 &      0.087 &      0.087 &       0.087 \\
         GNormPlus &  Krissbert Strict &     0.072 &     0.075 &     0.080 &     0.081 &      0.081 &      0.081 &      0.081 &       0.081 \\
         GNormPlus &           Sapbert &     0.199 &     0.304 &     0.478 &     0.581 &      0.

Unnamed: 0,Dataset,Model,Recall@1,Recall@2,Recall@4,Recall@8,Recall@16,Recall@32,Recall@64,Recall@100
7,BC5CDR,Krissbert Relaxed,0.716,0.736,0.749,0.757,0.763,0.767,0.769,0.77
6,BC5CDR,Krissbert Strict,0.715,0.735,0.748,0.755,0.761,0.765,0.767,0.768
0,BC5CDR,Sapbert,0.852,0.863,0.881,0.896,0.916,0.929,0.936,0.936
13,GNormPlus,Krissbert Relaxed,0.077,0.081,0.086,0.087,0.087,0.087,0.087,0.087
12,GNormPlus,Krissbert Strict,0.072,0.075,0.08,0.081,0.081,0.081,0.081,0.081
3,GNormPlus,Sapbert,0.199,0.304,0.478,0.581,0.624,0.635,0.636,0.636
9,MedMentions Full,Krissbert Relaxed,0.583,0.672,0.732,0.762,0.782,0.796,0.807,0.812
8,MedMentions Full,Krissbert Strict,0.583,0.672,0.732,0.762,0.782,0.796,0.807,0.812
1,MedMentions Full,Sapbert,0.599,0.681,0.732,0.764,0.793,0.814,0.825,0.825
11,MedMentions ST21PV,Krissbert Relaxed,0.548,0.625,0.677,0.705,0.724,0.737,0.747,0.753


In [16]:
eval_results = []
ks = [1,2,4,8,16,32,64,100]
for model, output_files in tqdm(outputs_dict.items()):
    for dataset, file in tqdm(output_files.items()):
        preds = ujson.load(open(file,'r'))
        accs = []
        for k in ks:
            accuracy = evaluate_preds(preds, k, model=model, mode='relaxed')
            accs.append(accuracy)
        eval_results.append([dataset, model] + accs)


results_df = pd.DataFrame(eval_results, columns=['dataset', 'model'] + [f"recall@{k}" for k in ks])
results_df.sort_values(by=['dataset','model']).round(4)

Unnamed: 0,dataset,model,recall@1,recall@2,recall@4,recall@8,recall@16,recall@32,recall@64,recall@100
6,bc5cdr,krissbert,0.715,0.7352,0.7477,0.7547,0.7606,0.7647,0.7667,0.7677
0,bc5cdr,sapbert,0.852,0.8628,0.8808,0.8963,0.9161,0.9291,0.9359,0.9359
9,gnormplus,krissbert,0.0723,0.0751,0.08,0.081,0.081,0.0813,0.0813,0.0813
3,gnormplus,sapbert,0.1994,0.3044,0.4782,0.5814,0.6241,0.6351,0.6359,0.6359
7,medmentions_full,krissbert,0.5828,0.672,0.7316,0.7622,0.7818,0.7956,0.8066,0.8121
1,medmentions_full,sapbert,0.5993,0.6806,0.7316,0.7641,0.793,0.8138,0.8252,0.8252
8,medmentions_st21pv,krissbert,0.5479,0.6251,0.6769,0.7054,0.724,0.7373,0.7475,0.7527
2,medmentions_st21pv,sapbert,0.5939,0.6682,0.7276,0.763,0.7981,0.8246,0.8383,0.8383
11,nlm_gene,krissbert,0.2682,0.3898,0.4559,0.4784,0.4936,0.4987,0.5049,0.5107
5,nlm_gene,sapbert,0.1383,0.2711,0.5205,0.7851,0.8544,0.8922,0.9111,0.9111
