In [1]:
import nltk
import sklearn_crfsuite

from copy import deepcopy
from collections import defaultdict

from sklearn_crfsuite import metrics
from ner_evaluation import collect_named_entities
from ner_evaluation import compute_metrics
from ner_evaluation import compute_metrics_by_type

## Train a CRF on the CoNLL 2002 NER Spanish data

In [2]:
nltk.corpus.conll2002.fileids()
train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))

In [3]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

## Feature Extraction

In [4]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 1.09 s, sys: 88.7 ms, total: 1.18 s
Wall time: 1.18 s


## Training

In [5]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 34.7 s, sys: 141 ms, total: 34.8 s
Wall time: 34.8 s


## Performance per label type per token

In [6]:
y_pred = crf.predict(X_test)
labels = list(crf.classes_)
labels.remove('O') # remove 'O' label from evaluation
sorted_labels = sorted(labels,key=lambda name: (name[1:], name[0])) # group B and I results
print(sklearn_crfsuite.metrics.flat_classification_report(y_test, y_pred, labels=sorted_labels, digits=3))

             precision    recall  f1-score   support

      B-LOC      0.810     0.784     0.797      1084
      I-LOC      0.690     0.637     0.662       325
     B-MISC      0.731     0.569     0.640       339
     I-MISC      0.699     0.589     0.639       557
      B-ORG      0.807     0.832     0.820      1400
      I-ORG      0.852     0.786     0.818      1104
      B-PER      0.850     0.884     0.867       735
      I-PER      0.893     0.943     0.917       634

avg / total      0.809     0.787     0.796      6178



## Performance over full named-entity

In [7]:
test_sents_labels = []
for sentence in test_sents:
    sentence = [token[2] for token in sentence]
    test_sents_labels.append(sentence)

In [8]:
index = 2
true = collect_named_entities(test_sents_labels[index])
pred = collect_named_entities(y_pred[index])

In [9]:
true

[Entity(e_type='MISC', start_offset=12, end_offset=12),
 Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='PER', start_offset=37, end_offset=39),
 Entity(e_type='ORG', start_offset=45, end_offset=46)]

In [10]:
pred

[Entity(e_type='MISC', start_offset=12, end_offset=12),
 Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='PER', start_offset=37, end_offset=39),
 Entity(e_type='LOC', start_offset=45, end_offset=46)]

In [11]:
compute_metrics(true, pred)

{'ent_type': {'actual': 4,
  'correct': 3,
  'incorrect': 1,
  'missed': 0,
  'partial': 0,
  'possible': 4,
  'precision': 0.75,
  'recall': 0.75,
  'spurius': 0},
 'exact_matching': {'actual': 4,
  'correct': 4,
  'incorrect': 0,
  'missed': 0,
  'partial': 0,
  'possible': 4,
  'precision': 1.0,
  'recall': 1.0,
  'spurius': 0},
 'partial_matching': {'actual': 4,
  'correct': 4,
  'incorrect': 0,
  'missed': 0,
  'partial': 0,
  'possible': 4,
  'precision': 1.0,
  'recall': 1.0,
  'spurius': 0},
 'strict': {'actual': 4,
  'correct': 3,
  'incorrect': 1,
  'missed': 0,
  'partial': 0,
  'possible': 4,
  'precision': 0.75,
  'recall': 0.75,
  'spurius': 0}}

In [12]:
to_test = [2,4,12,14]

In [13]:
index = 2
true_named_entities_type = defaultdict(list)
pred_named_entities_type = defaultdict(list)

for true in collect_named_entities(test_sents_labels[index]):
    true_named_entities_type[true.e_type].append(true)

for pred in collect_named_entities(y_pred[index]):
    pred_named_entities_type[pred.e_type].append(pred)

In [14]:
true_named_entities_type

defaultdict(list,
            {'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15)],
             'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],
             'ORG': [Entity(e_type='ORG', start_offset=45, end_offset=46)],
             'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)]})

In [15]:
pred_named_entities_type

defaultdict(list,
            {'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15),
              Entity(e_type='LOC', start_offset=45, end_offset=46)],
             'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],
             'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)]})

In [16]:
true_named_entities_type['LOC']

[Entity(e_type='LOC', start_offset=15, end_offset=15)]

In [17]:
pred_named_entities_type['LOC']

[Entity(e_type='LOC', start_offset=15, end_offset=15),
 Entity(e_type='LOC', start_offset=45, end_offset=46)]

In [18]:
compute_metrics(true_named_entities_type['LOC'], pred_named_entities_type['LOC'])

{'ent_type': {'actual': 2,
  'correct': 1,
  'incorrect': 0,
  'missed': 0,
  'partial': 0,
  'possible': 1,
  'precision': 0.5,
  'recall': 1.0,
  'spurius': 1},
 'exact_matching': {'actual': 2,
  'correct': 1,
  'incorrect': 0,
  'missed': 0,
  'partial': 0,
  'possible': 1,
  'precision': 0.5,
  'recall': 1.0,
  'spurius': 1},
 'partial_matching': {'actual': 2,
  'correct': 1,
  'incorrect': 0,
  'missed': 0,
  'partial': 0,
  'possible': 1,
  'precision': 0.5,
  'recall': 1.0,
  'spurius': 1},
 'strict': {'actual': 2,
  'correct': 1,
  'incorrect': 0,
  'missed': 0,
  'partial': 0,
  'possible': 1,
  'precision': 0.5,
  'recall': 1.0,
  'spurius': 1}}

## results over all messages

In [19]:
metrics = {'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 0, 'spurius': 0, 'possible': 0, 'actual': 0}
results = {'strict': deepcopy(metrics),
           'exact_matching': deepcopy(metrics), 
           'partial_matching': deepcopy(metrics), 
           'ent_type': deepcopy(metrics)
          }

for true, pred in zip(test_sents_labels, y_pred):
    tmp_results = compute_metrics(collect_named_entities(true),collect_named_entities(pred))
    for eval_schema in results.keys():
        for metric in metrics.keys():
            results[eval_schema][metric] +=  tmp_results[eval_schema][metric]

In [20]:
results

{'ent_type': {'actual': 3518,
  'correct': 2909,
  'incorrect': 564,
  'missed': 106,
  'partial': 0,
  'possible': 3579,
  'spurius': 45},
 'exact_matching': {'actual': 3518,
  'correct': 3274,
  'incorrect': 199,
  'missed': 106,
  'partial': 0,
  'possible': 3579,
  'spurius': 45},
 'partial_matching': {'actual': 3518,
  'correct': 3274,
  'incorrect': 0,
  'missed': 106,
  'partial': 199,
  'possible': 3579,
  'spurius': 45},
 'strict': {'actual': 3518,
  'correct': 2779,
  'incorrect': 694,
  'missed': 106,
  'partial': 0,
  'possible': 3579,
  'spurius': 45}}

## results over all messages by ent_type

In [118]:
entity_types = ['LOC', 'PER', 'MISC', 'ORG']
all_results = compute_results(test_sents_labels, y_pred, entity_types)

In [119]:
all_results['ent_type']

{'LOC': {'correct': 863,
  'incorrect': 0,
  'missed': 124,
  'partial': 0,
  'spurius': 66},
 'MISC': {'correct': 212,
  'incorrect': 0,
  'missed': 43,
  'partial': 0,
  'spurius': 7},
 'ORG': {'correct': 1183,
  'incorrect': 0,
  'missed': 166,
  'partial': 0,
  'spurius': 153},
 'PER': {'correct': 657,
  'incorrect': 0,
  'missed': 46,
  'partial': 0,
  'spurius': 17}}

In [120]:
all_results['strict']

{'LOC': {'correct': 840,
  'incorrect': 23,
  'missed': 124,
  'partial': 0,
  'spurius': 66},
 'MISC': {'correct': 173,
  'incorrect': 39,
  'missed': 43,
  'partial': 0,
  'spurius': 7},
 'ORG': {'correct': 1120,
  'incorrect': 63,
  'missed': 166,
  'partial': 0,
  'spurius': 153},
 'PER': {'correct': 646,
  'incorrect': 11,
  'missed': 46,
  'partial': 0,
  'spurius': 17}}

In [121]:
all_results['exact_matching']

{'LOC': {'correct': 840,
  'incorrect': 23,
  'missed': 124,
  'partial': 0,
  'spurius': 66},
 'MISC': {'correct': 173,
  'incorrect': 39,
  'missed': 43,
  'partial': 0,
  'spurius': 7},
 'ORG': {'correct': 1120,
  'incorrect': 63,
  'missed': 166,
  'partial': 0,
  'spurius': 153},
 'PER': {'correct': 646,
  'incorrect': 11,
  'missed': 46,
  'partial': 0,
  'spurius': 17}}

In [122]:
all_results['partial_matching']

{'LOC': {'correct': 840,
  'incorrect': 0,
  'missed': 124,
  'partial': 23,
  'spurius': 66},
 'MISC': {'correct': 173,
  'incorrect': 0,
  'missed': 43,
  'partial': 39,
  'spurius': 7},
 'ORG': {'correct': 1120,
  'incorrect': 0,
  'missed': 166,
  'partial': 63,
  'spurius': 153},
 'PER': {'correct': 646,
  'incorrect': 0,
  'missed': 46,
  'partial': 11,
  'spurius': 17}}