In [1]:
import requests
import json
import sys
from os.path import join
import pickle
sys.path.append('/home/rogupta/mannheim-nel/')
from collections import defaultdict
import tagme
from datetime import datetime
import pandas as pd
import numpy as np
import torch
import spacy
import os

In [2]:
def pickle_load(path):
    assert os.path.exists(path)

    with open(path, 'rb') as f:
        data = pickle.load(f)

    return data

def json_load(path):
    assert os.path.exists(path)

    with open(path, 'r') as f:
        data = json.load(f)

    return data

In [83]:
data_path = '/work/rogupta/mannheim-nel-data/'
datasets = ['conll-train', 'conll-dev', 'msnbc', 'ace2004']
tagme.GCUBE_TOKEN = "88c693df-a43f-4086-b3bc-0b555bfbc9bb-843339462"
PORT = "127.0.0.1:5000"
DATASET = 'msnbc'

In [4]:
rd = json_load(join(data_path, 'dicts/redirects.json'))
ent2id = json_load(join(data_path, 'dicts/redirects.json'))

In [5]:
id2c = {}
id2c_conll = pickle_load(join(data_path, 'Conll', 'conll_raw_text.pickle'))
id2c['conll-train'] = id2c_conll['train']
id2c['conll-dev'] = id2c_conll['dev']
examples = {}

for d_name in datasets[2:]:
    id2c[d_name], examples[d_name] = pickle_load(join(data_path, 'datasets', f'raw_{d_name}.pickle'))

for d_name in datasets[:2]:
    _, examples[d_name] = pickle_load(join(data_path, 'Conll', f"conll-{d_name.split('-')[-1]}.pickle"))

In [6]:
gold = {dataset : {} for dataset in datasets}
for dataset, exs in examples.items():
    for ex in exs:
        c_id, (mention, ent_str, span, _) = ex
        if c_id not in gold[dataset]:
            gold[dataset][c_id] = {'mentions': [],
                          'ents': [],
                          'spans': []}
        gold[dataset][c_id]['mentions'].append(mention)
        gold[dataset][c_id]['ents'].append(ent_str)
        gold[dataset][c_id]['spans'].append(span)

In [7]:
barack_text = """Barack Hussein Obama II (/bəˈrɑːk huːˈseɪn oʊˈbɑːmə/ (About this sound listen);[1] born August 4, 1961) is an American politician who served as the 44th President of the United States from January 20, 2009, to January 20, 2017. A member of the Democratic Party, he was the first African American to be elected to the presidency and previously served as a United States Senator from Illinois (2005–2008).
Obama was born in 1961 in Honolulu, Hawaii, two years after the territory was admitted to the Union as the 50th state. Raised largely in Hawaii, he also lived for a year of his childhood in the State of Washington and four years in Indonesia. After graduating from Columbia University in 1983, he worked as a community organizer in Chicago. In 1988, he enrolled in Harvard Law School, where he was the first black president of the Harvard Law Review. After graduating, he became a civil rights attorney and a professor, teaching constitutional law at the University of Chicago Law School from 1992 to 2004. He represented the 13th district for three terms in the Illinois Senate from 1997 to 2004, when he ran for the U.S. Senate. He received national attention in 2004 with his March primary win, his well-received July Democratic National Convention keynote address, and his landslide November election to the Senate. In 2008, he was nominated for president a year after his campaign began and after a close primary campaign against Hillary Clinton. He was elected over Republican John McCain and was inaugurated on January 20, 2009. Nine months later, he was named the 2009 Nobel Peace Prize laureate, accepting the award with the caveat that he felt there were others "far more deserving of this honor than I".
During his first two years in office, Obama signed many landmark bills into law. The main reforms were the Patient Protection and Affordable Care Act (often referred to as "Obamacare", shortened as the "Affordable Care Act"), the Dodd–Frank Wall Street Reform and Consumer Protection Act, and the Don't Ask, Don't Tell Repeal Act of 2010. The American Recovery and Reinvestment Act of 2009 and Tax Relief, Unemployment Insurance Reauthorization, and Job Creation Act of 2010 served as economic stimulus amidst the Great Recession. After a lengthy debate over the national debt limit, he signed the Budget Control and the American Taxpayer Relief Acts. In foreign policy, he increased U.S. troop levels in Afghanistan, reduced nuclear weapons with the United States–Russia New START treaty, and ended military involvement in the Iraq War. He ordered military involvement in Libya in opposition to Muammar Gaddafi; Gaddafi was killed by NATO-assisted forces, and he also ordered the military operation that resulted in the deaths of Osama bin Laden and suspected Yemeni Al-Qaeda operative Anwar al-Awlaki.
"""
barack_mentions = ['President', 'United States', 'African American', 'Democratic Party']

In [8]:
def get_response_full(text, max_cands=100):
    data_json = json.dumps({'text': text,
                            'max_cands': max_cands})
    response_json = requests.post(f"http://{PORT}/link", data=data_json).json()
    ents = response_json['entities']
    mentions = response_json['mentions']
    spans = response_json['spans']
    
    return ents, mentions, spans

In [9]:
def get_response_mention(text, user_mentions, user_spans, max_cands=100):
    data_json = json.dumps({'text': text,
                            'mentions': user_mentions,
                            'spans': user_spans,
                            'max_cands': max_cands})
    response_json = requests.post(f"http://{PORT}/link", data=data_json).json()
    ents = response_json['entities']
    mentions = response_json['mentions']
    
    return ents, mentions

In [10]:
def get_full_results(num_text, dataset='dev', max_cands=100):
    results = {}
    times = []
    for doc_id, text in list(id2c[dataset].items())[:num_text]:
        results[doc_id] = {}

        tic = datetime.now()
        ents, mentions, spans = get_response_full(text)
        toc = datetime.now()
        times.append({'len': len(text), 'time (s)': (toc - tic).total_seconds()})
        results[doc_id]['mentions'] = mentions
        results[doc_id]['ents'] = ents
        results[doc_id]['spans'] = [tuple(span) for span in spans]
    
    return results, times

In [11]:
def get_mention_results(num_text, dataset='conll-dev', max_cands=100):
    results = {}
    times = []
    for doc_id, text in list(id2c[dataset].items())[:num_text]:
        if doc_id not in gold[dataset]:
            continue
        results[doc_id] = {}
        user_mentions = gold[dataset][doc_id]['mentions']
        user_spans = gold[dataset][doc_id]['spans']
        try:
            ents, mentions = get_response_mention(text, user_mentions, user_spans, max_cands=100)
        except Exception as e:
            print(Text, user_mentions)
        results[doc_id]['mentions'] = mentions
        results[doc_id]['ents'] = ents

    return results

In [61]:
def common_idx(pred_spans, gold_spans, thresh=0.5):
    i1 = 0
    i2 = 0
    res = []
    for i1, pred_span in enumerate(pred_spans):
        for i2, gold_span in enumerate(gold_spans):
            gold_begin = gold_span[0]
            gold_end = gold_span[1]
            
            pred_begin = pred_span[0]
            pred_end = pred_span[1]
            
            len_gold = gold_end - gold_begin
            len_pred = pred_end - pred_begin
            min_l = min(len_gold, len_pred)
            
            if thresh == 1:
                if pred_span == gold_span:
                    res.append((i1, i2))
                    
            else:

                if pred_end > gold_begin and pred_end < gold_end and pred_begin < gold_begin:
                    overlap = (pred_end - gold_begin) / min_l
                    if overlap >= thresh:
                        res.append((i1, i2))
                elif gold_end > pred_begin and gold_end < pred_end and pred_begin > gold_begin:
                    overlap = (gold_end - pred_begin) / min_l
                    if overlap >= thresh:
                        res.append((i1, i2))
                elif pred_begin >= gold_begin and pred_end <= gold_end:
                    res.append((i1, i2))
                elif gold_begin >= pred_begin and gold_end <= pred_end:
                    res.append((i1, i2))
    
    # If same mention is counted twice, only add it once
    i1_cov = set()
    i2_cov = set()
    final_res = []
    for i1, i2 in res:
        if i1 not in i1_cov and i2 not in i2_cov:
            i1_cov.add(i1)
            i2_cov.add(i2)
            final_res.append((i1, i2))
     
    return final_res

In [62]:
def eval_full(results, dataset='conll-dev', verbose=False, mention_thresh=0.5, tagme_thresh=0.1):
    total_correct = 0
    total = 0
    num_detected = 0
    match_idxss = []
    not_covered_idxss = []

    for k, preds in results.items():
        if k not in gold[dataset]:
            if verbose:
                print('not in gold', k)
            continue
        if isinstance(preds, dict):
            pred_spans = preds['spans']
            pred_titles = preds['ents']
        else:
            pred_spans = [(ann.begin, ann.end) for ann in preds.get_annotations(tagme_thresh)]
            pred_titles = [tagme.normalize_title(ann.entity_title) for ann in preds.get_annotations(tagme_thresh)]
        num_detected += len(pred_spans)
            
        correct_spans = gold[dataset][k]['spans']
        overlap = common_idx(pred_spans, correct_spans, thresh=mention_thresh)
        if verbose:
            print(f'Correct: {correct_spans}')
            print(f'Predicted: {pred_spans}')
            print(f'Overlap: {overlap}\n\n')

        match = [(gold[dataset][k]['ents'][correct_idx], pred_titles[pred_idx]) for pred_idx, correct_idx in overlap]
        match = [(rd.get(t[0], t[0]), rd.get(t[1], t[1])) for t in match]
        
        match_idxs = [correct_idx for pred_idx, correct_idx in overlap]
        match_idxss.append(match_idxs)
        not_covered_idxs = [idx for idx, _ in enumerate(gold[dataset][k]['ents']) if idx not in match_idxs]
        not_covered_idxss.append(not_covered_idxs)
        
        correct = 0
        for m in match:
            total += 1
            if m[0] == m[1]:
                correct += 1
                total_correct += 1
        local_acc = correct / len(match) if len(match) else 0
        
        if verbose:
            if local_acc < 0.2:
                print(match)
    
    return num_detected, total_correct, total, match_idxss, not_covered_idxss

In [63]:
def eval_mention(results, dataset='conll-dev'):
    num_correct = 0
    total = 0
    num_no_link = 0
    no_links = []
    correct_triples = []
    incorrect_triples = []

    for k, v in mention_results.items():
        if k not in gold[DATASET]:
            print(k, v)
            continue
        gold_ents = gold[DATASET][k]['ents']
        pred_ents = v['ents']
        mentions = gold[DATASET][k]['mentions']
        for i, (mention, gold_ent, pred_ent) in enumerate(zip(mentions, gold_ents, pred_ents)):

            gold_ent = rd.get(gold_ent, gold_ent)
            pred_ent = rd.get(pred_ent, pred_ent)
            triple = mention, pred_ent, gold_ent
            total += 1
            if pred_ent == 'NO LINK FOUND':
                num_no_link += 1
                no_links.append(gold_ent)
            if gold_ent == pred_ent:
                correct_triples.append(triple)
                num_correct += 1
            else:
                incorrect_triples.append(triple)
                pass
            
    return num_correct, total, num_no_link

## Ours

#### Eval full pipeline

In [85]:
our_results, our_times = get_full_results(2000, dataset=DATASET, max_cands=100)

In [86]:
df = pd.DataFrame(our_times)
df.describe()

Unnamed: 0,len,time (s)
count,20.0,20.0
mean,3380.15,1.717229
std,1426.553715,0.775217
min,941.0,0.664579
25%,2220.5,0.88627
50%,3714.0,1.93945
75%,4500.0,2.373915
max,5821.0,3.049635


In [87]:
for mention_thresh in [0.1, 0.5, 0.8, 0.99, 1]:
    num_detected, our_correct, our_total, match, not_covered = eval_full(our_results, 
                                                           dataset=DATASET, 
                                                           mention_thresh=mention_thresh,
                                                           verbose=False)
    num_mentions = 0
    for k, v in gold[DATASET].items():
        num_mentions += len(v['mentions'])
    p = our_correct / num_detected
    r = our_correct / num_mentions
    f = 2 * p * r / (p + r)
    print('Det Thresh: {}, Detection: {}, Num mentions: {}, Match: {}, Correct: {}, P: {:.3f}, R: {:.3f}, f: {:.3f}'.format(float(mention_thresh),
                                                                                                                            num_detected,
                                                                                                                            num_mentions,
                                                                                                                            our_total,
                                                                                                                            our_correct,
                                                                                                                            p,
                                                                                                                            r,
                                                                                                                            f))

Det Thresh: 0.1, Detection: 793, Num mentions: 656, Match: 591, Correct: 463, P: 0.584, R: 0.706, f: 0.639
Det Thresh: 0.5, Detection: 793, Num mentions: 656, Match: 591, Correct: 463, P: 0.584, R: 0.706, f: 0.639
Det Thresh: 0.8, Detection: 793, Num mentions: 656, Match: 591, Correct: 463, P: 0.584, R: 0.706, f: 0.639
Det Thresh: 0.99, Detection: 793, Num mentions: 656, Match: 590, Correct: 463, P: 0.584, R: 0.706, f: 0.639
Det Thresh: 1.0, Detection: 793, Num mentions: 656, Match: 461, Correct: 394, P: 0.497, R: 0.601, f: 0.544


In [88]:
394 / 461

0.8546637744034707

#### Eval only linking

In [18]:
mention_results = get_mention_results(20000, dataset=DATASET, max_cands=100)

In [19]:
num_correct, total, num_no_link = eval_mention(mention_results, dataset=DATASET)
print(num_correct, total, num_no_link, num_correct / total)

4253 4825 84 0.8814507772020725


In [69]:
4825 - 4253

572

## Tagme

In [65]:
def get_tagme_results(num_text, dataset='conll-dev'):
    results = {}
    times = []
    for i, (doc_id, text) in enumerate(list(id2c[dataset].items())[:num_text]):
        text =  id2c[dataset][doc_id] 
        tic = datetime.now()
        results[doc_id] = tagme.annotate(text)
        toc = datetime.now()
        times.append({'len': len(text), 'time (s)': (toc - tic).total_seconds()})
    if i % 20 == 0:
        print(i, i / num_text)

    return results, times

In [66]:
tagme_results, tagme_times = get_tagme_results(len(id2c[DATASET]), dataset=DATASET)

In [67]:
df = pd.DataFrame(tagme_times)
df.describe()

Unnamed: 0,len,time (s)
count,217.0,217.0
mean,1295.382488,1.67894
std,1043.489772,3.29638
min,165.0,0.257416
25%,545.0,0.52765
50%,990.0,0.922129
75%,1728.0,1.786551
max,6198.0,43.165836


In [77]:
res = []
for tag_thresh in [0.1, 0.15, 0.2, 0.3, 0.5]:
    for mention_thresh in [0.1, 0.5, 0.8, 1]:
        num_detected, tagme_correct, tagme_total, _, _ = eval_full(tagme_results,
                                                                  dataset=DATASET, 
                                                                  mention_thresh=mention_thresh,
                                                                  tagme_thresh=tag_thresh)
        num_mentions = 0
        for k, v in gold[DATASET].items():
            num_mentions += len(v['mentions'])
        res.append({'Tag Thesh': tag_thresh,
                    'Mention Thresh': mention_thresh,
                    'Detection': num_detected,
                    'Num mentions': num_mentions,
                    'Match': tagme_total,
                    'Correct': tagme_correct})
        print('Tag Thresh: {}, Mention Thresh: {}, Detection: {}, Num mentions: {}, Match: {}, Correct: {}'.format(tag_thresh,
                                                                                                               mention_thresh,
                                                                                                               num_detected, 
                                                                                                               num_mentions,
                                                                                                               tagme_total,
                                                                                                               tagme_correct))

Tag Thresh: 0.1, Mention Thresh: 0.1, Detection: 11118, Num mentions: 4825, Match: 4433, Correct: 3090
Tag Thresh: 0.1, Mention Thresh: 0.5, Detection: 11118, Num mentions: 4825, Match: 4433, Correct: 3090
Tag Thresh: 0.1, Mention Thresh: 0.8, Detection: 11118, Num mentions: 4825, Match: 4433, Correct: 3092
Tag Thresh: 0.1, Mention Thresh: 1, Detection: 11118, Num mentions: 4825, Match: 4206, Correct: 3066
Tag Thresh: 0.15, Mention Thresh: 0.1, Detection: 8562, Num mentions: 4825, Match: 4154, Correct: 2960
Tag Thresh: 0.15, Mention Thresh: 0.5, Detection: 8562, Num mentions: 4825, Match: 4154, Correct: 2960
Tag Thresh: 0.15, Mention Thresh: 0.8, Detection: 8562, Num mentions: 4825, Match: 4154, Correct: 2961
Tag Thresh: 0.15, Mention Thresh: 1, Detection: 8562, Num mentions: 4825, Match: 3944, Correct: 2932
Tag Thresh: 0.2, Mention Thresh: 0.1, Detection: 6766, Num mentions: 4825, Match: 3860, Correct: 2818
Tag Thresh: 0.2, Mention Thresh: 0.5, Detection: 6766, Num mentions: 4825, Mat

In [78]:
df_tagme = pd.DataFrame(res)

In [79]:
df_tagme.to_csv('tagme.csv')