https://github.com/TeamHG-Memex/sklearn-crfsuite/blob/master/docs/CoNLL2002.ipynb

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [10]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.cross_validation import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.grid_search import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

# Данные

In [3]:
import pandas as pd

df = pd.read_csv('../data/ner.csv', encoding = "ISO-8859-1", error_bad_lines=False)

df.head(10)

b'Skipping line 281837: expected 25 fields, saw 34\n'


Unnamed: 0.1,Unnamed: 0,lemma,next-lemma,next-next-lemma,next-next-pos,next-next-shape,next-next-word,next-pos,next-shape,next-word,...,prev-prev-lemma,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
0,0,thousand,of,demonstr,NNS,lowercase,demonstrators,IN,lowercase,of,...,__start2__,__START2__,wildcard,__START2__,wildcard,__START1__,1.0,capitalized,Thousands,O
1,1,of,demonstr,have,VBP,lowercase,have,NNS,lowercase,demonstrators,...,__start1__,__START1__,wildcard,__START1__,capitalized,Thousands,1.0,lowercase,of,O
2,2,demonstr,have,march,VBN,lowercase,marched,VBP,lowercase,have,...,thousand,NNS,capitalized,Thousands,lowercase,of,1.0,lowercase,demonstrators,O
3,3,have,march,through,IN,lowercase,through,VBN,lowercase,marched,...,of,IN,lowercase,of,lowercase,demonstrators,1.0,lowercase,have,O
4,4,march,through,london,NNP,capitalized,London,IN,lowercase,through,...,demonstr,NNS,lowercase,demonstrators,lowercase,have,1.0,lowercase,marched,O
5,5,through,london,to,TO,lowercase,to,NNP,capitalized,London,...,have,VBP,lowercase,have,lowercase,marched,1.0,lowercase,through,O
6,6,london,to,protest,VB,lowercase,protest,TO,lowercase,to,...,march,VBN,lowercase,marched,lowercase,through,1.0,capitalized,London,B-geo
7,7,to,protest,the,DT,lowercase,the,VB,lowercase,protest,...,through,IN,lowercase,through,capitalized,London,1.0,lowercase,to,O
8,8,protest,the,war,NN,lowercase,war,DT,lowercase,the,...,london,NNP,capitalized,London,lowercase,to,1.0,lowercase,protest,O
9,9,the,war,in,IN,lowercase,in,NN,lowercase,war,...,to,TO,lowercase,to,lowercase,protest,1.0,lowercase,the,O


In [4]:
%%time
def to_sents(df):
    sents = []
    sent = []
    for index, item in df.iterrows():
        if (item.word in ['.', '?', '!', '...']):
            sent.append(tuple((str(item.word), str(item.lemma), str(item.pos), str(item.tag))))
            sents.append(sent)
            sent = []
        else:
            sent.append(tuple((str(item.word), str(item.lemma), str(item.pos), str(item.tag))))
    return(sents)

sents = to_sents(df)
print(*sents[0], sep='\n')

('Thousands', 'thousand', 'NNS', 'O')
('of', 'of', 'IN', 'O')
('demonstrators', 'demonstr', 'NNS', 'O')
('have', 'have', 'VBP', 'O')
('marched', 'march', 'VBN', 'O')
('through', 'through', 'IN', 'O')
('London', 'london', 'NNP', 'B-geo')
('to', 'to', 'TO', 'O')
('protest', 'protest', 'VB', 'O')
('the', 'the', 'DT', 'O')
('war', 'war', 'NN', 'O')
('in', 'in', 'IN', 'O')
('Iraq', 'iraq', 'NNP', 'B-geo')
('and', 'and', 'CC', 'O')
('demand', 'demand', 'VB', 'O')
('the', 'the', 'DT', 'O')
('withdrawal', 'withdraw', 'NN', 'O')
('of', 'of', 'IN', 'O')
('British', 'british', 'JJ', 'B-gpe')
('troops', 'troop', 'NNS', 'O')
('from', 'from', 'IN', 'O')
('that', 'that', 'DT', 'O')
('country', 'countri', 'NN', 'O')
('.', '.', '.', 'O')
CPU times: user 3min 48s, sys: 10.4 s, total: 3min 59s
Wall time: 6min 11s


In [5]:
len(sents)

47985

## Признаковое пространство 

In [15]:
def word2features(sent, i):
    word = sent[i][1]
    postag = sent[i][2]
    
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,   
        'word[-3:]': word[-3:],
        'word[:3]': word[:3].lower(),
    }
    if i > 0:
        word1 = sent[i-1][1]
        postag1 = sent[i-1][2]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:word[-3:]': word[-3:],
            '-1:word[:3]': word[:3].lower(),
        })
    else:
        features['BOS'] = True
        
    if i < len(sent)-1:
        word1 = sent[i+1][1]
        postag1 = sent[i+1][2]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:word[-3:]': word[-3:],
            '+1:word[:3]': word[:3].lower(),
        })
    else:
        features['EOS'] = True
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, lemma, postag, label in sent]

def sent2tokens(sent):
    return [token for token, lemma, postag, label in sent]

И вот что получится

In [16]:
sent2features(sents[0])[1]

{'+1:postag': 'NNS',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:word.lower()': 'demonstr',
 '+1:word[-3:]': 'of',
 '+1:word[:3]': 'of',
 '-1:postag': 'NNS',
 '-1:word.istitle()': False,
 '-1:word.isupper()': False,
 '-1:word.lower()': 'thousand',
 '-1:word[-3:]': 'of',
 '-1:word[:3]': 'of',
 'bias': 1.0,
 'postag': 'IN',
 'word.isdigit()': False,
 'word.istitle()': False,
 'word.isupper()': False,
 'word.lower()': 'of',
 'word[-3:]': 'of',
 'word[:3]': 'of'}

Извлекаем признаки из данных

In [17]:
%%time
X = [sent2features(s) for s in sents]
y = [sent2labels(s) for s in sents]

CPU times: user 7.25 s, sys: 1.19 s, total: 8.44 s
Wall time: 11.2 s


In [18]:
(X_train, X_test, y_train, y_test) = train_test_split(X, y, test_size=0.332, random_state=42)

## Обучение


In [19]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.1, 
    c2=0.1, 
    max_iterations=100, 
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 3min 58s, sys: 11.9 s, total: 4min 10s
Wall time: 5min 38s


## Оценка качества


In [20]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-geo',
 'I-geo',
 'B-gpe',
 'B-tim',
 'I-tim',
 'B-per',
 'I-per',
 'B-org',
 'I-org',
 'I-gpe',
 'B-art',
 'I-art',
 'B-nat',
 'B-eve',
 'I-eve',
 'I-nat']

In [21]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, 
                      average='micro', labels=labels)

0.87085826645113096

Результаты по каждому классу:

In [None]:
# group B and I results
sorted_labels = sorted(
    labels, 
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

             precision    recall  f1-score   support

      B-art      0.728     0.381     0.500       155
      I-art      0.692     0.233     0.348       116
      B-eve      0.698     0.649     0.673       114
      I-eve      0.701     0.565     0.626       108
      B-geo      0.872     0.908     0.890     12539
      I-geo      0.838     0.826     0.832      2507
      B-gpe      0.966     0.934     0.950      5448
      I-gpe      0.932     0.506     0.656        81
      B-nat      0.694     0.417     0.521        60
      I-nat      0.684     0.542     0.605        24
      B-org      0.821     0.764     0.791      6663
      I-org      0.832     0.836     0.834      5269
      B-per      0.876     0.860     0.868      5604
      I-per      0.876     0.917     0.896      5717
      B-tim      0.934     0.892     0.913      6723
      I-tim      0.855     0.791     0.822      2107

avg / total      0.876     0.865     0.870     53235



## Оптимизация гиперпараметров


In [None]:
%%time
# define fixed parameters and parameters to search
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    max_iterations=100, 
    all_possible_transitions=True
)
params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.05),
}

# use the same metric for evaluation
f1_scorer = make_scorer(metrics.flat_f1_score, 
                        average='weighted', labels=labels)

# search
rs = RandomizedSearchCV(crf, params_space, 
                        cv=3, 
                        verbose=1, 
                        n_jobs=-1, 
                        n_iter=20, 
                        scoring=f1_scorer)
rs.fit(X_train, y_train)

Fitting 3 folds for each of 20 candidates, totalling 60 fits


Best result:

In [None]:
crf = rs.best_estimator_
print('best params:', rs.best_params_)
print('best CV score:', rs.best_score_)
print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))

### Анализ пространства параметров


In [None]:
_x = [s.parameters['c1'] for s in rs.grid_scores_]
_y = [s.parameters['c2'] for s in rs.grid_scores_]
_c = [s.mean_validation_score for s in rs.grid_scores_]

fig = plt.figure()
fig.set_size_inches(12, 12)
ax = plt.gca()
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel('C1')
ax.set_ylabel('C2')
ax.set_title("Randomized Hyperparameter Search CV Results (min={:0.3}, max={:0.3})".format(
    min(_c), max(_c)
))

ax.scatter(_x, _y, c=_c, s=60, alpha=0.9, edgecolors=[0,0,0])

print("Dark blue => {:0.4}, dark red => {:0.4}".format(min(_c), max(_c)))

## Лучшие результаты


In [None]:
crf = rs.best_estimator_
y_pred = crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

In [None]:
metrics.flat_f1_score(y_test, y_pred, 
                      average='weighted', labels=labels)

## Что выучил классификатор

In [None]:
from collections import Counter

def print_transitions(trans_features):
    for (label_from, label_to), weight in trans_features:
        print("%-6s -> %-7s %0.6f" % (label_from, label_to, weight))

print("Top likely transitions:")
print_transitions(Counter(crf.transition_features_).most_common(20))

print("\nTop unlikely transitions:")
print_transitions(Counter(crf.transition_features_).most_common()[-20:])

In [None]:
def print_state_features(state_features):
    for (attr, label), weight in state_features:
        print("%0.6f %-8s %s" % (weight, label, attr))    

print("Top positive:")
print_state_features(Counter(crf.state_features_).most_common(30))

print("\nTop negative:")
print_state_features(Counter(crf.state_features_).most_common()[-30:])