In [1]:
import nltk
import sklearn_crfsuite

from copy import deepcopy
from collections import defaultdict

from sklearn_crfsuite.metrics import flat_classification_report

from ner_evaluation.ner_eval import collect_named_entities
from ner_evaluation.ner_eval import compute_metrics
from ner_evaluation.ner_eval import compute_precision_recall_wrapper

## Train a CRF on the CoNLL 2002 NER Spanish data

In [2]:
nltk.corpus.conll2002.fileids()
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

In [3]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]


def sent2labels(sent):
    return [label for token, postag, label in sent]


def sent2tokens(sent):
    return [token for token, postag, label in sent]

## Feature Extraction

In [4]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 761 ms, sys: 48.4 ms, total: 809 ms
Wall time: 809 ms


## Training

In [5]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 28.7 s, sys: 36.4 ms, total: 28.7 s
Wall time: 28.7 s


## Performance per label type per token

In [25]:
y_pred = crf.predict(X_test)
labels = list(crf.classes_)
labels.remove('O') # remove 'O' label from evaluation
sorted_labels = sorted(labels,key=lambda name: (name[1:], name[0])) # group B and I results
print(flat_classification_report(y_test, y_pred, labels=sorted_labels, digits=3))

              precision    recall  f1-score   support

       B-LOC      0.810     0.784     0.797      1084
       I-LOC      0.690     0.637     0.662       325
      B-MISC      0.731     0.569     0.640       339
      I-MISC      0.699     0.589     0.639       557
       B-ORG      0.807     0.832     0.820      1400
       I-ORG      0.852     0.786     0.818      1104
       B-PER      0.850     0.884     0.867       735
       I-PER      0.893     0.943     0.917       634

   micro avg      0.813     0.787     0.799      6178
   macro avg      0.791     0.753     0.770      6178
weighted avg      0.809     0.787     0.796      6178



## Performance over full named-entity

In [26]:
test_sents_labels = []
for sentence in test_sents:
    sentence = [token[2] for token in sentence]
    test_sents_labels.append(sentence)

In [27]:
index = 2
true = collect_named_entities(test_sents_labels[index])
pred = collect_named_entities(y_pred[index])

In [28]:
true

[Entity(e_type='MISC', start_offset=12, end_offset=12),
 Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='PER', start_offset=37, end_offset=39),
 Entity(e_type='ORG', start_offset=45, end_offset=46)]

In [29]:
pred

[Entity(e_type='MISC', start_offset=12, end_offset=12),
 Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='PER', start_offset=37, end_offset=39),
 Entity(e_type='LOC', start_offset=45, end_offset=46)]

In [30]:
index = 2
true = collect_named_entities(test_sents_labels[index])
pred = collect_named_entities(y_pred[index])

In [31]:
true

[Entity(e_type='MISC', start_offset=12, end_offset=12),
 Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='PER', start_offset=37, end_offset=39),
 Entity(e_type='ORG', start_offset=45, end_offset=46)]

In [32]:
pred

[Entity(e_type='MISC', start_offset=12, end_offset=12),
 Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='PER', start_offset=37, end_offset=39),
 Entity(e_type='LOC', start_offset=45, end_offset=46)]

In [33]:
compute_metrics(true, pred, ['LOC', 'MISC', 'PER', 'ORG'])

({'strict': {'correct': 3,
   'incorrect': 1,
   'partial': 0,
   'missed': 0,
   'spurious': 0,
   'precision': 0,
   'recall': 0,
   'actual': 4,
   'possible': 4},
  'ent_type': {'correct': 3,
   'incorrect': 1,
   'partial': 0,
   'missed': 0,
   'spurious': 0,
   'precision': 0,
   'recall': 0,
   'actual': 4,
   'possible': 4},
  'partial': {'correct': 4,
   'incorrect': 0,
   'partial': 0,
   'missed': 0,
   'spurious': 0,
   'precision': 0,
   'recall': 0,
   'actual': 4,
   'possible': 4},
  'exact': {'correct': 4,
   'incorrect': 0,
   'partial': 0,
   'missed': 0,
   'spurious': 0,
   'precision': 0,
   'recall': 0,
   'actual': 4,
   'possible': 4}},
 {'LOC': {'strict': {'correct': 1,
    'incorrect': 0,
    'partial': 0,
    'missed': 0,
    'spurious': 0,
    'precision': 0,
    'recall': 0,
    'actual': 1,
    'possible': 1},
   'ent_type': {'correct': 1,
    'incorrect': 0,
    'partial': 0,
    'missed': 0,
    'spurious': 0,
    'precision': 0,
    'recall': 0,
    '

In [34]:
to_test = [2,4,12,14]

In [35]:
index = 2
true_named_entities_type = defaultdict(list)
pred_named_entities_type = defaultdict(list)

for true in collect_named_entities(test_sents_labels[index]):
    true_named_entities_type[true.e_type].append(true)

for pred in collect_named_entities(y_pred[index]):
    pred_named_entities_type[pred.e_type].append(pred)

In [36]:
true_named_entities_type

defaultdict(list,
            {'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],
             'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15)],
             'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)],
             'ORG': [Entity(e_type='ORG', start_offset=45, end_offset=46)]})

In [37]:
pred_named_entities_type

defaultdict(list,
            {'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],
             'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15),
              Entity(e_type='LOC', start_offset=45, end_offset=46)],
             'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)]})

In [38]:
true_named_entities_type['LOC']

[Entity(e_type='LOC', start_offset=15, end_offset=15)]

In [39]:
pred_named_entities_type['LOC']

[Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='LOC', start_offset=45, end_offset=46)]

In [40]:
compute_metrics(true_named_entities_type['LOC'], pred_named_entities_type['LOC'], ['LOC', 'MISC', 'PER', 'ORG'])

({'strict': {'correct': 1,
   'incorrect': 0,
   'partial': 0,
   'missed': 0,
   'spurious': 1,
   'precision': 0,
   'recall': 0,
   'actual': 2,
   'possible': 1},
  'ent_type': {'correct': 1,
   'incorrect': 0,
   'partial': 0,
   'missed': 0,
   'spurious': 1,
   'precision': 0,
   'recall': 0,
   'actual': 2,
   'possible': 1},
  'partial': {'correct': 1,
   'incorrect': 0,
   'partial': 0,
   'missed': 0,
   'spurious': 1,
   'precision': 0,
   'recall': 0,
   'actual': 2,
   'possible': 1},
  'exact': {'correct': 1,
   'incorrect': 0,
   'partial': 0,
   'missed': 0,
   'spurious': 1,
   'precision': 0,
   'recall': 0,
   'actual': 2,
   'possible': 1}},
 {'LOC': {'strict': {'correct': 1,
    'incorrect': 0,
    'partial': 0,
    'missed': 0,
    'spurious': 1,
    'precision': 0,
    'recall': 0,
    'actual': 2,
    'possible': 1},
   'ent_type': {'correct': 1,
    'incorrect': 0,
    'partial': 0,
    'missed': 0,
    'spurious': 1,
    'precision': 0,
    'recall': 0,
    '

## results over all messages

In [42]:
metrics_results = {'correct': 0, 'incorrect': 0, 'partial': 0,
                   'missed': 0, 'spurious': 0, 'possible': 0, 'actual': 0, 'precision': 0, 'recall': 0}

# overall results
results = {'strict': deepcopy(metrics_results),
           'ent_type': deepcopy(metrics_results),
           'partial':deepcopy(metrics_results),
           'exact':deepcopy(metrics_results)
          }


# results aggregated by entity type
evaluation_agg_entities_type = {e: deepcopy(results) for e in ['PER', 'LOC', 'MISC', 'ORG']}

for true_ents, pred_ents in zip(test_sents_labels, y_pred):
    
    # compute results for one message
    tmp_results, tmp_agg_results = compute_metrics(
        collect_named_entities(true_ents), collect_named_entities(pred_ents),  ['LOC', 'MISC', 'PER', 'ORG']
    )
    
    #print(tmp_results)

    # aggregate overall results
    for eval_schema in results.keys():
        for metric in metrics_results.keys():
            results[eval_schema][metric] += tmp_results[eval_schema][metric]
            
    # Calculate global precision and recall
        
    results = compute_precision_recall_wrapper(results)


    # aggregate results by entity type
 
    for e_type in ['PER', 'LOC', 'MISC', 'ORG']:

        for eval_schema in tmp_agg_results[e_type]:

            for metric in tmp_agg_results[e_type][eval_schema]:
                
                evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]
                
        # Calculate precision recall at the individual entity level
                
        evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(evaluation_agg_entities_type[e_type])
                
    

In [43]:
results

{'ent_type': {'correct': 2860,
  'incorrect': 523,
  'partial': 0,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.8120386144236229,
  'recall': 0.8035965158752458},
 'partial': {'correct': 3278,
  'incorrect': 0,
  'partial': 105,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.9456274843838728,
  'recall': 0.9357965720708064},
 'strict': {'correct': 2783,
  'incorrect': 600,
  'partial': 0,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.7901760363429869,
  'recall': 0.78196122506322},
 'exact': {'correct': 3278,
  'incorrect': 105,
  'partial': 0,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.9307211811470755,
  'recall': 0.9210452374262433}}

In [44]:
evaluation_agg_entities_type

{'PER': {'ent_type': {'correct': 651,
   'incorrect': 67,
   'partial': 0,
   'missed': 17,
   'spurious': 139,
   'possible': 735,
   'actual': 857,
   'precision': 0.7596266044340724,
   'recall': 0.8857142857142857},
  'partial': {'correct': 711,
   'incorrect': 0,
   'partial': 7,
   'missed': 17,
   'spurious': 139,
   'possible': 735,
   'actual': 857,
   'precision': 0.8337222870478413,
   'recall': 0.972108843537415},
  'strict': {'correct': 646,
   'incorrect': 72,
   'partial': 0,
   'missed': 17,
   'spurious': 139,
   'possible': 735,
   'actual': 857,
   'precision': 0.7537922987164527,
   'recall': 0.8789115646258503},
  'exact': {'correct': 711,
   'incorrect': 7,
   'partial': 0,
   'missed': 17,
   'spurious': 139,
   'possible': 735,
   'actual': 857,
   'precision': 0.8296382730455076,
   'recall': 0.9673469387755103}},
 'LOC': {'ent_type': {'correct': 855,
   'incorrect': 180,
   'partial': 0,
   'missed': 49,
   'spurious': 139,
   'possible': 1084,
   'actual': 11

In [45]:
from ner_evaluation.ner_eval import Evaluator

In [46]:
evaluator = Evaluator(test_sents_labels, y_pred, ['LOC', 'MISC', 'PER', 'ORG'])

In [47]:
results, results_agg = evaluator.evaluate()

2019-03-12 12:00:31 root INFO: Imported 1517 predictions for 1517 true examples


In [48]:
results

{'ent_type': {'correct': 2860,
  'incorrect': 523,
  'partial': 0,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.8120386144236229,
  'recall': 0.8035965158752458},
 'partial': {'correct': 3278,
  'incorrect': 0,
  'partial': 105,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.9456274843838728,
  'recall': 0.9357965720708064},
 'strict': {'correct': 2783,
  'incorrect': 600,
  'partial': 0,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.7901760363429869,
  'recall': 0.78196122506322},
 'exact': {'correct': 3278,
  'incorrect': 105,
  'partial': 0,
  'missed': 176,
  'spurious': 139,
  'possible': 3559,
  'actual': 3522,
  'precision': 0.9307211811470755,
  'recall': 0.9210452374262433}}

In [49]:
results_agg

{'LOC': {'ent_type': {'correct': 855,
   'incorrect': 180,
   'partial': 0,
   'missed': 49,
   'spurious': 139,
   'possible': 1084,
   'actual': 1174,
   'precision': 0.7282793867120954,
   'recall': 0.7887453874538746},
  'partial': {'correct': 1016,
   'incorrect': 0,
   'partial': 19,
   'missed': 49,
   'spurious': 139,
   'possible': 1084,
   'actual': 1174,
   'precision': 0.8735093696763203,
   'recall': 0.9460332103321033},
  'strict': {'correct': 844,
   'incorrect': 191,
   'partial': 0,
   'missed': 49,
   'spurious': 139,
   'possible': 1084,
   'actual': 1174,
   'precision': 0.7189097103918228,
   'recall': 0.7785977859778598},
  'exact': {'correct': 1016,
   'incorrect': 19,
   'partial': 0,
   'missed': 49,
   'spurious': 139,
   'possible': 1084,
   'actual': 1174,
   'precision': 0.8654173764906303,
   'recall': 0.9372693726937269}},
 'MISC': {'ent_type': {'correct': 200,
   'incorrect': 89,
   'partial': 0,
   'missed': 51,
   'spurious': 139,
   'possible': 340,
 