In [1]:
from kashgari.embeddings import BERTEmbedding
embedding = BERTEmbedding('bert-base-chinese', 100)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import sys
sys.path.extend(['../src'])
from config import Config

config = Config()
config.train_path = '../source/data/ner_data/ner_train.txt'
f = open(config.train_path, encoding = 'UTF-8')

In [3]:
config.vocab_path = '../src/model/vocab.txt'

def build_data():
    datas = []
    sample_x = []
    sample_y = []
    vocabs = {'UNK'}
    for line in open(config.train_path, encoding = 'UTF-8'):
        line = line.rstrip().split('\t')
        if not line:
            continue
        char = line[0]
        if not char:
            continue
        cate = line[-1]
        sample_x.append(char)
        sample_y.append(cate)
        vocabs.add(char)
        if char in ['。','?','!','！','？']:
            datas.append([sample_x, sample_y])
            sample_x = []
            sample_y = []
    word_dict = {wd:index for index, wd in enumerate(list(vocabs))}
    def write_file(wordlist, filepath):
        with open(filepath, 'w+', encoding = 'UTF-8') as f:
            f.write('\n'.join(wordlist))
    write_file(list(vocabs), config.vocab_path)
    return datas, word_dict

In [4]:
datas, word_dict = build_data()

In [6]:
import random

class_dict ={
                         'O':0,
                         'B-TREATMENT': 1,
                         'I-TREATMENT': 2,
                         'B-BODY': 3,
                         'I-BODY': 4,
                         'B-SIGNS': 5,
                         'I-SIGNS': 6,
                         'B-CHECK': 7,
                         'I-CHECK': 8,
                         'B-DISEASE': 9,
                         'I-DISEASE': 10
                        }

random.shuffle(datas)
x = [[char for char in data[0]] for data in datas]
y = [[label for label in data[1]] for data in datas]

In [7]:
validation_split = 0.2

x_train = x[:int(len(x)*validation_split)]
y_train = y[:int(len(y)*validation_split)]
x_valid = x[int(len(x)*validation_split)+1:]
y_valid = y[int(len(y)*validation_split)+1:]

In [8]:
from kashgari.tasks.seq_labeling import BLSTMCRFModel

# 还可以选择 `BLSTMModel` 和 `CNNLSTMModel` 

model = BLSTMCRFModel(embedding)
model.fit(x_train,
          y_train,
          x_validate=x_valid,
          y_validate=y_valid,
          epochs=5,
          batch_size=500)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input-Token (InputLayer)        (None, 100)          0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, 100)          0                                            
__________________________________________________________________________________________________
Embedding-Token (TokenEmbedding [(None, 100, 768), ( 16226304    Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, 100, 768)     1536        Input-Segment[0][0]              
__________________________________________________________________________________________________
Embedding-

__________________________________________________________________________________________________
Encoder-4-FeedForward (FeedForw (None, 100, 768)     4722432     Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-FeedForward-Dropout ( (None, 100, 768)     0           Encoder-4-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-4-FeedForward-Add (Add) (None, 100, 768)     0           Encoder-4-MultiHeadSelfAttention-
                                                                 Encoder-4-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-4-FeedForward-Norm (Lay (None, 100, 768)     1536        Encoder-4-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-5-

Encoder-9-MultiHeadSelfAttentio (None, 100, 768)     2362368     Encoder-8-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 100, 768)     0           Encoder-9-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 100, 768)     0           Encoder-8-FeedForward-Norm[0][0] 
                                                                 Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 100, 768)     1536        Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-FeedForward (FeedForw (None, 100, 768)     4722432     Encoder-9-MultiHeadSelfAttention-
__________

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [18]:
def build_input(text):
    datas = []
    x = []
    for char in text:
        x.append(char)
        if char in ['。','?','!','！','？'] or text.index(char) == len(text)-1:
            datas.append(x)
            x = []
    return datas

In [28]:
text = '1.患者老年女性，88岁；2.既往体健，否认药物过敏史。3.患者缘于5小时前不慎摔伤，伤及右髋部。伤后患者自感伤处疼痛，呼我院120接来我院，查左髋部X光片示：左侧粗隆间骨折。给予补液等对症治疗。患者病情平稳，以左侧粗隆间骨折介绍入院。患者自入院以来，无发热，无头晕头痛，无恶心呕吐，无胸闷心悸，饮食可，小便正常，未排大便。4.查体：T36.1C，P87次/分，R18次/分，BP150/93mmHg,心肺查体未见明显异常，专科情况：右下肢短缩畸形约2cm，右髋部外旋内收畸形，右髋部压痛明显，叩击痛阳性,右髋关节活动受限。右足背动脉波动好，足趾感觉运动正常。5.辅助检查：本院右髋关节正位片：右侧股骨粗隆间骨折。'

def predict(text):
    new_model = BLSTMCRFModel.load_model('../src/model/bert_model_20.h5')
    x_test = build_input(text)
    result = model.predict(x_test)
    chars = [i for i in text]
    tags = []
    for i in range(len(result)):
        tags = result_ + result[i]
    res = list(zip(chars, tags))
    print(res)