In [21]:
import sklearn
import scipy.stats
from sklearn.model_selection import RandomizedSearchCV 

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

import random

### Use the data extracted from GRAMMAR PATTERN 1: Verb to build a crf model

In [2]:
all_data = []
with open('data/seq_label_data_v2.txt') as file:
    for line in file:
        all_data += [eval(line)]
random.shuffle(all_data)

In [6]:
all_data[0]

[('Lead', 'NOUN', 'O'),
 ('can', 'VERB', 'O'),
 ('accumulate', 'VERB', 'V'),
 ('in', 'ADP', 'O'),
 ('the', 'DET', 'O'),
 ('body', 'NOUN', 'O'),
 ('until', 'ADP', 'O'),
 ('toxic', 'ADJ', 'O'),
 ('levels', 'NOUN', 'O'),
 ('are', 'AUX', 'O'),
 ('reached', 'VERB', 'O')]

In [3]:
%%time
split_num = int(len(all_data)*(4/5))
train_sents = all_data[:split_num]
test_sents = all_data[split_num:]
print(len(train_sents), len(test_sents))

9272 2318
CPU times: user 748 µs, sys: 0 ns, total: 748 µs
Wall time: 638 µs


### Features

In [7]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    label = sent[i][2]
    
    features = {
        'word': word,
        'postag': postag,
    }
    if label == 'V':
        features['target'] = True
        
    if i == 0:
        features['BOS'] = True
    else:
        word_1 = sent[i-1][0]
        
        features.update({
            'passive': word_1 in 'be/is/are/was/were'.split('/') and postag == 'VERB' and word[-3:] != 'ing',
            '-1:word == \'and\'': word_1 == 'and',
        })
        
    if i == len(sent)-1:
        features['EOS'] = True
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

# def sent2tokens(sent):
#     return [token for token, postag, label in sent]

Extract features from the data:

In [8]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

# X_all = [sent2features(s) for s in all_data]
# y_all = [sent2labels(s) for s in all_data]

CPU times: user 174 ms, sys: 29.7 ms, total: 204 ms
Wall time: 201 ms


## Training

To see all possible CRF parameters check its docstring. Here we are useing L-BFGS training algorithm (it is default) with Elastic Net (L1 + L2) regularization.

In [9]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 33.7 s, sys: 29 ms, total: 33.7 s
Wall time: 33.7 s


## Evaluation

There is much more O entities in data set, but we're more interested in other entities. To account for this we'll use averaged F1 score computed for all labels except for O. ``sklearn-crfsuite.metrics`` package provides some useful metrics for sequence classification task, including this one.

In [10]:
labels = list(crf.classes_)
labels.remove('O')

In [11]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, 
                      average='weighted', labels=labels)

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


0.8711856976844636

In [12]:
metrics.sequence_accuracy_score(y_test, y_pred)

0.7165660051768766

#### Test sentence

In [16]:
test_sent = [('I', 'NOUN', 'O'), ('shout', 'VERB', 'V'), ('at', 'ADP', 'O'), ('the', 'NOUN', 'O'), ('children', 'NOUN', 'O'), ('.', 'PUNCT', 'O')]
sent2features(test_sent)
test_pred = crf.predict([sent2features(test_sent)])
test_pred

[['O', 'V', 'at', 'n', 'n', 'O']]

#### Save CRF model

In [113]:
import joblib

with open('crf_model.joblib', 'wb') as fo:  
    joblib.dump(crf, fo)

Inspect per-class results in more detail:

In [17]:
# group B and I results
sorted_labels = sorted(
    labels, 
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

           V      1.000     1.000     1.000      2208
           n      0.885     0.875     0.880      2897
           v      0.000     0.000     0.000         0
         way      0.000     0.000     0.000         2
       about      0.000     0.000     0.000         1
         adj      0.000     0.000     0.000         0
         adv      0.667     0.182     0.286        11
    adv/prep      0.000     0.000     0.000        13
          be      0.593     0.291     0.390       110
         Ved      0.582     0.291     0.388       110
          if      0.000     0.000     0.000         0
          of      0.000     0.000     0.000         1
     against      0.000     0.000     0.000         0
          wh      0.524     0.407     0.458        27
        that      0.980     0.980     0.980        51
     through      1.000     0.167     0.286         6
        like      0.000     0.000     0.000         1
        ving      0.000    

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


## Check best estimator on our test data

As you can see, quality is improved.

In [22]:
crf = rs.best_estimator_
y_pred = crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

NameError: name 'rs' is not defined

## Let's check what classifier learned

In [23]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

Top likely transitions:
O      -> O       6.636823
be     -> Ved     4.397070
way    -> way     4.368860
wh     -> wh      4.151701
aux    -> be      4.135489
n      -> n       4.131335
and    -> v       3.761766
as     -> adj     3.609226
aux    -> not     3.450092
O      -> V       3.369684
as     -> if      3.140766
out    -> of      3.129034
by     -> ving    3.099795
prep/adv -> O       3.053058
to     -> vp      2.859820
O      -> wh      2.827245
way    -> prep/adv 2.827207
O      -> be      2.747605
V      -> n       2.651217
Ved    -> at      2.559702

Top unlikely transitions:
V      -> V       -2.194478
V      -> Ved     -2.246393
V      -> be      -2.355027
V      -> aux     -2.433817
not    -> V       -2.455993
O      -> of      -2.681680
n      -> prep    -2.779962
n      -> of      -2.835403
n      -> aux     -2.894727
Ved    -> V       -2.904784
n      -> Ved     -2.977809
wh     -> V       -3.002955
adv/prep -> n       -3.075018
n      -> be      -3.554158
V      -> ad

Check the state features:

In [24]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])

Top positive:
17.052123 V        target
8.288044 of       word:of
7.971171 O        BOS
7.862612 wh       word:which
7.615649 as       word:as
7.558498 that     word:that
7.116449 with     word:with
7.031221 to       word:to
6.997499 for      word:for
6.543719 wh       word:who
6.397497 wh       word:what
6.363945 n        postag:NOUN
6.348635 at       word:at
6.226175 n        word:of
6.064075 amount   postag:NUM
6.030686 wh       word:how
5.770945 from     word:from
5.728208 O        word:many
5.644918 Ved      postag:VERB
5.608067 into     word:into
5.597507 on       word:on
5.088123 wh       word:whether
5.062125 in       word:in
5.023542 n        postag:PRON
4.988443 O        word:today
4.984593 way      word:way
4.751935 wh       word:where
4.657137 through  word:through
4.651979 O        postag:PUNCT
4.623456 O        EOS

Top negative:
-1.912327 O        word:economies
-1.914350 O        word:socialism
-1.948481 O        word:owned
-1.949982 O        word:not
-1.955985 O       