In [1]:
!pip install scikit-learn==1.5.0 nltk sklearn_crfsuite

Defaulting to user installation because normal site-packages is not writeable
Collecting scikit-learn==1.5.0
  Obtaining dependency information for scikit-learn==1.5.0 from https://files.pythonhosted.org/packages/46/c0/63d3a8da39a2ee051df229111aa93f6dca2b56f8080abd34993938166455/scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting sklearn_crfsuite
  Obtaining dependency information for sklearn_crfsuite from https://files.pythonhosted.org/packages/b2/11/a8370dd6fce65f8f4e74a0adffae72be9db5799d8ed8ddbf84415356a764/sklearn_crfsuite-0.5.0-py2.py3-none-any.whl.metadata
  Downloading sklearn_crfsuite-0.5.0-py2.py3-none-any.whl.metadata (4.9 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn==1.5.0)
  Obtaining dependency information for threadpoolctl>=3.1.0 from https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453

In [20]:
import sys
sys.path.append("./NER-Evaluation")

import nltk
import sklearn_crfsuite

from copy import deepcopy
from collections import defaultdict

from sklearn_crfsuite.metrics import flat_classification_report

from ner_evaluation.ner_eval import compute_metrics
from ner_evaluation.ner_eval import compute_precision_recall_wrapper

import pandas as pd
import json

In [7]:
from collections import namedtuple
Entity = namedtuple("Entity", "e_type start_offset end_offset")

### Implement our own version of collect_named_entities

**Use code from coref_reformat.ipynb from crosslingual_coref**

In [31]:
# Get FAA data in format {c5_id:{0: word0, 1: word1, ..., n: wordn}} using word tokenization from faa.conll

faa = {}

with open('../../data/FAA_data/faa.conll') as f:
    text = f.read()

docs = text.split('#begin document ')

for doc in docs:
    if doc[:5] == '(faa/':
        word_count = 0
        c5_id = doc.split('_')[1][:15]
        faa[c5_id] = {}
        lines = doc.split('\n')
        for line in lines[1:]:
            if 'faa' in line:
                faa[c5_id][word_count] = line.split()[3].upper()
                word_count = word_count + 1

In [25]:
faa['19750419011349A'].values()

dict_values(['TOW', 'PLANE', 'BECAME', 'AIRBORNE', 'THEN', 'SETTLED', '/.', 'STUDENT', 'THOUGHT', 'TOW', 'IN', 'TROUBLE', '&', 'RELEASED', '/.', 'HIT', 'TREE', '/.'])

In [145]:
# Fix known err
faa['19980620030289I'] = {0: 'MR.', 1: 'KADERA', 2: 'THEN', 3: 'ATTEMPTED', 4: 'TO', 5: 'LAND', 6: 'IN', 7: 'A', 8: 'FIELD', 9: 'BUT', 10: 'WAS', 11: 'FORCED', 12: 'TO', 13: 'LAND', 14: 'ON', 15: 'HIGHWAY', 16: '93', 17: '.', 18: 'THREE', 19: 'MILES', 20: 'EAST', 21: 'OF', 22: 'SUNMER', 23: ',', 24: 'IOWA'}

In [203]:
def get_spans(mentions, words):
    ''' Input:
    - mentions:['MENTION1','MENTION2',...]
    - words: ['This','is','a','sentence','.','This','is','another','sentence','.'] (dict values)
        Output: [[startidx_mention1, end_idxmention1], [startidx_mention2, end_idxmention2], ...]
    '''

    mention_spans = []

    resume_idx = 0
    
    for mention in mentions:

        mention_span = [-1, -1]
        
        if mention in ' '.join(words.values()):
            # find start of mention
            for iword, word in words.items():
                if iword >= resume_idx and mention.split()[0] == word:
                    mention_span[0] = iword

                    if words[iword + len(mention.split()) - 1] == mention.split()[-1]:
                        mention_span[1] = iword + len(mention.split()) # REMOVED -1 FROM HERE, SO THAT PARTIAL MATCH WORKS CORRECTLY
                        resume_idx = iword + len(mention.split()) # set 'resume_idx' such that it continues for looking for mentions in the coref chain after this one, so it cannot be counted twice
                        break
                    else:
                        mention_span[0] = -2
                    # else reset and continue
        
        mention_spans.append(mention_span)

    return mention_spans

In [204]:
get_spans(['TOW PLANE','STUDENT'],faa['19750419011349A']) # example

[[0, 2], [7, 8]]

**Now, very simple to make our own version of collect_named_entities()**

In [205]:
def collect_named_entities(entities, labels, tokens):
    """
    Creates a list of Entity named-tuples, storing the entity type and the start and end
    offsets of the entity.

    Parameters:
    - entities: ["ENT1","ENT2"...] All entities for a doc
    - labels: ["LABEL1","LABEL2"...] All corresponding labels for a doc
    - tokens: dict_values(['TOW', 'PLANE', 'BECAME', ...]) Tokenized doc. Result of faa[doc_id].values()

    Returns: a list of Entity named-tuples
    """

    ent_spans = get_spans(entities, tokens)

    named_entities = []
    for ient, ent_span in enumerate(ent_spans):
        named_entities.append(Entity(labels[ient], ent_span[0], ent_span[1]))

    return named_entities

In [206]:
collect_named_entities(['TOW PLANE','STUDENT'],['VEHICLE','PER'],faa['19750419011349A']) #example

[Entity(e_type='VEHICLE', start_offset=0, end_offset=2),
 Entity(e_type='PER', start_offset=7, end_offset=8)]

### Get Predicted and Gold Data

In [410]:
abbrevs = {'FACILITY':'FAC','ORGANIZATION':'ORG','PERSON':'PER','LOCATION':'LOC'}

In [789]:
result_df = pd.read_csv('../../data/results/nltk/nltk_ner_uppercased.csv')
result_df['labels'] = result_df['labels'].apply(lambda x: abbrevs[x] if x in abbrevs else x)
result_df.head()

Unnamed: 0,index,c5_unique_id,c119_text,entities,POS tags,labels
0,2318,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,ACFT,NNP,ORG
1,2318,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,WAS,NNP,ORG
2,2318,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,IT,NNP,ORG
3,2318,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,LOST,NNP,ORG
4,2318,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,RAN,NNP,ORG


In [701]:
gold_df = pd.read_csv('../../gold_standard/processed/ner.csv')
if 'labels' not in gold_df.columns:
    gold_df['labels'] = ['ORG']*len(gold_df) # Add dummy labels for aviation mentions-only gs
gold_df.head()

Unnamed: 0,id,sample,entities,labels
0,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,ACFT,ORG
1,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,DITCH,ORG
2,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,TREE,ORG
3,19990213001379A,ACFT WAS TAXIING FOR TAKE OFF WHEN IT LOST CON...,LOST CONTROL,ORG
4,19800217031649I,"AFTER TAKEOFF, ENGINE QUIT. WING FUEL TANK SUM...",TAKEOFF,ORG


In [598]:
#gold_df['entities'].iat[106] = 'BARTLESVILLE' # typo in ACE-2005 Gold Df

In [604]:
#gold_df['entities'].iat[86] = 'BARTLESVILLE' # typo in OntoNotes Gold Df

In [702]:
gold_df['entities'].iat[209] = 'ATTITUDE' # typo in our Gold Df
gold_df['entities'].iat[222] = 'ACCELARATION'
gold_df['entities'].iat[439] = 'MR. KADERA'

In [792]:
id_col = 'c5_unique_id'
len(result_df[id_col].unique())

100

In [603]:
len(gold_df['id'].unique())

100

**Create all_true_ents and all_pred_ents lists**

In the example in example-full-named-entity-evaluation.ipynb, they use test_sents_labels and y_pred, which are lists of the tokens to input to collect_named_entities to get true and pred, respectively. However, our collect_named_entities() takes more input, so it's easier to do that process ahead of time, and have lists of true's and pred's ready to go

In [793]:
all_true_ents = []
all_pred_ents = []

for doc_id in gold_df['id'].unique():
    true_rows = gold_df.dropna()[gold_df.dropna()['id']==doc_id]
    pred_rows = result_df.dropna()[result_df.dropna()[id_col]==doc_id]

    true_ents = collect_named_entities(true_rows['entities'].to_list(),true_rows['labels'].to_list(),faa[doc_id])
    pred_ents = collect_named_entities(pred_rows['entities'].to_list(),pred_rows['labels'].to_list(),faa[doc_id])

    if True in [ent[2] < 0 for ent in true_ents]:
        print("\nTrue: ",doc_id)
        print(true_ents)
    #if True in [ent[2] < 0 for ent in pred_ents]:
    #    print("\nPred: ",doc_id)
    #    print(pred_ents)
    
    all_true_ents.append(true_ents)
    all_pred_ents.append(pred_ents)


True:  19990213001379A
[Entity(e_type='ORG', start_offset=0, end_offset=1), Entity(e_type='ORG', start_offset=14, end_offset=15), Entity(e_type='ORG', start_offset=19, end_offset=20), Entity(e_type='ORG', start_offset=-1, end_offset=-1)]

True:  19971226042729I
[Entity(e_type='ORG', start_offset=0, end_offset=2), Entity(e_type='ORG', start_offset=-1, end_offset=-1), Entity(e_type='ORG', start_offset=5, end_offset=6), Entity(e_type='ORG', start_offset=7, end_offset=8), Entity(e_type='ORG', start_offset=11, end_offset=13), Entity(e_type='ORG', start_offset=21, end_offset=22)]

True:  20050713014239I
[Entity(e_type='ORG', start_offset=-1, end_offset=-1), Entity(e_type='ORG', start_offset=10, end_offset=12), Entity(e_type='ORG', start_offset=-2, end_offset=-1), Entity(e_type='ORG', start_offset=17, end_offset=18), Entity(e_type='ORG', start_offset=19, end_offset=20), Entity(e_type='ORG', start_offset=22, end_offset=23), Entity(e_type='ORG', start_offset=25, end_offset=27)]

True:  1987031

**Now, fix errors**

NEEDS attn in future, temp fix

In [744]:
doc_id = '19980620030289I'

In [745]:
print(faa[doc_id])

{0: 'MR.', 1: 'KADERA', 2: 'THEN', 3: 'ATTEMPTED', 4: 'TO', 5: 'LAND', 6: 'IN', 7: 'A', 8: 'FIELD', 9: 'BUT', 10: 'WAS', 11: 'FORCED', 12: 'TO', 13: 'LAND', 14: 'ON', 15: 'HIGHWAY', 16: '93', 17: '.', 18: 'THREE', 19: 'MILES', 20: 'EAST', 21: 'OF', 22: 'SUNMER', 23: ',', 24: 'IOWA'}


In [746]:
#result_df[result_df[id_col]==doc_id]

In [747]:
gold_df[gold_df['id']==doc_id]

Unnamed: 0,id,sample,entities,labels
439,19980620030289I,MR. KADERA THEN ATTEMPTED TO LAND IN A FIELD B...,MR. KADERA,ORG
440,19980620030289I,MR. KADERA THEN ATTEMPTED TO LAND IN A FIELD B...,FIELD,ORG
441,19980620030289I,MR. KADERA THEN ATTEMPTED TO LAND IN A FIELD B...,HIGHWAY 93,ORG
442,19980620030289I,MR. KADERA THEN ATTEMPTED TO LAND IN A FIELD B...,THREE,ORG
443,19980620030289I,MR. KADERA THEN ATTEMPTED TO LAND IN A FIELD B...,EAST,ORG
444,19980620030289I,MR. KADERA THEN ATTEMPTED TO LAND IN A FIELD B...,"SUNMER, IOWA",ORG


In [742]:
all_true_ents.index([Entity(e_type='ORG', start_offset=0, end_offset=2), Entity(e_type='ORG', start_offset=8, end_offset=9), Entity(e_type='ORG', start_offset=15, end_offset=17), Entity(e_type='ORG', start_offset=18, end_offset=19), Entity(e_type='ORG', start_offset=20, end_offset=21), Entity(e_type='ORG', start_offset=-1, end_offset=-1)])

89

In [794]:
all_true_ents[0] = [Entity(e_type='ORG', start_offset=0, end_offset=1), Entity(e_type='ORG', start_offset=14, end_offset=15), Entity(e_type='ORG', start_offset=19, end_offset=20), Entity(e_type='ORG', start_offset=8, end_offset=10)]
all_true_ents[13] = [Entity(e_type='ORG', start_offset=0, end_offset=2), Entity(e_type='ORG', start_offset=1, end_offset=3), Entity(e_type='ORG', start_offset=5, end_offset=6), Entity(e_type='ORG', start_offset=7, end_offset=8), Entity(e_type='ORG', start_offset=11, end_offset=13), Entity(e_type='ORG', start_offset=21, end_offset=22)]
all_true_ents[24] = [Entity(e_type='ORG', start_offset=4, end_offset=8), Entity(e_type='ORG', start_offset=10, end_offset=12), Entity(e_type='ORG', start_offset=14, end_offset=17), Entity(e_type='ORG', start_offset=17, end_offset=18), Entity(e_type='ORG', start_offset=19, end_offset=20), Entity(e_type='ORG', start_offset=22, end_offset=23), Entity(e_type='ORG', start_offset=25, end_offset=27)]
all_true_ents[36] = [Entity(e_type='ORG', start_offset=6, end_offset=7), Entity(e_type='ORG', start_offset=11, end_offset=12), Entity(e_type='ORG', start_offset=0, end_offset=2)]
all_true_ents[40] = [Entity(e_type='ORG', start_offset=3, end_offset=5), Entity(e_type='ORG', start_offset=8, end_offset=9), Entity(e_type='ORG', start_offset=10, end_offset=12), Entity(e_type='ORG', start_offset=13, end_offset=15), Entity(e_type='ORG', start_offset=16, end_offset=19)]
all_true_ents[54] = [Entity(e_type='ORG', start_offset=3, end_offset=4), Entity(e_type='ORG', start_offset=6, end_offset=8), Entity(e_type='ORG', start_offset=12, end_offset=14), Entity(e_type='ORG', start_offset=15, end_offset=16), Entity(e_type='ORG', start_offset=21, end_offset=22), Entity(e_type='ORG', start_offset=19, end_offset=20)]
all_true_ents[65] = [Entity(e_type='ORG', start_offset=3, end_offset=7), Entity(e_type='ORG', start_offset=8, end_offset=9), Entity(e_type='ORG', start_offset=13, end_offset=14), Entity(e_type='ORG', start_offset=17, end_offset=21), Entity(e_type='ORG', start_offset=23, end_offset=24)]
all_true_ents[70] = [Entity(e_type='ORG', start_offset=4, end_offset=5), Entity(e_type='ORG', start_offset=6, end_offset=9), Entity(e_type='ORG', start_offset=10, end_offset=11), Entity(e_type='ORG', start_offset=12, end_offset=15), Entity(e_type='ORG', start_offset=16, end_offset=20)]
all_true_ents[76] = [Entity(e_type='ORG', start_offset=5, end_offset=6), Entity(e_type='ORG', start_offset=7, end_offset=12), Entity(e_type='ORG', start_offset=13, end_offset=17), Entity(e_type='ORG', start_offset=19, end_offset=22), Entity(e_type='ORG', start_offset=24, end_offset=25), Entity(e_type='ORG', start_offset=26, end_offset=27), Entity(e_type='ORG', start_offset=28, end_offset=29)]
all_true_ents[80] = [Entity(e_type='ORG', start_offset=1, end_offset=3), Entity(e_type='ORG', start_offset=5, end_offset=6), Entity(e_type='ORG', start_offset=8, end_offset=10), Entity(e_type='ORG', start_offset=11, end_offset=13), Entity(e_type='ORG', start_offset=14, end_offset=15), Entity(e_type='ORG', start_offset=19, end_offset=21), Entity(e_type='ORG', start_offset=24, end_offset=25)]
all_true_ents[82] = [Entity(e_type='ORG', start_offset=3, end_offset=4), Entity(e_type='ORG', start_offset=5, end_offset=6), Entity(e_type='ORG', start_offset=10, end_offset=14), Entity(e_type='ORG', start_offset=15, end_offset=18), Entity(e_type='ORG', start_offset=21, end_offset=22), Entity(e_type='ORG', start_offset=22, end_offset=23)]
all_true_ents[89] =[Entity(e_type='ORG', start_offset=0, end_offset=2), Entity(e_type='ORG', start_offset=8, end_offset=9), Entity(e_type='ORG', start_offset=15, end_offset=17), Entity(e_type='ORG', start_offset=18, end_offset=19), Entity(e_type='ORG', start_offset=20, end_offset=21), Entity(e_type='ORG', start_offset=22, end_offset=25)]
all_true_ents[90] = [Entity(e_type='ORG', start_offset=4, end_offset=8), Entity(e_type='ORG', start_offset=10, end_offset=15), Entity(e_type='ORG', start_offset=18, end_offset=21), Entity(e_type='ORG', start_offset=22, end_offset=23)]
all_true_ents[93] = [Entity(e_type='ORG', start_offset=4, end_offset=8), Entity(e_type='ORG', start_offset=10, end_offset=14), Entity(e_type='ORG', start_offset=16, end_offset=18), Entity(e_type='ORG', start_offset=19, end_offset=20), Entity(e_type='ORG', start_offset=23, end_offset=25), Entity(e_type='ORG', start_offset=25, end_offset=26)]

In [508]:
# ACE ENTS (NLTK RESTRICTED)
#all_true_ents[5] = [Entity(e_type='PER', start_offset=3, end_offset=7), Entity(e_type='PER', start_offset=10, end_offset=21), Entity(e_type='ORG', start_offset=15, end_offset=16)]
#all_true_ents[31] = [Entity(e_type='PER', start_offset=11, end_offset=12), Entity(e_type='LOC', start_offset=9, end_offset=10)]
#all_true_ents[33] = [Entity(e_type='PER', start_offset=0, end_offset=1), Entity(e_type='LOC', start_offset=3, end_offset=9), Entity(e_type='FAC', start_offset=8, end_offset=9)]
#all_true_ents[40] = [Entity(e_type='PER', start_offset=10, end_offset=12), Entity(e_type='FAC', start_offset=13, end_offset=14), Entity(e_type='GPE', start_offset=16, end_offset=19)]
#all_true_ents[49] = [Entity(e_type='FAC', start_offset=3, end_offset=7), Entity(e_type='LOC', start_offset=15, end_offset=18), Entity(e_type='FAC', start_offset=17, end_offset=18)]
#all_true_ents[51] = [Entity(e_type='PER', start_offset=2, end_offset=9), Entity(e_type='ORG', start_offset=5, end_offset=6), Entity(e_type='FAC', start_offset=14, end_offset=17)]
#all_true_ents[52] = [Entity(e_type='LOC', start_offset=16, end_offset=20), Entity(e_type='FAC', start_offset=19, end_offset=20)]
#all_true_ents[70] = [Entity(e_type='FAC', start_offset=6, end_offset=12), Entity(e_type='FAC', start_offset=10, end_offset=11), Entity(e_type='GPE', start_offset=12, end_offset=13), Entity(e_type='GPE', start_offset=14, end_offset=15)]
#all_true_ents[76] = [Entity(e_type='GPE', start_offset=7, end_offset=9), Entity(e_type='GPE', start_offset=10, end_offset=11), Entity(e_type='GPE', start_offset=13, end_offset=15), Entity(e_type='GPE', start_offset=16, end_offset=17), Entity(e_type='LOC', start_offset=27, end_offset=29), Entity(e_type='GPE', start_offset=28, end_offset=29)]
#all_true_ents[78] = [Entity(e_type='PER', start_offset=3, end_offset=10), Entity(e_type='PER', start_offset=8, end_offset=9)]
#all_true_ents[89] = [Entity(e_type='PER', start_offset=0, end_offset=2), Entity(e_type='FAC', start_offset=15, end_offset=17), Entity(e_type='LOC', start_offset=18, end_offset=25), Entity(e_type='GPE', start_offset=22, end_offset=23), Entity(e_type='GPE', start_offset=24, end_offset=25)]

In [734]:
# SPACY SM ENTS
all_pred_ents[24] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='DATE', start_offset=10, end_offset=11), Entity(e_type='ORG', start_offset=19, end_offset=20)]
all_pred_ents[40] = [Entity(e_type='PERSON', start_offset=16, end_offset=18)]
all_pred_ents[65] = [Entity(e_type='DATE', start_offset=3, end_offset=7), Entity(e_type='ORG', start_offset=8, end_offset=9)]
all_pred_ents[90] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='TIME', start_offset=10, end_offset=14), Entity(e_type='ORG', start_offset=17, end_offset=18), Entity(e_type='PERSON', start_offset=24, end_offset=25)]
all_pred_ents[93] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='TIME', start_offset=10, end_offset=12), Entity(e_type='DATE', start_offset=17, end_offset=18)]

In [766]:
# SPACY LG ENTS
all_pred_ents[10] = [Entity(e_type='ORG', start_offset=13, end_offset=16)]
all_pred_ents[24] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='ORG', start_offset=14, end_offset=18), Entity(e_type='ORG', start_offset=19, end_offset=20)]
all_pred_ents[65] = [Entity(e_type='DATE', start_offset=3, end_offset=7), Entity(e_type='ORG', start_offset=8, end_offset=9), Entity(e_type='ORG', start_offset=12, end_offset=14), Entity(e_type='ORG', start_offset=23, end_offset=24)]
all_pred_ents[90] = [Entity(e_type='ORG', start_offset=1, end_offset=2), Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='ORG', start_offset=17, end_offset=18), Entity(e_type='PERSON', start_offset=22, end_offset=23), Entity(e_type='ORG', start_offset=24, end_offset=25)]
all_pred_ents[93] = [Entity(e_type='ORG', start_offset=1, end_offset=2), Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='ORG', start_offset=11, end_offset=14), Entity(e_type='DATE', start_offset=17, end_offset=18), Entity(e_type='PERSON', start_offset=19, end_offset=20)]

In [776]:
# STANZA ENTS
all_pred_ents[18] = [Entity(e_type='CARDINAL', start_offset=1, end_offset=2), Entity(e_type='ORG', start_offset=6, end_offset=7)]
all_pred_ents[24] = [Entity(e_type='DATE', start_offset=4, end_offset=11), Entity(e_type='TIME', start_offset=11, end_offset=12), Entity(e_type='ORG', start_offset=13, end_offset=18), Entity(e_type='ORG', start_offset=19, end_offset=20)]
all_pred_ents[54] = [Entity(e_type='CARDINAL', start_offset=1, end_offset=2), Entity(e_type='PRODUCT', start_offset=6, end_offset=8)]
all_pred_ents[65] = [Entity(e_type='DATE', start_offset=3, end_offset=7)]
all_pred_ents[90] = [Entity(e_type='CARDINAL', start_offset=1, end_offset=2), Entity(e_type='DATE', start_offset=4, end_offset=11)]
all_pred_ents[93] = [Entity(e_type='CARDINAL', start_offset=1, end_offset=2), Entity(e_type='DATE', start_offset=4, end_offset=11)]

In [583]:
# ON ENTS
all_true_ents[24] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='TIME', start_offset=10, end_offset=12),Entity(e_type='ORG', start_offset=14, end_offset=17),Entity(e_type='PRODUCT', start_offset=17, end_offset=18), Entity(e_type='PRODUCT', start_offset=19, end_offset=20), Entity(e_type='ORG', start_offset=25, end_offset=27)]
all_true_ents[40] = [Entity(e_type='PER', start_offset=10, end_offset=12), Entity(e_type='GPE', start_offset=16, end_offset=19)]
all_true_ents[65] = [Entity(e_type='DATE', start_offset=3, end_offset=7)]
all_true_ents[90] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='TIME', start_offset=10, end_offset=15), Entity(e_type='ORG', start_offset=18, end_offset=19), Entity(e_type='PRODUCT', start_offset=19, end_offset=21)]
all_true_ents[93] = [Entity(e_type='DATE', start_offset=4, end_offset=8), Entity(e_type='TIME', start_offset=10, end_offset=14), Entity(e_type='ORG', start_offset=16, end_offset=17), Entity(e_type='PRODUCT', start_offset=17, end_offset=18), Entity(e_type='ORG', start_offset=23, end_offset=25)]

In [785]:
# FLAIR ENTS
#all_true_ents[40] = [Entity(e_type='PER', start_offset=10, end_offset=12), Entity(e_type='LOC', start_offset=16, end_offset=19)]
all_pred_ents[40] = [Entity(e_type='PER', start_offset=11, end_offset=12), Entity(e_type='LOC', start_offset=16, end_offset=19)]

In [795]:
# Check:
for true_ents in all_true_ents:
    if True in [ent[2] < 0 for ent in true_ents]:
        print("\nTrue: ")
        print(true_ents)
for pred_ents in all_pred_ents:
    if True in [ent[2] < 0 for ent in pred_ents]:
        print("\nPred: ")
        print(pred_ents)

### Apply Example

In [750]:
conll_tags = ['PER', 'ORG', 'MISC', 'LOC']
ace_tags = ['PER','ORG','LOC','FAC','GPE'] # RESTRICTED SET
on_tags = ['PER','ORG','LOC','FAC','GPE','PRODUCT','NORP','QUANTITY','EVENT','WORK_OF_ART','CARDINAL','DATE','PERCENT','TIME','ORDINAL','MONEY','LAW','LANGUAGE']

In [796]:
tags = ace_tags

metrics_results = {'correct': 0, 'incorrect': 0, 'partial': 0,
                   'missed': 0, 'spurious': 0, 'possible': 0, 'actual': 0, 'precision': 0, 'recall': 0}

# overall results
results = {'strict': deepcopy(metrics_results),
           'ent_type': deepcopy(metrics_results),
           'partial':deepcopy(metrics_results),
           'exact':deepcopy(metrics_results)
          }


# results aggregated by entity type
evaluation_agg_entities_type = {e: deepcopy(results) for e in tags}

for true_ents, pred_ents in zip(all_true_ents, all_pred_ents):
    
    # compute results for one message
    tmp_results, tmp_agg_results = compute_metrics(
        true_ents, pred_ents,  tags
    )
    
    #print(tmp_results)

    # aggregate overall results
    for eval_schema in results.keys():
        for metric in metrics_results.keys():
            results[eval_schema][metric] += tmp_results[eval_schema][metric]
            
    # Calculate global precision and recall
        
    results = compute_precision_recall_wrapper(results)


    # aggregate results by entity type
 
    for e_type in tags:

        for eval_schema in tmp_agg_results[e_type]:

            for metric in tmp_agg_results[e_type][eval_schema]:
                
                evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]
                
        # Calculate precision recall at the individual entity level
                
        evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(evaluation_agg_entities_type[e_type])

In [759]:
def print_results_labeled(tool_name, results):

    scores = {'exact':0.0,'strict':0.0,'partial':0.0,'ent_type':0.0}
    for score in scores:
        prec = results[score]['precision']
        rec = results[score]['recall']
        scores[score] = 2*prec*rec/(prec+rec)    

    print('|                                         | Strict  | Exact  | Partial  | Type    |')
    print('|-----------------------------------------|---------|--------|----------|---------|')
    print(f"| {tool_name:40}| {scores['strict']:.4}  | {scores['exact']:.4} | {scores['partial']:.4}   | {scores['ent_type']:.4}  |")

In [760]:
def print_results_unlabeled(tool_name, results):
    scores = {'exact':0.0,'partial':0.0}
    for score in scores:
        prec = results[score]['precision']
        rec = results[score]['recall']
        scores[score] = {'prec':prec, 'rec':rec, 'f1':2*prec*rec/(prec+rec)}

    print('|                                         | Precision (Weak) | Recall (Weak) | F1 (Weak)     | Precision (Strong) | Recall (Strong) | F1 (Strong) |')
    print('|-----------------------------------------|-----------|---------|---------|------------------|---------------|-----------|')
    print(f"| {tool_name:40}| {scores['partial']['prec']:.4}  | {scores['partial']['rec']:.4} | {scores['partial']['f1']:.4}   | {scores['exact']['prec']:.4}  | {scores['exact']['rec']:.4}  | {scores['exact']['f1']:.4}  |")

In [797]:
print_results_unlabeled('nltk (uppercased)',results)

|                                         | Precision (Weak) | Recall (Weak) | F1 (Weak)     | Precision (Strong) | Recall (Strong) | F1 (Strong) |
|-----------------------------------------|-----------|---------|---------|------------------|---------------|-----------|
| nltk (uppercased)                       | 0.3996  | 0.3533 | 0.375   | 0.286  | 0.2529  | 0.2684  |
