In [1]:
import json
import os
import pickle
import torch
import numpy as np

In [2]:
# Read JSON
def read_json(fname):
    with open(fname, "r") as handle:
        res = handle.read()
        obj = json.loads(res)
    return obj

In [3]:
# Read pickle
def read_pkl(fname):
    with open(fname, 'rb') as handle:
        obj_pkl = pickle.load(handle)
    return obj_pkl

In [4]:
# Read torch 
def read_torch(fname):
    obj = torch.load(fname)
    return obj

In [5]:
test_data = read_pkl('../data/zeshel/test_processed_data.pickle')

In [8]:
from collections import defaultdict
menlabel = defaultdict(set)
for t in test_data:
    menlabel[t['mention_name']].add(t['label_idxs'][0])

In [10]:
ambig_list = []
for m in menlabel:
    if len(menlabel[m]) >= 10:
        ambig_list.append(m)

In [13]:
assert len(ambig_list) == len(set(ambig_list))
ambig_list

['the previous episode',
 'his father',
 'the previous Ride',
 'Duel continues from previous episode .',
 'the next chapter',
 'the next episode',
 'the next Rank',
 'Duel continues in the next chapter . . .',
 'a Duel',
 'father',
 'the previous Rank',
 'ship',
 'the next Scale',
 'the previous chapter',
 'Duel concludes next episode .',
 'the previous Scale',
 'planet',
 'the next Ride',
 'his ship',
 'Duel continues next episode .',
 'Duel continued from previous episode .',
 'shuttlecraft',
 'TO BE CONTINUED . . .',
 'mother']

In [14]:
ambig_count = 0
for o in test_data:
    if o['mention_name'] in ambig_list:
        ambig_count += 1
print(f"Total ambiguous mentions in test: {ambig_count}")

Total ambiguous mentions in test: 655


In [15]:
results = {}
results['arbo'] = read_json('../data/zeshel/results/cross_arbo.json')
results['1rand'] = read_json('../data/zeshel/results/cross_1rand.json')
results['1nn'] = read_json('../data/zeshel/results/cross_1nn.json')
results['knn'] = read_json('../data/zeshel/results/cross_knn.json')
results['in_batch'] = read_json('../data/zeshel/results/cross_in_batch.json')

In [16]:
acc = {}

for mode in results:    
    fail = succ = 0
    for o in results[mode]['failure']:
        if o['mention_name'] in ambig_list:
            fail += 1
    for o in results[mode]['success']:
        if o['mention_name'] in ambig_list:
            succ += 1
    acc[mode] = round((succ / ambig_count)*100, 2)
acc

{'arbo': 14.81, '1rand': 19.54, '1nn': 17.25, 'knn': 16.34, 'in_batch': 20.61}

In [25]:
!ls ../data/zeshel/results

bi_1nn.json                [31mcross_in_batch.json[m[m
bi_1rand.json              [31mcross_knn.json[m[m
[31mbi_arbo.json[m[m               oracle_cross_1nn.json
[31mbi_in_batch.json[m[m           oracle_cross_1rand.json
[31mbi_knn.json[m[m                oracle_cross_arbo.json
[31mcross_1nn.json[m[m             oracle_cross_in_batch.json
[31mcross_1rand.json[m[m           oracle_cross_knn.json
[31mcross_arbo.json[m[m


In [26]:
# Oracle results
oresults = {}
oresults['arbo'] = read_json('../data/zeshel/results/oracle_cross_arbo.json')
oresults['1rand'] = read_json('../data/zeshel/results/oracle_cross_1rand.json')
oresults['1nn'] = read_json('../data/zeshel/results/oracle_cross_1nn.json')
oresults['knn'] = read_json('../data/zeshel/results/oracle_cross_knn.json')
oresults['in_batch'] = read_json('../data/zeshel/results/oracle_cross_in_batch.json')

In [27]:
oacc = {}

for mode in oresults:    
    fail = succ = 0
    for o in oresults[mode]['failure']:
        if o['mention_name'] in ambig_list:
            fail += 1
    for o in oresults[mode]['success']:
        if o['mention_name'] in ambig_list:
            succ += 1
    oacc[mode] = round((succ / ambig_count)*100, 2)
oacc

{'arbo': 16.49, '1rand': 25.04, '1nn': 22.29, 'knn': 19.08, 'in_batch': 22.44}

In [6]:
edict = read_pkl("../data/zeshel/dictionary.pickle")

In [7]:
len(edict)

492321