In [None]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.grid_search import GridSearchCV

import scipy.stats

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

import numpy as np

In [52]:
%%time

import pickle
sents = pickle.load(open('conll2002-kaggle-2000.pickle', 'rb'))
print(len(sents))

split = int(0.8*len(sents))
train_sents = sents[:split]
test_sents = sents[split:]

2000
CPU times: user 60 ms, sys: 0 ns, total: 60 ms
Wall time: 61.2 ms


In [53]:
train_sents[0]

[('Thousands', 'NNS', 'O'),
 ('of', 'IN', 'O'),
 ('demonstrators', 'NNS', 'O'),
 ('have', 'VBP', 'O'),
 ('marched', 'VBN', 'O'),
 ('through', 'IN', 'O'),
 ('London', 'NNP', 'B-geo'),
 ('to', 'TO', 'O'),
 ('protest', 'VB', 'O'),
 ('the', 'DT', 'O'),
 ('war', 'NN', 'O'),
 ('in', 'IN', 'O'),
 ('Iraq', 'NNP', 'B-geo'),
 ('and', 'CC', 'O'),
 ('demand', 'VB', 'O'),
 ('the', 'DT', 'O'),
 ('withdrawal', 'NN', 'O'),
 ('of', 'IN', 'O'),
 ('British', 'JJ', 'B-gpe'),
 ('troops', 'NNS', 'O'),
 ('from', 'IN', 'O'),
 ('that', 'DT', 'O'),
 ('country', 'NN', 'O'),
 ('.', '.', 'O'),
 ('Thousands', 'NNS', 'O'),
 ('of', 'IN', 'O'),
 ('demonstrators', 'NNS', 'O'),
 ('have', 'VBP', 'O'),
 ('marched', 'VBN', 'O'),
 ('through', 'IN', 'O'),
 ('London', 'NNP', 'B-geo'),
 ('to', 'TO', 'O'),
 ('protest', 'VB', 'O'),
 ('the', 'DT', 'O'),
 ('war', 'NN', 'O'),
 ('in', 'IN', 'O'),
 ('Iraq', 'NNP', 'B-geo'),
 ('and', 'CC', 'O'),
 ('demand', 'VB', 'O'),
 ('the', 'DT', 'O'),
 ('withdrawal', 'NN', 'O'),
 ('of', 'IN', 'O'

## Features

Next, define some features. In this example we use word identity, word suffix, word shape and word POS tag; also, some information from nearby words is used. 

This makes a simple baseline, but you certainly can add and remove some features to get (much?) better results - experiment with it.

sklearn-crfsuite (and python-crfsuite) supports several feature formats; here we use feature dicts.

In [54]:
from nltk.stem import WordNetLemmatizer

lemmatizer = WordNetLemmatizer()

def extract_more_features(word):
    has_digits = False
    has_uppers = False
    
    for c in word:
        if c.isdigit():
            has_digits = True
        if c.isupper():
            has_uppers = True
            
    return (has_digits, has_uppers)

def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    t = extract_more_features(word)
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
        'hasDigits': t[0],
        'hasUppers': t[1],
        'firstTwoChars': word[:2],
        'length': len(word),
        'lemma': lemmatizer.lemmatize(word)
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True
    
#     if i > 1:
#         word2 = sent[i-2][0]
#         postag2 = sent[i-2][1]
#         features.update({
#             '-2:word.lower()': word2.lower(),
#             '-2:word.istitle()': word2.istitle(),
#             '-2:word.isupper()': word2.isupper(),
#             '-2:postag': postag2,
#             '-2:postag[:2]': postag2[:2],
#         })
    
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True
        
#     if i < len(sent)-2:
#         word2 = sent[i+2][0]
#         postag2 = sent[i+2][1]
#         features.update({
#             '+2:word.lower()': word2.lower(),
#             '+2:word.istitle()': word2.istitle(),
#             '+2:word.isupper()': word2.isupper(),
#             '+2:postag': postag2,
#             '+2:postag[:2]': postag2[:2],
#         })
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

This is what word2features extracts:

In [55]:
sent2features(train_sents[0])[2]

{'+1:postag': 'VBP',
 '+1:postag[:2]': 'VB',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:word.lower()': 'have',
 '+2:postag': 'VBN',
 '+2:postag[:2]': 'VB',
 '+2:word.istitle()': False,
 '+2:word.isupper()': False,
 '+2:word.lower()': 'marched',
 '-1:postag': 'IN',
 '-1:postag[:2]': 'IN',
 '-1:word.istitle()': False,
 '-1:word.isupper()': False,
 '-1:word.lower()': 'of',
 '-2:postag': 'NNS',
 '-2:postag[:2]': 'NN',
 '-2:word.istitle()': True,
 '-2:word.isupper()': False,
 '-2:word.lower()': 'thousands',
 'bias': 1.0,
 'firstTwoChars': 'de',
 'hasDigits': False,
 'hasUppers': False,
 'lemma': 'demonstrator',
 'length': 13,
 'postag': 'NNS',
 'postag[:2]': 'NN',
 'word.isdigit()': False,
 'word.istitle()': False,
 'word.isupper()': False,
 'word.lower()': 'demonstrators',
 'word[-2:]': 'rs',
 'word[-3:]': 'ors'}

Extract features from the data:

In [56]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 2.62 s, sys: 24 ms, total: 2.64 s
Wall time: 2.64 s


## Training

To see all possible CRF parameters check its docstring. Here we are useing L-BFGS training algorithm (it is default) with Elastic Net (L1 + L2) regularization.

In [None]:
%%time

params = {'c1': scipy.stats.expon(scale=0.5), 'c2': scipy.stats.expon(scale=0.05)}

crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.1, 
    c2=0.1, 
    max_iterations=50, 
    all_possible_transitions=True
)
# crf.fit(X_train, y_train)

gs = RandomizedSearchCV(crf, param_distributions=params)
gs_res = gs.fit(X_train, y_train)

## Evaluation

There is much more O entities in data set, but we're more interested in other entities. To account for this we'll use averaged F1 score computed for all labels except for O. ``sklearn-crfsuite.metrics`` package provides some useful metrics for sequence classification task, including this one.

In [31]:
labels = list(crf.classes_)
labels.remove('O')
# labels
print(gs_res.best_params_)

['B-geo',
 'B-gpe',
 'B-per',
 'I-geo',
 'B-org',
 'I-org',
 'B-tim',
 'B-art',
 'I-art',
 'I-per',
 'I-gpe',
 'I-tim',
 'B-nat',
 'B-eve',
 'I-eve',
 'I-nat']

In [33]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, average='micro', labels=labels)

0.74558996471971772

Inspect per-class results in more detail:

In [34]:
# group B and I results
sorted_labels = sorted(
    labels, 
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

             precision    recall  f1-score   support

      B-art      0.000     0.000     0.000         0
      I-art      0.000     0.000     0.000         0
      B-eve      0.400     0.182     0.250        44
      I-eve      0.000     0.000     0.000        28
      B-geo      0.740     0.770     0.755       996
      I-geo      0.667     0.667     0.667       168
      B-gpe      0.775     0.824     0.799       592
      I-gpe      0.200     0.500     0.286         8
      B-nat      0.000     0.000     0.000         8
      I-nat      0.000     0.000     0.000         0
      B-org      0.565     0.581     0.573       616
      I-org      0.743     0.739     0.741       476
      B-per      0.794     0.706     0.748       596
      I-per      0.845     0.910     0.876       720
      B-tim      0.931     0.771     0.844       704
      I-tim      0.881     0.381     0.532       252

avg / total      0.766     0.730     0.741      5208



  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [35]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
B-eve  -> I-eve   3.690174
I-eve  -> I-eve   3.440612
I-art  -> I-art   3.436977
B-per  -> I-per   3.414614
I-org  -> I-org   3.324302
B-nat  -> I-nat   3.278195
B-org  -> I-org   3.205937
B-geo  -> I-geo   3.064063
B-art  -> I-art   3.045358
I-per  -> I-per   2.669580
B-tim  -> I-tim   2.570981
I-tim  -> I-tim   2.475548
B-gpe  -> I-gpe   2.446508
O      -> O       2.403821
I-geo  -> I-geo   2.189821
I-nat  -> I-nat   2.114438
I-gpe  -> I-gpe   1.796349
B-org  -> B-art   1.496428
O      -> B-tim   0.745189
B-gpe  -> B-org   0.706160

Top unlikely transitions:
B-geo  -> B-org   -2.972659
B-geo  -> I-gpe   -3.028510
B-org  -> I-geo   -3.049104
O      -> I-nat   -3.086647
B-tim  -> B-tim   -3.099261
O      -> I-eve   -3.110203
O      -> I-art   -3.141735
I-org  -> I-per   -3.224418
I-per  -> B-per   -3.309276
B-org  -> I-per   -3.429796
B-geo  -> I-per   -3.482263
B-per  -> B-per   -3.525805
B-geo  -> I-org   -3.563352
O      -> I-gpe   -3.789578
B-gpe  -> B-gpe  

We can see that, for example, it is very likely that the beginning of an organization name (B-ORG) will be followed by a token inside organization name (I-ORG), but transitions to I-ORG from tokens with other labels are penalized.

Check the state features:

In [36]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
4.566423 B-org    -2:word.lower():interest
4.374259 B-gpe    -2:word.lower():concern
4.216193 B-org    -2:word.lower():saturday
3.889579 B-org    +2:word.lower():wrested
3.880429 B-gpe    -2:word.lower():currently
3.871938 B-gpe    +2:word.lower():build
3.725769 O        postag[:2]:VB
3.620111 B-gpe    firstTwoChars:Sw
3.549836 B-tim    word[-3:]:ber
3.523390 B-org    -1:word.lower():observed
3.419723 B-tim    +1:word.lower():year
3.407830 I-org    +2:word.lower():khartoum
3.358308 O        word.lower():this
3.335910 B-org    -2:word.lower():tariffs
3.332753 B-org    +1:word.lower():democratic
3.318072 B-gpe    +2:word.lower():scheduled
3.300045 B-org    -2:word.lower():cultivation
3.295296 O        word[-3:]:ice
3.286607 B-geo    -2:word.lower():al-qaida
3.252897 B-tim    word[-2:]:ay
3.214567 B-tim    -1:word.lower():next
3.211005 B-per    -2:word.lower():tamils
3.208769 B-gpe    -1:word.lower():prompt
3.194266 B-tim    word[-3:]:day
3.188693 B-org    +2:word.lower():pa