<a href="https://colab.research.google.com/github/liuhuiaren0524/albert_ner/blob/main/albert_ner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#作用：加载NER的预训练模型
import os
import random
import numpy as np
# import pandas as pd
from bert4keras.backend import keras, K   #bert4keras基于keras的预训练模型加载框架，支持多种预训练模型（BERT、ALBERT、RoBERTa、ALBERT、NEZHA、GPT2、T5等）
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open, ViterbiDecoder
from bert4keras.layers import ConditionalRandomField
from keras.layers import Dense
from keras.models import Model
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
# 版本不一致导致需要显式定义
# from bert4keras.snippets import to_array

def to_array(*args):
    """批量转numpy的array
    """
    results = [np.array(a) for a in args]
    if len(args) == 1:
        return results[0]
    else:
        return results

# 参数
maxlen = 256
epochs = 10
batch_size = 32
bert_layers = 12
learing_rate = 1e-5  # bert_layers越小，学习率应该要越大
crf_lr_multiplier = 1000  # 必要时扩大CRF层的学习率


# bert配置
config_path = '/home/BERT/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/home/BERT/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/home/BERT/chinese_L-12_H-768_A-12/vocab.txt'


# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
# 类别映射
labels = ['CAR'] #标签列
id2label = dict(enumerate(labels))
label2id = {j: i for i, j in id2label.items()}
num_labels = len(labels) * 2 + 1



#后面的代码使用的是bert类型的模型，如果你用的是albert，那么前几行请改为：
model = build_transformer_model(
    config_path,
    checkpoint_path,
    model='albert',
)
output_layer = 'Transformer-FeedForward-Norm'
output = model.get_layer(output_layer).get_output_at(bert_layers - 1)

"""
model = build_transformer_model(
    config_path,
    checkpoint_path,
)

output_layer = 'Transformer-%s-FeedForward-Norm' % (bert_layers - 1)
output = model.get_layer(output_layer).output
"""
output = Dense(num_labels)(output)
CRF = ConditionalRandomField(lr_multiplier=crf_lr_multiplier)
output = CRF(output)

model = Model(model.input, output)
# model.summary()


class NamedEntityRecognizer(ViterbiDecoder):
    """命名实体识别器
    """

    def recognize(self, text, location=False):
        tokens = tokenizer.tokenize(text)
        while len(tokens) > 512:
            tokens.pop(-2)
        mapping = tokenizer.rematch(text, tokens)
        token_ids = tokenizer.tokens_to_ids(tokens)
        segment_ids = [0] * len(token_ids)
        token_ids, segment_ids = to_array([token_ids], [segment_ids])
        nodes = model.predict([token_ids, segment_ids])[0]
        labels = self.decode(nodes)
        entities, starting = [], False
        for i, label in enumerate(labels):
            if label > 0:
                if label % 2 == 1:
                    starting = True
                    entities.append([[i], id2label[(label - 1) // 2]])
                elif starting:
                    entities[-1][0].append(i)
                else:
                    starting = False
            else:
                starting = False
        if location:
            r = []
            for w, l in entities:
                i, j = mapping[w[0]][0], mapping[w[-1]][-1] + 1
                r.append(('{}({}:{})'.format(text[i:j], i, j), l))
            return r
        else:
            return [(text[mapping[w[0]][0]:mapping[w[-1]][-1] + 1], l) for w, l in entities]


NER = NamedEntityRecognizer(trans=K.eval(CRF.trans), starts=[0], ends=[0])


def evaluate(data, location=False):
    """评测函数
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    for d in tqdm(data):
        text = ''.join([i[0] for i in d])
        R = set(NER.recognize(text, location=location))
        if location:
            T = set()
            cunsum = 0
            for w, l in d:
                if l == 'O':
                    cunsum += len(w)
                else:
                    T.add(('{}({}:{})'.format(w, cunsum, cunsum + len(w)), l))
                    cunsum += len(w)
        else:
            T = set([tuple(i) for i in d if i[1] != 'O'])
        X += len(R & T)
        Y += len(R)
        Z += len(T)
    print('Text: ', text)
    print('Pred: ', R)
    print('True: ', T)
    f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
    return f1, precision, recall
