In [1]:
import json
import os
import pickle
import torch
import numpy as np

In [2]:
# Read JSON
def read_json(fname):
    with open(fname, "r") as handle:
        res = handle.read()
        obj = json.loads(res)
    return obj

In [3]:
# Read pickle
def read_pkl(fname):
    with open(fname, 'rb') as handle:
        obj_pkl = pickle.load(handle)
    return obj_pkl

In [4]:
# Read torch 
def read_torch(fname):
    obj = torch.load(fname)
    return obj

In [5]:
test_data = read_pkl('data/medmentions/test_processed_data.pickle')

In [6]:
dictionary = read_pkl('data/medmentions/dictionary.pickle')

In [27]:
dict_id_to_idx = {e['cui']: i for i,e in enumerate(dictionary)}

In [8]:
!ls data/medmentions/top64_cands

1nn.t7      arbo.t7     knn.t7      union.t7
1rand.t7    in_batch.t7 [31mtfidf.json[m[m


In [9]:
!ls data/medmentions/results

bi_1nn.json              [31mbi_in_batch.json[m[m         [31mcross_arbo.json[m[m
bi_1rand.json            bi_knn.json              [31mcross_arbo_nonblink.json[m[m
[31mbi_arbo_directed.json[m[m    [31mcross_1nn.json[m[m           [31mcross_in_batch.json[m[m
[31mbi_arbo_undirected.json[m[m  [31mcross_1rand.json[m[m         [31mcross_knn.json[m[m


In [31]:
tfidf = {}
with open('data/medmentions/top64_cands/tfidf.json', "r") as handle:
    lines = handle.readlines()
    for line in lines:
        obj = json.loads(line)
        tfidf[obj['mention_id']] = obj['tfidf_candidates']
existing_cands = {
    'mode': 'test',
    'candidates': {},
    'labels': []
}
for idx,men in enumerate(test_data):
    if men['mention_id'] not in tfidf:
        existing_cands['labels'].append(-1)
        existing_cands['candidates'][idx] = []
    else:
        arr = []
        existing_cands['labels'].append(-1)
        for c in tfidf[men['mention_id']]:
            arr.append(dict_id_to_idx[c])
            if c == men['label_cuis'][0]:
                existing_cands['labels'][-1] = len(arr) - 1
        existing_cands['candidates'][idx] = arr

In [33]:
# existing_cands = read_torch('data/medmentions/top64_cands/in_batch.t7')
labels = np.array(existing_cands['labels'])
np.where(labels == -1)[0]

array([   19,    21,    27, ..., 39027, 39032, 39034])

In [34]:
# existing_cands = read_torch('data/medmentions/top64_cands/in_batch.t7')
exist_labels = np.array(existing_cands['labels'])
existing_idxs = np.where(exist_labels == -1)[0]
print(existing_idxs)

[   19    21    27 ... 39027 39032 39034]


In [35]:
our_cands = read_torch('data/medmentions/top64_cands/arbo.t7')
our_labels = np.array(our_cands['labels'])
our_idxs = np.where(our_labels != -1)[0]
our_idxs_bi_recallhits = np.where(our_labels == 0)[0]
our_idxs_bi_recallnonhits = np.where(our_labels > 0)[0]
print(len(our_idxs), " - ", our_idxs)
print(len(our_idxs_bi_recallhits), " - ", our_idxs_bi_recallhits)
print(len(our_idxs_bi_recallnonhits), " - ", our_idxs_bi_recallnonhits)

37346  -  [    0     1     2 ... 39035 39036 39037]
28230  -  [    0     1     2 ... 39033 39035 39036]
9116  -  [   19    21    27 ... 39029 39034 39037]


In [37]:
# existing model failed recall
# our biencoder linked correctly

print(len(np.intersect1d(existing_idxs, our_idxs_bi_recallhits)))
print(np.intersect1d(existing_idxs, our_idxs_bi_recallhits))

1385
[   42   102   212 ... 38977 38987 38994]


In [38]:
# existing model failed recall
# our biencoder recalled correctly
# our crossencoder linked correctly

exi_fail_ourbi_succ_idxs = np.intersect1d(existing_idxs, our_idxs_bi_recallnonhits)
print(len(exi_fail_ourbi_succ_idxs))
print(exi_fail_ourbi_succ_idxs)

3043
[   19    21    27 ... 38973 38995 39034]


In [39]:
cross_res = read_json('data/medmentions/results/cross_arbo.json')
success_men_ids = set([s['mention_id'] for s in cross_res['success']])

In [40]:
cross_succ_idxs = [i for i in range(len(test_data)) if test_data[i]['mention_id'] in success_men_ids]

In [41]:
final_idxs = np.intersect1d(exi_fail_ourbi_succ_idxs, cross_succ_idxs)
print(len(final_idxs), " - ", final_idxs)

675  -  [   19    21   163   164   277   280   303   589   758   788   799   802
   823   834   837   842   851   945   956   989  1081  1085  1107  1195
  1204  1214  1332  1340  1441  1488  1713  1777  1779  1785  1789  1821
  1822  1823  1824  1831  1841  1842  1879  1904  1926  1932  2032  2125
  2184  2203  2314  2325  2352  2493  2505  2507  2654  2661  2719  2759
  2822  2987  3121  3467  3609  3913  3943  3959  3980  4071  4173  4240
  4277  4320  4341  4345  4392  4442  4445  4446  4520  4600  4644  4722
  4723  4724  4825  4852  4900  5254  5255  5347  5378  5393  5402  5411
  5721  5808  5813  5822  5976  5987  5991  5992  5994  5996  6021  6211
  6226  6374  6416  6436  6603  6622  6673  7179  7196  7344  7419  7435
  7473  7484  7580  7719  7891  7894  7895  7902  8126  8172  8433  8435
  8438  8445  8451  8455  8456  8460  8462  8464  8465  8574  8643  8648
  8657  8663  8715  8767  8784  8791  8838  8917  8922  8923  9446  9493
  9619  9712  9789  9811  9829  9870  9935 

In [51]:
res = []

menid_to_idx = {c['mention_id']: i for i,c in enumerate(cross_res['success'])}
for i in final_idxs:
    d = test_data[i]
    men_id = d['mention_id']
    ctxt = ' '.join(d['context']['tokens'])
    etype = d['type']
    name = d['mention_name']
    correct_label_idx = d['label_idxs'][0]
    if len(existing_cands['candidates'][i]) != 0:
        existing_label_idx = existing_cands['candidates'][i][0]
        existing_cui = dictionary[existing_label_idx]['cui']
        existing_cui_name = dictionary[existing_label_idx]['title']
        existing_cui_desc = ' '.join(dictionary[existing_label_idx]['tokens'])
    else:
        existing_cui = None
        existing_cui_name = None
        existing_cui_desc = None
    our_label_idx = our_cands['candidates'][i][0]  # Our biencoder model's best prediction
    res.append({
        'mention_id': men_id,
        'test_set_idx': i,
        'type': etype,
        'name': name,
        'context': ctxt,
        'existing_cui': existing_cui,
        'existing_cui_name': existing_cui_name,
        'existing_cui_desc': existing_cui_desc,
        'our_bi_cui': dictionary[our_label_idx]['cui'],
        'our_bi_cui_name': dictionary[our_label_idx]['title'],
        'our_bi_cui_desc': ' '.join(dictionary[our_label_idx]['tokens']),
        'our_cross_cui': cross_res['success'][menid_to_idx[men_id]]['mention_gold_cui'],
        'our_cross_cui_name': cross_res['success'][menid_to_idx[men_id]]['mention_gold_cui_name'],
        'our_cross_cui_desc': ' '.join(dictionary[correct_label_idx]['tokens']),
    })

In [52]:
np.random.shuffle(res)

In [110]:
list(np.array(test_data)[final_idxs])

[{'mention_id': '0195C8FCA206B09C',
  'mention_name': 'his Deck',
  'context': {'tokens': ['[CLS]',
    'for',
    'three',
    'minutes',
    'during',
    'which',
    'maintenance',
    'was',
    'carried',
    'out',
    '.',
    'in',
    'the',
    'days',
    'coming',
    'up',
    'to',
    'the',
    'scheduled',
    'maintenance',
    ',',
    'yu',
    '##sei',
    'worked',
    'on',
    'his',
    'duel',
    'runner',
    'and',
    'made',
    'sure',
    'it',
    'could',
    'cover',
    'the',
    'necessary',
    'distance',
    'in',
    'three',
    'minutes',
    '.',
    'before',
    'yu',
    '##sei',
    'left',
    ',',
    'rally',
    'gave',
    'him',
    'the',
    'card',
    '"',
    'turbo',
    'booster',
    '"',
    ',',
    'which',
    'yu',
    '##sei',
    'temporarily',
    'added',
    'to',
    '[unused1]',
    'his',
    'deck',
    '[unused2]',
    '.',
    'on',
    'the',
    'day',
    ',',
    'sector',
    'security',
    'detected

In [53]:
res

[{'mention_id': '27335087.30',
  'test_set_idx': 2759,
  'type': 'T017',
  'name': 'CD3 ( + ) T cells',
  'context': '[CLS] flow c ##yt ##ometer . Cell ultra ##structure was observed under a transmission electron micro ##scope . The mitochondrial membrane potential ( Δ ##ψ ) was examined with J ##C - 1 dye . In H ##22 tumor - bearing mice , CD ##4 ( + ) T cells , CD ##8 ( + ) T cells , [unused1] CD ##3 ( + ) T cells [unused2] , and natural killer cells in peripheral blood were evaluated c ##yt ##ometric ##ally . Inter ##le ##uki ##n ( Inter ##le ##uki ##n ) - 2 and tumor ne ##c ##rosis factor ( tumor ne ##c ##rosis factor ) - α levels were measured using radio ##im ##mu ##no ##ass ##ay . The m ##RNA levels of Ba [SEP]',
  'existing_cui': 'C1277786',
  'existing_cui_name': 'CD3 T-cell count procedure',
  'existing_cui_desc': '[CLS] CD ##3 T - cell count procedure [unused3] ( Health Care Act ##ivity , procedure : CD ##3 T - cell count ) [SEP]',
  'our_bi_cui': 'C1563926',
  'our_bi_cui_n

In [54]:
[13478, 14309, 34150, 18011, 31580]

[13478, 14309, 34150, 18011, 31580]