# CRFを用いた固有表現認識

このノートブックでは、[sklearn-crfsuite](https://sklearn-crfsuite.readthedocs.io/en/latest/)を用いて、固有表現認識のモデルを構築します。データセットとしては　CoNLLを使用します。

## 準備

### パッケージのインストール

In [1]:
!pip install sklearn-crfsuite==0.3.6 seqeval==1.2.2 eli5==0.11.0 scikit-learn==0.23.2



### インポート

In [2]:
import eli5
import scipy
from seqeval.metrics import classification_report, f1_score
from sklearn_crfsuite import CRF
from sklearn.metrics import make_scorer
from sklearn.model_selection import RandomizedSearchCV

### データの読み込み

まずはCoNLLのデータセットをアップロードします。ノートブックと同じ階層にDataフォルダがあり、その下に`conll2003/en`フォルダがあるので、学習・検証・テスト用データセットをアップロードしましょう。Colabでない場合は、データセットを読み込むときに正しいパスを指定します。

In [3]:
from google.colab import files
uploaded = files.upload()

Saving test.txt to test (1).txt
Saving train.txt to train (1).txt
Saving valid.txt to valid (1).txt


In [4]:
!head train.txt

-DOCSTART- -X- -X- O

EU NNP B-NP B-ORG
rejects VBZ B-VP O
German JJ B-NP B-MISC
call NN I-NP O
to TO B-VP O
boycott VB I-VP O
British JJ B-NP B-MISC
lamb NN I-NP O


次に、データセットを読み込みます。

In [5]:
def load_conll(file_path):
    sents = []
    sent = []
    with open(file_path, encoding='utf-8') as f:
      for line in f:
          line = line.strip()
          if line.startswith('-DOCSTART'):
            continue
          if line:
              word, pos, _, tag = line.split()
              sent.append((word, pos, tag))
          else:
              if len(sent) == 0:
                continue
              sents.append(sent)
              sent = []
    return sents

In [6]:
train_sents = load_conll('train.txt')
valid_sents = load_conll('valid.txt')
test_sents = load_conll('test.txt')

In [7]:
train_sents[0]

[('EU', 'NNP', 'B-ORG'),
 ('rejects', 'VBZ', 'O'),
 ('German', 'JJ', 'B-MISC'),
 ('call', 'NN', 'O'),
 ('to', 'TO', 'O'),
 ('boycott', 'VB', 'O'),
 ('British', 'JJ', 'B-MISC'),
 ('lamb', 'NN', 'O'),
 ('.', '.', 'O')]

In [8]:
print(len(train_sents))
print(len(valid_sents))
print(len(test_sents))

14041
3250
3453


## 前処理

データを読み込み終えたので、特徴を定義します。今回は前後2単語に関して、以下の特徴を使います。

- 小文字化した単語
- 大文字だけからなるか
- 単語の先頭の文字は大文字か
- 数字か
- 品詞

In [9]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.isupper()': word1.isupper(),
            '-1:word.istitle()': word1.istitle(),
            '-1:postag': postag1,
        })
    else:
        features['BOS'] = True

    if i > 1:
        word2 = sent[i-2][0]
        postag2 = sent[i-2][1]
        features.update({
            '-2:word.lower()': word2.lower(),
            '-2:word.isupper()': word2.isupper(),
            '-2:word.istitle()': word2.istitle(),
            '-2:postag': postag2,
        })
    else:
        features['-2:BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.isupper()': word1.isupper(),
            '+1:word.istitle()': word1.istitle(),
            '+1:postag': postag1,
        })
    else:
        features['EOS'] = True

    if i < len(sent) - 2:
        word2 = sent[i+2][0]
        postag2 = sent[i+2][1]
        features.update({
            '+2:word.lower()': word2.lower(),
            '+2:word.isupper()': word2.isupper(),
            '+2:word.istitle()': word2.istitle(),
            '+2:postag': postag2,
        })
    else:
        features['+2:EOS'] = True

    return features

def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [10]:
sent2features(train_sents[0])[0]

{'+1:postag': 'VBZ',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False,
 '+1:word.lower()': 'rejects',
 '+2:postag': 'JJ',
 '+2:word.istitle()': True,
 '+2:word.isupper()': False,
 '+2:word.lower()': 'german',
 '-2:BOS': True,
 'BOS': True,
 'bias': 1.0,
 'postag': 'NNP',
 'word.isdigit()': False,
 'word.istitle()': False,
 'word.isupper()': True,
 'word.lower()': 'eu'}

In [11]:
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_valid = [sent2features(s) for s in valid_sents]
y_valid = [sent2labels(s) for s in valid_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

## モデルの学習

In [12]:
%%time
model = CRF(
    algorithm='lbfgs',
    max_iterations=100,
    all_possible_transitions=False
)
model.fit(X_train + X_valid, y_train + y_valid)

CPU times: user 46.8 s, sys: 240 ms, total: 47 s
Wall time: 46.9 s


## モデルの評価

In [13]:
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred, digits=4))

              precision    recall  f1-score   support

         LOC     0.8385    0.7842    0.8104      1668
        MISC     0.7602    0.6638    0.7087       702
         ORG     0.7256    0.7038    0.7145      1661
         PER     0.8041    0.8683    0.8350      1617

   micro avg     0.7861    0.7697    0.7778      5648
   macro avg     0.7821    0.7550    0.7672      5648
weighted avg     0.7857    0.7697    0.7766      5648



## 重みの検査

[eli5](https://eli5.readthedocs.io/en/latest/index.html)の`show_weights`を使って、状態と繊維の特徴に関する重みを見てみましょう。


In [14]:
eli5.show_weights(model, top=30)



From \ To,O,B-LOC,I-LOC,B-MISC,I-MISC,B-ORG,I-ORG,B-PER,I-PER
O,4.001,3.893,0.0,4.013,0.0,4.057,0.0,5.837,0.0
B-LOC,0.363,0.002,7.573,0.492,0.0,0.34,0.0,-1.584,0.0
I-LOC,-0.504,-0.29,7.328,-0.13,0.0,-0.778,0.0,0.0,0.0
B-MISC,-0.039,-0.237,0.0,-0.698,7.347,0.438,0.0,0.77,0.0
I-MISC,-0.736,-0.93,0.0,-0.724,7.4,-0.492,0.0,-0.606,0.0
B-ORG,0.273,-1.175,0.0,-0.273,0.0,0.337,9.411,-1.614,0.0
I-ORG,-0.039,-1.556,0.0,-0.896,0.0,-1.305,9.568,-0.956,0.0
B-PER,0.729,-0.343,0.0,-0.802,0.0,0.0,0.0,0.0,9.652
I-PER,-0.099,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.222

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8
+6.131,word.lower():september,,,,,,,
+5.809,word.lower():july,,,,,,,
+5.755,word.lower():june,,,,,,,
+5.480,bias,,,,,,,
+5.463,word.lower():thursday,,,,,,,
+5.393,word.lower():friday,,,,,,,
+5.288,word.lower():august,,,,,,,
+5.217,word.lower():monday,,,,,,,
+5.169,word.lower():tuesday,,,,,,,
+5.137,word.lower():wednesday,,,,,,,

Weight?,Feature
+6.131,word.lower():september
+5.809,word.lower():july
+5.755,word.lower():june
+5.480,bias
+5.463,word.lower():thursday
+5.393,word.lower():friday
+5.288,word.lower():august
+5.217,word.lower():monday
+5.169,word.lower():tuesday
+5.137,word.lower():wednesday

Weight?,Feature
+5.063,word.lower():england
+4.327,word.lower():iraq
+4.210,word.lower():pakistan
+4.202,word.lower():u.s.
+4.062,word.lower():germany
+4.046,word.lower():britain
+3.951,word.lower():israel
+3.681,word.lower():russia
+3.591,word.lower():china
+3.553,word.lower():india

Weight?,Feature
+2.290,word.lower():lanka
+2.195,-1:word.lower():san
+2.174,word.lower():republic
+2.112,-1:word.lower():new
+1.925,word.lower():korea
+1.817,word.lower():kong
+1.761,word.lower():city
+1.679,-1:word.lower():hong
+1.663,word.lower():states
+1.630,word.lower():zealand

Weight?,Feature
+4.325,postag:$
+4.021,word.lower():gmt
+3.536,word.lower():german
+3.227,word.lower():english
+3.224,word.lower():moslem
+3.123,word.lower():dutch
+3.068,word.lower():american
+2.824,word.lower():australian
+2.816,word.lower():sudanese
+2.760,word.lower():british

Weight?,Feature
+3.321,word.lower():open
+3.011,word.lower():cup
+2.996,word.lower():division
+2.192,word.lower():league
+2.066,-1:word.lower():grand
+1.964,word.lower():day
+1.872,-1:word.lower():south
+1.805,postag:CD
+1.672,-1:postag:DT
+1.671,-1:word.lower():world

Weight?,Feature
+4.238,-1:word.lower():v
+3.732,word.lower():reuters
+3.186,word.lower():u.n.
+3.117,word.lower():ajax
+3.068,+1:word.lower():v
+2.838,word.lower():interfax
+2.721,+1:word.lower():21
+2.680,word.lower():osce
+2.588,word.lower():cofinec
+2.496,word.lower():rtrs

Weight?,Feature
+2.958,word.lower():newsroom
+2.699,word.lower():inc
+2.491,word.lower():corp
+2.289,+1:word.lower():4
+2.052,-1:word.lower():lloyd
+2.028,word.lower():co
+1.831,+1:word.lower():0
+1.724,word.lower():&
+1.718,-1:word.lower():&
+1.681,-1:word.lower():boatmen

Weight?,Feature
+4.613,word.lower():clinton
+3.518,word.lower():yeltsin
+3.452,word.lower():lebed
+3.412,word.lower():dole
+3.272,-1:word.lower():b
+3.170,BOS
+2.856,word.lower():arafat
+2.786,word.lower():inzamam-ul-haq
+2.744,-1:word.lower():president
+2.533,word.lower():dutroux

Weight?,Feature
+1.758,-1:word.lower():van
+1.483,-1:word.isupper()
+1.455,word.istitle()
+1.376,+2:word.lower():u.s.
+1.353,word.lower():van
+1.353,-1:word.lower():john
+1.325,-1:word.lower():)
+1.325,-1:postag:)
+1.272,-1:word.lower():mark
+1.262,+2:word.lower():out


遷移の行列を見ると、モデルはB-XのあとにI-Xが来ることを学習していることがわかります。また、ある種の遷移はほとんど起こらないことも学習しているようです。たとえば、B-LOCのあとにI-ORGが来ることはほとんどないといったことです。

## ハイパーパラメータの最適化

性能を向上させるに、ハイパーパラメータの最適化をしてみましょう。ここでは、ランダムサーチと3分割交差検定を使用して、正則化パラメータの探索をしてみます。

In [15]:
%%time
model = CRF(
    algorithm='lbfgs',
    max_iterations=100,
    all_possible_transitions=True
)

params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.05),
}

f1_scorer = make_scorer(f1_score)

rs = RandomizedSearchCV(
    model,
    params_space,
    cv=3,
    verbose=1,
    n_jobs=-1,
    n_iter=30,
    scoring=f1_scorer
)
rs.fit(X_train + X_valid, y_train + y_valid)

Fitting 3 folds for each of 30 candidates, totalling 90 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 2 concurrent workers.
[Parallel(n_jobs=-1)]: Done  46 tasks      | elapsed: 37.9min
[Parallel(n_jobs=-1)]: Done  90 out of  90 | elapsed: 73.3min finished


CPU times: user 1h 3min 9s, sys: 13.1 s, total: 1h 3min 22s
Wall time: 1h 14min 1s


In [16]:
print('best params:', rs.best_params_)
print('best CV score:', rs.best_score_)
print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))

best params: {'c1': 0.028100798861946483, 'c2': 0.002451948365671623}
best CV score: 0.8532850512553268
model size: 1.61M


テストデータを使って、最適なパラメータのモデルを評価してみましょう。

In [17]:
crf = rs.best_estimator_
y_pred = crf.predict(X_test)
print(classification_report(y_test, y_pred, digits=4))

              precision    recall  f1-score   support

         LOC     0.8767    0.8483    0.8623      1668
        MISC     0.7868    0.7151    0.7493       702
         ORG     0.7815    0.7363    0.7582      1661
         PER     0.8395    0.8862    0.8622      1617

   micro avg     0.8278    0.8097    0.8187      5648
   macro avg     0.8211    0.7965    0.8080      5648
weighted avg     0.8269    0.8097    0.8176      5648

