In [1]:
import json
import os
import pickle
import torch
import numpy as np

In [2]:
# Read JSON
def read_json(fname):
    with open(fname, "r") as handle:
        res = handle.read()
        obj = json.loads(res)
    return obj

In [3]:
# Read pickle
def read_pkl(fname):
    with open(fname, 'rb') as handle:
        obj_pkl = pickle.load(handle)
    return obj_pkl

In [4]:
# Read torch 
def read_torch(fname):
    obj = torch.load(fname)
    return obj

In [5]:
test_data = read_pkl('../data/medmentions/test_processed_data.pickle')

In [7]:
from collections import defaultdict
menlabel = defaultdict(set)
for t in test_data:
    menlabel[t['mention_name']].add(t['label_idxs'][0])
ambig_list = []
for m in menlabel:
    if len(menlabel[m]) >= 10:
        ambig_list.append(m)
assert len(ambig_list) == len(set(ambig_list))
ambig_list

[]

{872673, 872948, 944893, 953135, 964325, 976724}

In [12]:
ambig_list = ['activation', 'activity', 'a', 'b', 'cardiac', 'cells', 'clinical', 'compounds', 'cr', 
              'development', 'disease', 'function', 'fusion', 'inhibition', 'injuries', 'injury', 
              'liver', 'management', 'methods', 'mice', 'model', 'pa', 'production', 'protein', 'regulation', 
              'report', 'responses', 'response', 'r', 'screening', 'stress', 'studies', 'study', 'treatment']

In [44]:
ambig_count = 0
for o in test_data:
    if o['mention_name'] in ambig_list:
        ambig_count += 1
print(f"Total ambiguous mentions in test: {ambig_count}")

Total ambiguous mentions in test: 1247


In [45]:
results = {}
results['arbo'] = read_json('../data/medmentions/results/cross_arbo.json')
results['1rand'] = read_json('../data/medmentions/results/cross_1rand.json')
results['1nn'] = read_json('../data/medmentions/results/cross_1nn.json')
results['knn'] = read_json('../data/medmentions/results/cross_knn.json')
results['in_batch'] = read_json('../data/medmentions/results/cross_in_batch.json')

In [46]:
# Angell et al.: Ambiguous accuracy
print(f"Angell et al.: Ambiguous accuracy = {73.03}")

Angell et al.: Ambiguous accuracy = 73.03


In [47]:
acc = {}

for mode in results:    
    fail = succ = 0
    for o in results[mode]['failure']:
        if o['mention_name'] in ambig_list:
            fail += 1
    for o in results[mode]['success']:
        if o['mention_name'] in ambig_list:
            succ += 1
    acc[mode] = round((succ / ambig_count)*100, 2)
acc

{'arbo': 76.82, '1rand': 75.7, '1nn': 74.58, 'knn': 72.41, 'in_batch': 76.02}

In [48]:
# Oracle results
oresults = {}
oresults['arbo'] = read_json('../data/medmentions/results/oracle_cross_arbo.json')
oresults['1rand'] = read_json('../data/medmentions/results/oracle_cross_1rand.json')
oresults['1nn'] = read_json('../data/medmentions/results/oracle_cross_1nn.json')
oresults['knn'] = read_json('../data/medmentions/results/oracle_cross_knn.json')
oresults['in_batch'] = read_json('../data/medmentions/results/oracle_cross_in_batch.json')

In [50]:
oacc = {}

for mode in oresults:    
    fail = succ = 0
    for o in oresults[mode]['failure']:
        if o['mention_name'] in ambig_list:
            fail += 1
    for o in oresults[mode]['success']:
        if o['mention_name'] in ambig_list:
            succ += 1
    oacc[mode] = round((succ / ambig_count)*100, 2)
oacc

{'arbo': 76.82, '1rand': 75.7, '1nn': 74.34, 'knn': 10.26, 'in_batch': 12.51}