In [1]:
import requests
import json
import sys
from os.path import join
import pickle
sys.path.append('..')
from src.utils.utils import *
from src.utils.file import *
from collections import defaultdict
import tagme
from datetime import datetime
import pandas as pd
import Levenshtein
import numpy as np
import torch
import spacy

In [2]:
data_path = '/home/rohitalyosha/Student_Job/mannheim-nel/data'
datasets = ['conll-train', 'conll-dev', 'msnbc', 'ace2004']
tagme.GCUBE_TOKEN = "88c693df-a43f-4086-b3bc-0b555bfbc9bb-843339462"

In [3]:
rd = json_load(join(data_path, 'dicts/redirects.json'))
ent2id = json_load('/home/rohitalyosha/Student_Job/mannheim-nel/data/dicts/ent_dict.json')

In [4]:
id2c = {}
id2c_conll = pickle_load(join(data_path, 'Conll', 'conll_raw_text.pickle'))
id2c['conll-train'] = id2c_conll['train']
id2c['conll-dev'] = id2c_conll['dev']
examples = {}

for d_name in datasets[2:]:
    id2c[d_name], examples[d_name] = pickle_load(join(data_path, 'datasets', f'raw_{d_name}.pickle'))

for d_name in datasets[:2]:
    _, examples[d_name] = pickle_load(join(data_path, 'Conll', f"conll-{d_name.split('-')[-1]}.pickle"))

In [5]:
gold = {dataset : {} for dataset in datasets}
for dataset, exs in examples.items():
    for ex in exs:
        c_id, (mention, ent_str, span, _) = ex
        if c_id not in gold[dataset]:
            gold[dataset][c_id] = {'mentions': [],
                          'ents': [],
                          'spans': []}
        gold[dataset][c_id]['mentions'].append(mention)
        gold[dataset][c_id]['ents'].append(ent_str)
        gold[dataset][c_id]['spans'].append(span)

In [6]:
barack_text = """Barack Hussein Obama II (/bəˈrɑːk huːˈseɪn oʊˈbɑːmə/ (About this sound listen);[1] born August 4, 1961) is an American politician who served as the 44th President of the United States from January 20, 2009, to January 20, 2017. A member of the Democratic Party, he was the first African American to be elected to the presidency and previously served as a United States Senator from Illinois (2005–2008).
Obama was born in 1961 in Honolulu, Hawaii, two years after the territory was admitted to the Union as the 50th state. Raised largely in Hawaii, he also lived for a year of his childhood in the State of Washington and four years in Indonesia. After graduating from Columbia University in 1983, he worked as a community organizer in Chicago. In 1988, he enrolled in Harvard Law School, where he was the first black president of the Harvard Law Review. After graduating, he became a civil rights attorney and a professor, teaching constitutional law at the University of Chicago Law School from 1992 to 2004. He represented the 13th district for three terms in the Illinois Senate from 1997 to 2004, when he ran for the U.S. Senate. He received national attention in 2004 with his March primary win, his well-received July Democratic National Convention keynote address, and his landslide November election to the Senate. In 2008, he was nominated for president a year after his campaign began and after a close primary campaign against Hillary Clinton. He was elected over Republican John McCain and was inaugurated on January 20, 2009. Nine months later, he was named the 2009 Nobel Peace Prize laureate, accepting the award with the caveat that he felt there were others "far more deserving of this honor than I".
During his first two years in office, Obama signed many landmark bills into law. The main reforms were the Patient Protection and Affordable Care Act (often referred to as "Obamacare", shortened as the "Affordable Care Act"), the Dodd–Frank Wall Street Reform and Consumer Protection Act, and the Don't Ask, Don't Tell Repeal Act of 2010. The American Recovery and Reinvestment Act of 2009 and Tax Relief, Unemployment Insurance Reauthorization, and Job Creation Act of 2010 served as economic stimulus amidst the Great Recession. After a lengthy debate over the national debt limit, he signed the Budget Control and the American Taxpayer Relief Acts. In foreign policy, he increased U.S. troop levels in Afghanistan, reduced nuclear weapons with the United States–Russia New START treaty, and ended military involvement in the Iraq War. He ordered military involvement in Libya in opposition to Muammar Gaddafi; Gaddafi was killed by NATO-assisted forces, and he also ordered the military operation that resulted in the deaths of Osama bin Laden and suspected Yemeni Al-Qaeda operative Anwar al-Awlaki.
"""
barack_mentions = ['President', 'United States', 'African American', 'Democratic Party']

In [7]:
def get_response_full(text):
    data_json = json.dumps({'text': text})
    response_json = requests.post("http://127.0.0.1:5000/link", data=data_json).json()
    ents = response_json['entities']
    mentions = response_json['mentions']
    spans = response_json['spans']
    
    return ents, mentions, spans

In [8]:
def get_response_mention(text, user_mentions, user_spans):
    data_json = json.dumps({'text': text, 'mentions': user_mentions, 'spans': user_spans})
    response_json = requests.post("http://127.0.0.1:5000/link", data=data_json).json()
    ents = response_json['entities']
    mentions = response_json['mentions']
    
    return ents, mentions

In [9]:
def get_full_results(num_text, dataset='dev'):
    results = {}
    times = []
    for doc_id, text in list(id2c[dataset].items())[:num_text]:
        results[doc_id] = {}

        tic = datetime.now()
        ents, mentions, spans = get_response_full(text)
        toc = datetime.now()
        times.append({'len': len(text), 'time (s)': (toc - tic).total_seconds()})
        results[doc_id]['mentions'] = mentions
        results[doc_id]['ents'] = ents
        results[doc_id]['spans'] = [tuple(span) for span in spans]
    
    return results, times

In [10]:
def get_mention_results(num_text, dataset='conll-dev'):
    results = {}
    times = []
    for doc_id, text in list(id2c[dataset].items())[:num_text]:
        if doc_id not in gold[dataset]:
            continue
        results[doc_id] = {}
        user_mentions = gold[dataset][doc_id]['mentions']
        user_spans = gold[dataset][doc_id]['spans']
        try:
            ents, mentions = get_response_mention(text, user_mentions, user_spans)
        except Exception as e:
            print(Text, user_mentions)
        results[doc_id]['mentions'] = mentions
        results[doc_id]['ents'] = ents

    return results

In [47]:
def common_idx(pred_spans, gold_spans, thresh=0.5):
    i1 = 0
    i2 = 0
    res = []
    for i1, pred_span in enumerate(pred_spans):
        for i2, gold_span in enumerate(gold_spans):
            gold_begin = gold_span[0]
            gold_end = gold_span[1]
            
            pred_begin = pred_span[0]
            pred_end = pred_span[1]
            
            len_gold = gold_end - gold_begin
            
            if thresh == 1:
                if pred_span == gold_span:
                    res.append((i1, i2))
                    
            else:

                if pred_end >= gold_begin and pred_end <= gold_end:
                    overlap = (pred_end - gold_begin) / len_gold
                    if overlap >= thresh:
                        res.append((i1, i2))
                elif gold_end >= pred_begin and gold_end <= pred_end:
                    overlap = (gold_end - pred_begin) / len_gold
                    if overlap >= thresh:
                        res.append((i1, i2))
    
    # If same mention is counted twice, only add it once
    i1_cov = set()
    i2_cov = set()
    final_res = []
    for i1, i2 in res:
        if i1 not in i1_cov and i2 not in i2_cov:
            i1_cov.add(i1)
            i2_cov.add(i2)
            final_res.append((i1, i2))

        
    return final_res

In [12]:
def eval_full(results, dataset='conll-dev', verbose=False, mention_thresh=0.5, tagme_thresh=0.1):
    total_correct = 0
    total = 0
    num_detected = 0
    match_idxss = []
    not_covered_idxss = []

    for k, preds in results.items():
        if k not in gold[dataset]:
            if verbose:
                print('not in gold', k)
            continue
        if isinstance(preds, dict):
            pred_spans = preds['spans']
            pred_titles = preds['ents']
        else:
            pred_spans = [(ann.begin, ann.end) for ann in preds.get_annotations(tagme_thresh)]
            pred_titles = [tagme.normalize_title(ann.entity_title) for ann in preds.get_annotations(tagme_thresh)]
        num_detected += len(pred_spans)
            
        correct_spans = gold[dataset][k]['spans']
        overlap = common_idx(pred_spans, correct_spans, thresh=mention_thresh)
        if verbose:
            print(f'Correct: {correct_spans}')
            print(f'Predicted: {pred_spans}')
            print(f'Overlap: {overlap}\n\n')

        match = [(gold[dataset][k]['ents'][correct_idx], pred_titles[pred_idx]) for pred_idx, correct_idx in overlap]
        match = [(rd.get(t[0], t[0]), rd.get(t[1], t[1])) for t in match]
        
        match_idxs = [correct_idx for pred_idx, correct_idx in overlap]
        match_idxss.append(match_idxs)
        not_covered_idxs = [idx for idx, _ in enumerate(gold[dataset][k]['ents']) if idx not in match_idxs]
        not_covered_idxss.append(not_covered_idxs)
        
        correct = 0
        for m in match:
            total += 1
            if m[0] == m[1]:
                correct += 1
                total_correct += 1
        local_acc = correct / len(match) if len(match) else 0
        
        if verbose:
            if local_acc < 0.2:
                print(match)
    
    return num_detected, total_correct, total, match_idxss, not_covered_idxss

## Ours

#### Eval full pipeline

In [64]:
DATASET = 'ace2004'

In [65]:
our_results, our_times = get_full_results(2000, dataset=DATASET)

In [66]:
df = pd.DataFrame(our_times)
df.describe()

Unnamed: 0,len,time (s)
count,57.0,57.0
mean,2300.052632,0.157395
std,1415.323732,0.08502
min,196.0,0.017807
25%,1210.0,0.095136
50%,2042.0,0.146881
75%,3474.0,0.215752
max,5148.0,0.353154


In [67]:
for mention_thresh in [0.1, 0.5, 0.8, 1]:
    num_detected, our_correct, our_total, _, _ = eval_full(our_results, 
                                                           dataset=DATASET, 
                                                           mention_thresh=mention_thresh,
                                                           verbose=False)
    num_mentions = 0
    for k, v in gold[DATASET].items():
        num_mentions += len(v['mentions'])
    print('Det Thresh: {}, Detection: {}, Num mentions: {}, Match: {}, Correct: {}'.format(mention_thresh,
                                                                                           num_detected,
                                                                                           num_mentions,
                                                                                           our_total,
                                                                                           our_correct))

Det Thresh: 0.1, Detection: 1105, Num mentions: 257, Match: 213, Correct: 183
Det Thresh: 0.5, Detection: 1105, Num mentions: 257, Match: 209, Correct: 181
Det Thresh: 0.8, Detection: 1105, Num mentions: 257, Match: 204, Correct: 178
Det Thresh: 1, Detection: 1105, Num mentions: 257, Match: 164, Correct: 148


#### Eval only linking

In [68]:
DATASET = 'ace2004'

In [85]:
mention_results = get_mention_results(500, dataset=DATASET)

num_correct = 0
total = 0
num_no_link = 0
no_links = []

for k, v in mention_results.items():
    if k not in gold[DATASET]:
        print(k, v)
        continue
    gold_ents = gold[DATASET][k]['ents']
    pred_ents = v['ents']
    gold_mentions = gold[DATASET][k]['mentions']
    for i, (gold_ent, pred_ent) in enumerate(zip(gold_ents, pred_ents)):
        gold_ent = rd.get(gold_ent, gold_ent)
        pred_ent = rd.get(pred_ent, pred_ent)
        total += 1
        if pred_ent == 'NO LINK FOUND':
            num_no_link += 1
            no_links.append(gold_ent)
        if gold_ent == pred_ent:
            num_correct += 1
        else:
            pass

In [86]:
print(num_correct, total, num_correct / total)

219 257 0.8521400778210116


In [46]:
pred = [(ann.begin, ann.end) for ann in tagme_results[5].get_annotations(0.2)]
correct = gold['conll-dev'][0]['spans']
for p_i, c_i in common_idx(pred, correct, thresh=0.1):
    print(pred[p_i], correct[c_i])

(48, 56) (44, 51)
(170, 176) (166, 179)


## Tagme

In [13]:
def get_tagme_results(num_text, dataset='conll-dev'):
    results = {}
    times = []
    for i, (doc_id, text) in enumerate(list(id2c[dataset].items())[:num_text]):
        text =  id2c[dataset][doc_id] 
        tic = datetime.now()
        results[doc_id] = tagme.annotate(text)
        toc = datetime.now()
        times.append({'len': len(text), 'time (s)': (toc - tic).total_seconds()})
    if i % 20 == 0:
        print(i, i / num_text)

    return results, times

In [79]:
DATASET = 'ace2004'

In [80]:
tagme_results, tagme_times = get_tagme_results(len(id2c[DATASET]), dataset=DATASET)

In [81]:
df = pd.DataFrame(tagme_times)
df.describe()

Unnamed: 0,len,time (s)
count,57.0,57.0
mean,2300.052632,1.638665
std,1415.323732,1.715196
min,196.0,0.28264
25%,1210.0,0.784536
50%,2042.0,1.124888
75%,3474.0,1.871518
max,5148.0,12.222618


In [83]:
for tag_thresh in [0.1, 0.15, 0.2, 0.3, 0.5]:
    for mention_thresh in [0.1, 0.5, 0.8, 1]:
        num_detected, tagme_correct, tagme_total, _, _ = eval_full(tagme_results,
                                                                  dataset=DATASET, 
                                                                  mention_thresh=mention_thresh,
                                                                  tagme_thresh=tag_thresh)
        num_mentions = 0
        for k, v in gold[DATASET].items():
            num_mentions += len(v['mentions'])
        print('Tag Thresh: {}, Mention Thresh: {}, Detection: {}, Num mentions: {}, Match: {}, Correct: {}'.format(tag_thresh,
                                                                                                               mention_thresh,
                                                                                                               num_detected, 
                                                                                                               num_mentions,
                                                                                                               tagme_total,
                                                                                                               tagme_correct))

Tag Thresh: 0.1, Mention Thresh: 0.1, Detection: 2808, Num mentions: 257, Match: 245, Correct: 189
Tag Thresh: 0.1, Mention Thresh: 0.5, Detection: 2808, Num mentions: 257, Match: 239, Correct: 184
Tag Thresh: 0.1, Mention Thresh: 0.8, Detection: 2808, Num mentions: 257, Match: 236, Correct: 184
Tag Thresh: 0.1, Mention Thresh: 1, Detection: 2808, Num mentions: 257, Match: 210, Correct: 174
Tag Thresh: 0.15, Mention Thresh: 0.1, Detection: 1935, Num mentions: 257, Match: 233, Correct: 186
Tag Thresh: 0.15, Mention Thresh: 0.5, Detection: 1935, Num mentions: 257, Match: 225, Correct: 181
Tag Thresh: 0.15, Mention Thresh: 0.8, Detection: 1935, Num mentions: 257, Match: 221, Correct: 180
Tag Thresh: 0.15, Mention Thresh: 1, Detection: 1935, Num mentions: 257, Match: 199, Correct: 170
Tag Thresh: 0.2, Mention Thresh: 0.1, Detection: 1368, Num mentions: 257, Match: 210, Correct: 171
Tag Thresh: 0.2, Mention Thresh: 0.5, Detection: 1368, Num mentions: 257, Match: 204, Correct: 168
Tag Thresh