In [1]:
from itertools import chain
import nltk
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelBinarizer
import sklearn
import pycrfsuite

print(sklearn.__version__)

0.16.1


In [2]:
nltk.corpus.conll2002.fileids()

['esp.testa', 'esp.testb', 'esp.train', 'ned.testa', 'ned.testb', 'ned.train']

In [3]:
%%time
train_sents = list(nltk.corpus.conll2002.iob_sents('ned.train'))
test_sents = list(nltk.corpus.conll2002.iob_sents('ned.testa'))

CPU times: user 5.9 s, sys: 145 ms, total: 6.05 s
Wall time: 6.14 s


In [4]:
train_sents[0]

[('De', 'Art', 'O'),
 ('tekst', 'N', 'O'),
 ('van', 'Prep', 'O'),
 ('het', 'Art', 'O'),
 ('arrest', 'N', 'O'),
 ('is', 'V', 'O'),
 ('nog', 'Adv', 'O'),
 ('niet', 'Adv', 'O'),
 ('schriftelijk', 'Adj', 'O'),
 ('beschikbaar', 'Adj', 'O'),
 ('maar', 'Conj', 'O'),
 ('het', 'Art', 'O'),
 ('bericht', 'N', 'O'),
 ('werd', 'V', 'O'),
 ('alvast', 'Adv', 'O'),
 ('bekendgemaakt', 'V', 'O'),
 ('door', 'Prep', 'O'),
 ('een', 'Art', 'O'),
 ('communicatiebureau', 'N', 'O'),
 ('dat', 'Conj', 'O'),
 ('Floralux', 'N', 'B-ORG'),
 ('inhuurde', 'V', 'O'),
 ('.', 'Punc', 'O')]

In [5]:
test_sents[0]

[('Dat', 'Pron', 'O'),
 ('is', 'V', 'O'),
 ('verder', 'Adj', 'O'),
 ('opgelaaid', 'N', 'O'),
 ('door', 'Prep', 'O'),
 ('windsnelheden', 'N', 'O'),
 ('die', 'Pron', 'O'),
 ('oplopen', 'V', 'O'),
 ('tot', 'Prep', 'O'),
 ('35', 'Num', 'O'),
 ('kilometer', 'N', 'O'),
 ('per', 'Prep', 'O'),
 ('uur', 'N', 'O'),
 ('.', 'Punc', 'O')]

In [6]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    features = [
        'bias',
        'word[-3:]=' + word[-3:],
        'word[-2:]=' + word[-2:],
        'word.isupper={}'.format(word.isupper()),
        'word.istitle={}'.format(word.istitle()),
        'word.isdigit={}'.format(word.isdigit()),
        'postag=' + postag,
        'postag[:2]=' + postag[:2],
    ]
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.extend([
                '-1:word.lower=' + word1.lower(),
                '-1:word.istitle={}'.format(word1.istitle()),
                '-1:word.isupper={}'.format(word1.isupper()),
                '-1:postag=' + postag1,
                '-1:postag[:2]=' + postag1[:2]
            ])
    else:
        features.append('BOS')
    
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.extend([
            '+1:word.lower=' + word1.lower(),
            '+1:word.istitle=%s' % word1.istitle(),
            '+1:word.isupper=%s' % word1.isupper(),
            '+1:postag=' + postag1,
            '+1:postag[:2]=' + postag1[:2],
        ])
    else:
        features.append('EOS')
        
    return features

def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [7]:
sent2features(train_sents[0])[0]

['bias',
 'word[-3:]=De',
 'word[-2:]=De',
 'word.isupper=False',
 'word.istitle=True',
 'word.isdigit=False',
 'postag=Art',
 'postag[:2]=Ar',
 'BOS',
 '+1:word.lower=tekst',
 '+1:word.istitle=False',
 '+1:word.isupper=False',
 '+1:postag=N',
 '+1:postag[:2]=N']

In [8]:
%%time
x_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

x_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 4.5 s, sys: 228 ms, total: 4.73 s
Wall time: 4.76 s


In [9]:
%%time
trainer = pycrfsuite.Trainer(verbose=False)

for xseq, yseq in zip(x_train, y_train):
    trainer.append(xseq, yseq)

CPU times: user 4.67 s, sys: 55.8 ms, total: 4.72 s
Wall time: 4.75 s


In [10]:
trainer.set_params({
    'c1': 1.0,   # coefficient for L1 penalty
    'c2': 1e-3,  # coefficient for L2 penalty
    'max_iterations': 50,  # stop earlier

    # include transitions that are possible, but not observed
    'feature.possible_transitions': True
})

In [11]:
trainer.params()

['feature.minfreq',
 'feature.possible_states',
 'feature.possible_transitions',
 'c1',
 'c2',
 'max_iterations',
 'num_memories',
 'epsilon',
 'period',
 'delta',
 'linesearch',
 'max_linesearch']

In [12]:
%%time
trainer.train('conll2002-ned.crfsuite')

CPU times: user 31.8 s, sys: 80.1 ms, total: 31.9 s
Wall time: 31.9 s


In [13]:
!ls -lh ./conll2002-ned.crfsuite

-rw-rw-r-- 1 curtis curtis 265K Jun  2 09:35 ./conll2002-ned.crfsuite


In [14]:
tagger = pycrfsuite.Tagger()
tagger.open('conll2002-ned.crfsuite')

<contextlib.closing at 0x7fdad4a30a20>

In [15]:
example_sent = test_sents[2]
print(' '.join(sent2tokens(example_sent)), end = '\n\n')

print("Predicted:", ' '.join(tagger.tag(sent2features(example_sent))))
print("Correct:  ", ' '.join(sent2labels(example_sent)))

Ook in Californië , in Sierra Nevada , woeden al een week lang hevige bosbranden .

Predicted: O O B-LOC O O B-LOC I-LOC O O O O O O O O O
Correct:   O O B-LOC O O B-LOC I-LOC O O O O O O O O O


In [16]:
def bio_classification_report(y_true, y_pred):
    lb = LabelBinarizer()
    y_true_combined = lb.fit_transform(list(chain.from_iterable(y_true)))
    y_pred_combined = lb.transform(list(chain.from_iterable(y_pred)))
        
    tagset = set(lb.classes_) - {'O'}
    tagset = sorted(tagset, key=lambda tag: tag.split('-', 1)[::-1])
    class_indices = {cls: idx for idx, cls in enumerate(lb.classes_)}
    
    return classification_report(
        y_true_combined,
        y_pred_combined,
        labels = [class_indices[cls] for cls in tagset],
        target_names = tagset,
    )

In [17]:
%%time
y_pred = [tagger.tag(xseq) for xseq in x_test]

CPU times: user 890 ms, sys: 7.8 ms, total: 898 ms
Wall time: 958 ms


In [18]:
print(bio_classification_report(y_test, y_pred))

             precision    recall  f1-score   support

      B-LOC       0.72      0.72      0.72       479
      I-LOC       0.42      0.38      0.40        64
     B-MISC       0.76      0.60      0.67       748
     I-MISC       0.60      0.51      0.55       215
      B-ORG       0.84      0.53      0.65       686
      I-ORG       0.83      0.57      0.68       396
      B-PER       0.67      0.80      0.73       703
      I-PER       0.75      0.94      0.84       423

avg / total       0.74      0.67      0.69      3714



In [19]:
from collections import Counter
info = tagger.info()

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(info.transitions).most_common(15))

print("\nTop unlikely transitions:")
print_transitions(Counter(info.transitions).most_common()[-15:])

Top likely transitions:
B-PER  -> I-PER   5.160822
B-LOC  -> I-LOC   5.028159
B-ORG  -> I-ORG   4.599620
I-MISC -> I-MISC  4.233843
I-LOC  -> I-LOC   4.138612
I-ORG  -> I-ORG   4.014301
B-MISC -> I-MISC  3.813706
I-PER  -> I-PER   2.877162
O      -> O       1.237149
O      -> B-PER   0.841806
O      -> B-MISC  0.810526
O      -> B-LOC   0.719871
O      -> B-ORG   0.415048
I-MISC -> B-ORG   0.286818
B-MISC -> B-PER   0.178145

Top unlikely transitions:
I-LOC  -> B-PER   -2.212303
B-LOC  -> I-PER   -2.317872
B-PER  -> B-ORG   -2.399059
B-LOC  -> I-MISC  -2.487489
B-LOC  -> B-ORG   -2.573519
I-PER  -> B-PER   -2.598128
B-MISC -> I-LOC   -2.627749
I-MISC -> B-LOC   -2.730945
I-MISC -> I-PER   -2.781056
B-MISC -> I-PER   -3.399730
B-MISC -> I-ORG   -3.689969
O      -> I-LOC   -6.559237
O      -> I-ORG   -6.761947
O      -> I-MISC  -7.033212
O      -> I-PER   -7.323753


In [20]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-6s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(info.state_features).most_common(20))

print("\nTop negative:")
print_state_features(Counter(info.state_features).most_common()[-20:])

Top positive:
5.447402 O      -1:word.lower=+
5.001577 I-MISC +1:word.lower=ned
4.910584 B-ORG  +1:word.lower=volkskrant
4.852691 I-MISC -1:word.lower=tbs
4.818328 B-ORG  word[-3:]=Mix
4.788142 B-MISC word[-3:]='er
4.570704 B-PER  +1:word.lower=gucht
4.458404 B-LOC  +1:word.lower=toneelhuis
4.394968 I-PER  +1:word.lower=reeth
4.355056 B-ORG  word[-3:]=mar
4.342996 B-ORG  -1:word.lower=holding
4.341810 B-ORG  word[-3:]=sys
4.339052 B-MISC +1:word.lower=bijsluiter
4.319794 B-PER  word[-3:]=par
4.305789 B-LOC  word[-2:]=ië
4.283598 B-ORG  word[-3:]=vas
4.178316 O      word[-3:]=oto
4.142664 B-ORG  word[-3:]=d's
4.137002 B-ORG  word[-3:]=ray
4.098414 B-MISC -1:word.lower=tentoonstelling

Top negative:
-2.319940 B-PER  word[-3:]=ken
-2.328662 O      -1:word.lower=zogenaamde
-2.386422 B-LOC  word.istitle=False
-2.400493 O      +1:word.lower=toont
-2.418656 O      +1:word.lower=(
-2.434077 O      +1:word.lower=vormgeving
-2.436619 B-PER  -1:word.lower=in
-2.480851 I-LOC  postag[:2]=Ad
-2.6054