# 使用sklearn_crfsuite模块训练crf模型

sklearn-crfsuite是基于CRFsuite库的一款轻量级的CRF库。该库兼容sklearn的算法，因此可以结合sklearn库的算法设计实体识别系统。sklearn-crfsuite不仅提供了条件随机场的训练和预测方法还提供了评测方法。

安装：pip install sklearn-crfsuite

官方文档：https://sklearn-crfsuite.readthedocs.io/en/latest/

In [1]:
!pip install sklearn_crfsuite

Collecting sklearn_crfsuite
  Downloading sklearn_crfsuite-0.3.6-py2.py3-none-any.whl (12 kB)
Collecting python-crfsuite>=0.8.3
  Downloading python_crfsuite-0.9.7-cp37-cp37m-win_amd64.whl (154 kB)
Installing collected packages: python-crfsuite, sklearn-crfsuite
Successfully installed python-crfsuite-0.9.7 sklearn-crfsuite-0.3.6


## 导入需要的模块

In [4]:
import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

## 读取数据

In [5]:
def read_data(file):   
    fr = open(file,encoding="utf-8")
    ret_sents = []
    tmp = []
    for line in fr:
        line = line.strip()
        if line == "":
            ret_sents.append(tmp)
            tmp = []
            continue
        line = line.split("\t")
        if len(line) != 2:
            continue
        tmp.append((line[0],line[1]))
    return ret_sents

train_sents = read_data("msra/train_data")
test_sents = read_data("msra/test_data")

### 定义特征并转换语料

In [7]:
def word2features(sent, i):
    word = sent[i][0]
    
    features = {
        'bias': 1.0,
        'word': word,      
    }
    if i > 0:
        word1 = sent[i-1][0]
        features.update({
            '-1:word': word1,
        })
    else:
        features['BOS'] = True
        
    if i > 1:
        word2 = sent[i-2][0]
        features.update({
            '-2:word': word2
        })
        
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        features.update({
            '+1:word': word1,
        })
    else:
        features['EOS'] = True
        
    if i < len(sent)-2:
        word2 = sent[i+2][0]
        features.update({
            '+2:word': word2,
        })
                
    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token,label in sent]

def sent2tokens(sent):
    return [token for token,label in sent]

In [8]:
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

### 训练

In [9]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs', 
    c1=0.1, 
    c2=0.1, 
    max_iterations=100, 
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

UnicodeEncodeError: 'ascii' codec can't encode characters in position 9-11: ordinal not in range(128)

In [9]:
# 转移特征及权重
crf.transition_features_

{('O', 'O'): 3.770691,
 ('O', 'B-LOC'): 0.188121,
 ('O', 'I-LOC'): -9.183171,
 ('O', 'B-ORG'): -0.359881,
 ('O', 'I-ORG'): -9.002781,
 ('O', 'B-PER'): -0.429625,
 ('O', 'I-PER'): -8.421054,
 ('B-LOC', 'O'): -1.288199,
 ('B-LOC', 'B-LOC'): -1.279092,
 ('B-LOC', 'I-LOC'): 3.899732,
 ('B-LOC', 'B-ORG'): -6.498161,
 ('B-LOC', 'I-ORG'): -8.78877,
 ('B-LOC', 'B-PER'): -7.099309,
 ('B-LOC', 'I-PER'): -6.864922,
 ('I-LOC', 'O'): -0.476086,
 ('I-LOC', 'B-LOC'): -1.090046,
 ('I-LOC', 'I-LOC'): 3.552982,
 ('I-LOC', 'B-ORG'): -5.550098,
 ('I-LOC', 'I-ORG'): -9.640868,
 ('I-LOC', 'B-PER'): -4.283422,
 ('I-LOC', 'I-PER'): -7.681232,
 ('B-ORG', 'O'): -3.27917,
 ('B-ORG', 'B-LOC'): -5.181805,
 ('B-ORG', 'I-LOC'): -5.533095,
 ('B-ORG', 'B-ORG'): -2.446341,
 ('B-ORG', 'I-ORG'): 4.579519,
 ('B-ORG', 'B-PER'): -4.818339,
 ('B-ORG', 'I-PER'): -5.451564,
 ('I-ORG', 'O'): -0.362844,
 ('I-ORG', 'B-LOC'): -2.382928,
 ('I-ORG', 'I-LOC'): -7.110184,
 ('I-ORG', 'B-ORG'): -4.371057,
 ('I-ORG', 'I-ORG'): 5.34612,
 

In [16]:
# 状态特征及权重
crf.state_features_

{('bias', 'O'): 1.092168,
 ('bias', 'B-LOC'): -0.592062,
 ('bias', 'I-LOC'): -0.173155,
 ('bias', 'B-ORG'): -0.926908,
 ('bias', 'I-ORG'): -0.928595,
 ('bias', 'B-PER'): -1.443082,
 ('bias', 'I-PER'): -0.134878,
 ('word:当', 'O'): 3.415742,
 ('word:当', 'I-LOC'): -0.362896,
 ('word:当', 'B-ORG'): -0.241426,
 ('word:当', 'I-ORG'): -0.577074,
 ('word:当', 'I-PER'): 0.155251,
 ('BOS', 'O'): 2.336037,
 ('BOS', 'B-LOC'): 1.852237,
 ('BOS', 'B-ORG'): 1.68277,
 ('BOS', 'I-ORG'): -5.110783,
 ('BOS', 'B-PER'): 1.736904,
 ('BOS', 'I-PER'): -3.953563,
 ('+1:word:希', 'O'): -1.348802,
 ('+1:word:希', 'B-LOC'): 1.815173,
 ('+1:word:希', 'I-LOC'): -0.393703,
 ('+1:word:希', 'B-ORG'): 0.002865,
 ('+1:word:希', 'I-ORG'): 0.059517,
 ('+1:word:希', 'B-PER'): 0.257977,
 ('+1:word:希', 'I-PER'): -0.051414,
 ('+2:word:望', 'O'): 0.042823,
 ('+2:word:望', 'B-LOC'): -0.840629,
 ('+2:word:望', 'I-LOC'): -0.063721,
 ('+2:word:望', 'B-ORG'): -0.377318,
 ('+2:word:望', 'I-ORG'): 0.729842,
 ('+2:word:望', 'I-PER'): 0.481718,
 ('wo

### 评测
虽然数据集中标签O比较多，但是我们更关注其它标签。因此下面计算标签O以外的其它标签的平均F1值。
sklearn-crfsuite.metrics 提供了相应了功能。

In [11]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-LOC', 'I-LOC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']

In [12]:
y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)

0.8362230945945955

In [13]:
# group B and I results
sorted_labels = sorted(
    labels, 
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))



              precision    recall  f1-score   support

       B-LOC      0.907     0.840     0.872      2877
       I-LOC      0.882     0.827     0.854      4394
       B-ORG      0.775     0.719     0.746      1331
       I-ORG      0.805     0.784     0.794      5670
       B-PER      0.932     0.770     0.843      1973
       I-PER      0.890     0.867     0.878      3851

   micro avg      0.862     0.812     0.836     20096
   macro avg      0.865     0.801     0.831     20096
weighted avg      0.864     0.812     0.836     20096



ELI5是一个Python包，可以检查sklearn_crfsuite.CRF模型的权重。

安装：pip install eli5

In [2]:
!pip install eli5

Collecting eli5
  Downloading eli5-0.11.0-py2.py3-none-any.whl (106 kB)
Installing collected packages: eli5
Successfully installed eli5-0.11.0


In [3]:
import eli5
eli5.show_weights(crf,top = 20)

NameError: name 'crf' is not defined

为便于阅读，我们只检查一部分标签。

In [15]:
eli5.show_weights(crf,top = 10,targets = ['O','B-LOC','B-ORG'])



From \ To,O,B-LOC,B-ORG
O,3.771,0.188,-0.36
B-LOC,-1.288,-1.279,-6.498
B-ORG,-3.279,-5.182,-2.446

Weight?,Feature,Unnamed: 2_level_0
Weight?,Feature,Unnamed: 2_level_1
Weight?,Feature,Unnamed: 2_level_2
+9.722,word:、,
+9.559,word:，,
+7.435,word:在,
+7.263,word:到,
+7.142,word:是,
+7.001,word:等,
+6.875,word:某,
+6.674,word:从,
… 8528 more positive …,… 8528 more positive …,
… 5779 more negative …,… 5779 more negative …,

Weight?,Feature
+9.722,word:、
+9.559,word:，
+7.435,word:在
+7.263,word:到
+7.142,word:是
+7.001,word:等
+6.875,word:某
+6.674,word:从
… 8528 more positive …,… 8528 more positive …
… 5779 more negative …,… 5779 more negative …

Weight?,Feature
+9.690,word:淮
+8.211,word:陕
+8.117,word:湘
+6.231,word:蜀
+6.182,word:浙
+5.992,word:俄
+5.931,word:漯
+5.733,word:闽
+5.662,word:柬
… 3728 more positive …,… 3728 more positive …

Weight?,Feature
+5.115,+1:word:汽
+4.682,word:陕
+4.457,+1:word:两
+4.327,-2:word:佰
+4.162,+1:word:K
+4.059,word:央
+4.020,word:参
+3.914,word:扬
… 2891 more positive …,… 2891 more positive …
… 1636 more negative …,… 1636 more negative …
