In [17]:
# !pip install spacy -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install gensim -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install scikit-learn==0.24.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install tensorflow==1.12.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install keras==2.2.4 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install cnradical==0.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install jieba==0.42.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install numpy==1.16.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install pandas==0.25.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install tqdm==4.39.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install keras-self-attention==0.49.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

In [18]:
# 导入所需文件
import numpy as np
from sklearn.model_selection import ShuffleSplit
from data_utils import ENTITIES, Documents, Dataset, SentenceExtractor, make_predictions
from data_utils import Evaluator
from gensim.models import Word2Vec

In [19]:
# 数据文件读取
data_dir = "./data/train"
ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
idx2ent = dict([(v, k) for k, v in ent2idx.items()])

In [20]:
# 训练集，测试集切分与打乱
docs = Documents(data_dir=data_dir)
rs = ShuffleSplit(n_splits=1, test_size=20, random_state=2018)
train_doc_ids, test_doc_ids = next(rs.split(docs))
train_docs, test_docs = docs[train_doc_ids], docs[test_doc_ids]

In [21]:
# 模型参数赋值
num_cates = max(ent2idx.values()) + 1
sent_len = 64
vocab_size = 3000
emb_size = 100
sent_pad = 10
sent_extrator = SentenceExtractor(window_size=sent_len, pad_size=sent_pad)
train_sents = sent_extrator(train_docs)
test_sents = sent_extrator(test_docs)
train_data = Dataset(train_sents, cate2idx=ent2idx)
train_data.build_vocab_dict(vocab_size=vocab_size)
test_data = Dataset(test_sents, word2idx=train_data.word2idx, cate2idx=ent2idx)
vocab_size = len(train_data.word2idx)

In [22]:
# 构建词嵌入模型
w2v_train_sents = []
for doc in docs:
    w2v_train_sents.append(list(doc.text))
w2v_model = Word2Vec(w2v_train_sents, vector_size=emb_size)
w2v_embeddings = np.zeros((vocab_size, emb_size))
for char, char_idx in train_data.word2idx.items():
    if char in w2v_model.wv:
        w2v_embeddings[char_idx] = w2v_model.wv[char]

In [23]:
# 安装keras-contribute
# git clone https://github.com/keras-team/keras-contrib.git
# cd keras-contrib
# pip install .

In [24]:
# 构建概率图模型——条件随机场
import keras
from keras.layers import Input, Embedding
from keras_contrib.layers import CRF
from keras.models import Model
def build_crf_model(num_cates, seq_len, vocab_size, model_opts=dict()):
    opts = {
        'emb_size': 256,
        'emb_trainable': True,
        'emb_matrix': None,
        'optimizer': keras.optimizers.Adam()
    }
    opts.update(model_opts)

    input_seq = Input(shape=(seq_len,), dtype='int32')
    if opts.get('emb_matrix') is not None:
        embedding = Embedding(vocab_size, opts['emb_size'], 
                              weights=[opts['emb_matrix']],
                              trainable=opts['emb_trainable'])
    else:
        embedding = Embedding(vocab_size, opts['emb_size'])
    x = embedding(input_seq)
    crf = CRF(num_cates, sparse_target=True)
    output = crf(x)

    model = Model(input_seq, output)
    model.compile(opts['optimizer'], loss=crf.loss_function, metrics=[crf.accuracy])
    return model

In [25]:
# CRF条件随机场实例化
seq_len = sent_len + 2 * sent_pad
model = build_crf_model(num_cates, seq_len=seq_len, vocab_size=vocab_size,model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 84)                0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 84, 100)           215400    
_________________________________________________________________
crf_2 (CRF)                  (None, 84, 16)            1904      
Total params: 217,304
Trainable params: 1,904
Non-trainable params: 215,400
_________________________________________________________________


In [26]:
# 训练集，测试集形状
train_X, train_y = train_data[:]
print('train_X.shape', train_X.shape)
print('train_y.shape', train_y.shape)

train_X.shape (2622, 84)
train_y.shape (2622, 84, 1)


In [27]:
# 条件随机场模型训练
model.fit(train_X, train_y, batch_size=64, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x176196954e0>

In [28]:
# 模型预测
test_X, _ = test_data[:]
preds = model.predict(test_X, batch_size=64, verbose=True)
pred_docs = make_predictions(preds, test_data, sent_pad, docs, idx2ent)



In [29]:
# 输出评价指标
f_score, precision, recall = Evaluator.f1_score(test_docs, pred_docs)
print('f_score: ', f_score)
print('precision: ', precision)
print('recall: ', recall)

f_score:  0.44766793126534254
precision:  0.5110177047509872
recall:  0.3982924650054601


In [None]:
# 测试样本展示
sample_doc_id = list(pred_docs.keys())[3]
test_docs[sample_doc_id]

In [32]:
w2v_embeddings[1]

array([ 0.19941245, -0.26504546,  0.15430012, -0.15829962, -0.10389724,
        0.19118346, -0.11953872, -0.76132911, -0.26291174, -0.38873449,
        0.01347194, -0.2559436 ,  0.15380643,  0.26135439, -0.25552344,
        0.29242265,  0.09382474,  0.22370352,  0.02202507, -0.07514565,
        0.82636875, -0.52704293,  0.01793255, -0.79269636,  0.35236117,
       -0.02104079, -0.04055708, -0.18536669,  0.12498295, -0.03320423,
       -0.1734913 ,  0.11356194,  0.38451925,  0.07955495,  0.49044865,
       -0.40859443, -0.1679301 , -0.34706986, -0.37893501, -0.35810554,
        0.13111269,  0.12016071,  0.08605853, -0.15480822,  0.22341111,
       -0.37819988, -0.01085286,  0.05575052, -0.53235483, -0.24436882,
        0.11880824,  0.01603009, -0.42114711,  0.04289829,  0.0078109 ,
       -0.18823852, -0.0396415 , -0.12771091,  0.02638644, -0.2572166 ,
       -0.15825768,  0.8555066 , -0.02068132, -0.23514207, -0.28015685,
        0.65361744,  0.06349469,  0.34111962, -0.79231268,  0.38