In [1]:
import re
from pyhanlp import *

def get_sents(path):
    #'../data/train.txt';'../data/dev.txt'
    sentences = []
    sentence = []
    cnt = 0
    split_pattern = re.compile(r',|\.|;|，|。|；|\?|\!|\.\.\.\.\.\.|……')
    with open(path,'r',encoding = 'utf8') as f:
        for line in f.readlines():#每行为一个字符和其tag，中间用tab隔开
            line = line.strip().split('\t')
            if(not line or len(line) < 2): continue
            if line[1] == 'OO':
                line[1] = 'O'
            word_unit = [line[0],line[1]]
            if split_pattern.match(word_unit[0]):
                sentence.append(word_unit)
                sent = ''.join((word_unit[0] for word_unit in sentence))
                nature_list = []
                for term in HanLP.segment(sent):
                    for i in range(len(term.word)):# 分词
                        nature = '{}'.format(term.nature)
                        nature_list.append(nature)
                for idx,word_unit in enumerate(sentence):
                    word_unit.insert(1,nature_list[idx]) # insert损失一些性能
                sentences.append(sentence.copy())
                sentence.clear()
            else:
                sentence.append(word_unit)
        if(len(sentence)):
            sent = ''.join((word_unit[0] for word_unit in sentence))
            nature_list = []
            for term in HanLP.segment(sent):
                for i in range(len(term.word)):  # 分词
                    nature = '{}'.format(term.nature)
                    nature_list.append(nature)
            for idx, word_unit in enumerate(sentence):
                word_unit.insert(1, nature_list[idx])  # insert损失一些性能
            sentences.append(sentence.copy())
            sentence.clear()
    return sentences

In [2]:
train_sentences = get_sents('../data/train.txt')
train_sents = train_sentences
test_sents = get_sents('../data/dev.txt')

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import sklearn_crfsuite
from itertools import chain
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.cross_validation import cross_val_score
from sklearn.grid_search import RandomizedSearchCV
import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics



In [7]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word.isupper()': word.isupper(), #对于中文数据集来说，这两条用处不大
        'word.isdigit()': word.isdigit(),
        'postag': postag, 
#         'postag[:2]': postag[:2], #感觉这条特征没意义
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.isupper()': word1.isupper(),
             '-1:word.isdigit()': word.isdigit(),
            '-1:postag': postag1,
#             '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
#             '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent): #一个句子对应了一个序列
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
#     return [label for token,label in sent]
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token,label in sent]

In [8]:
%%time
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

CPU times: user 15.9 s, sys: 4.32 s, total: 20.2 s
Wall time: 23.7 s


In [9]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 5min 14s, sys: 16.5 s, total: 5min 30s
Wall time: 6min 16s


In [10]:
labels = list(crf.classes_)
labels.remove('O')
labels

['B-LOC', 'I-LOC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']

In [11]:
y_pred = crf.predict(X_test)

In [12]:
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

             precision    recall  f1-score   support

      B-LOC      0.921     0.754     0.829      2877
      I-LOC      0.883     0.734     0.801      4394
      B-ORG      0.889     0.566     0.691      1331
      I-ORG      0.884     0.653     0.751      5670
      B-PER      0.947     0.646     0.768      1973
      I-PER      0.836     0.913     0.873      3851

avg / total      0.886     0.729     0.794     20096



In [13]:
from seqeval.metrics import classification_report
from functools import reduce
y_test = reduce(lambda x,y : x + y, y_test)
y_pred = reduce(lambda x,y : x + y, y_pred)
print(classification_report(y_test,y_pred,digits=4))

             precision    recall  f1-score   support

        LOC     0.8457    0.7070    0.7702      2877
        PER     0.8376    0.5753    0.6821      1973
        ORG     0.8406    0.5349    0.6538      1331

avg / total     0.8420    0.6279    0.7170      6181



In [42]:
X_test[0]

[{'bias': 1.0,
  'word.lower()': '中',
  'word.isupper()': False,
  'word.isdigit()': False,
  'postag': 'nt',
  'BOS': True,
  '+1:word.lower()': '共',
  '+1:word.isupper()': False,
  '+1:postag': 'nt'},
 {'bias': 1.0,
  'word.lower()': '共',
  'word.isupper()': False,
  'word.isdigit()': False,
  'postag': 'nt',
  '-1:word.lower()': '中',
  '-1:word.isupper()': False,
  '-1:word.isdigit()': False,
  '-1:postag': 'nt',
  '+1:word.lower()': '中',
  '+1:word.isupper()': False,
  '+1:postag': 'nt'},
 {'bias': 1.0,
  'word.lower()': '中',
  'word.isupper()': False,
  'word.isdigit()': False,
  'postag': 'nt',
  '-1:word.lower()': '共',
  '-1:word.isupper()': False,
  '-1:word.isdigit()': False,
  '-1:postag': 'nt',
  '+1:word.lower()': '央',
  '+1:word.isupper()': False,
  '+1:postag': 'nt'},
 {'bias': 1.0,
  'word.lower()': '央',
  'word.isupper()': False,
  'word.isdigit()': False,
  'postag': 'nt',
  '-1:word.lower()': '中',
  '-1:word.isupper()': False,
  '-1:word.isdigit()': False,
  '-1:posta