<a href="https://colab.research.google.com/github/liuhuiaren0524/Bilstm-crf_ner-bert_classify/blob/main/multilabel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from bert4keras.backend import keras, set_gelu
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from keras.layers import Lambda, Dense, Dropout, Bidirectional, LSTM
from keras.callbacks import LearningRateScheduler
import keras.backend as K
from keras.optimizers import SGD, Adagrad, RMSprop, Adadelta, Adamax, Nadam
import tensorflow as tf
import os
import re
import sys
import logging
import time
import pandas as pd
import numpy as np

LOG_FORMAT = "%(asctime)s %(name)s %(levelname)s %(pathname)s %(message)s "#配置输出日志格式
DATE_FORMAT = '%Y-%m-%d  %H:%M:%S %a ' #配置输出时间的格式，注意月份和天数不要搞乱了
logging.basicConfig(level=logging.DEBUG,
                    format=LOG_FORMAT,
                    datefmt = DATE_FORMAT ,
                    # filename=r"" #有了filename参数就不会直接输出显示到控制台，而是直接写入文件
                    )

os.environ["CUDA_VISIBLE_DEVICES"]= "2"
set_gelu('tanh')  # 切换gelu版本

from config import CONFIG
from utils import retro_dictify

labelpath = CONFIG.labelfile
tsc_rule_df = pd.read_excel(labelpath, sheet_name='投诉分类完善').fillna(method='ffill')
tsc_rule_df.columns = [re.sub(r'\(.*?\)', '', t) for t in tsc_rule_df.columns]
tsc_rule_df['大类'] = tsc_rule_df['大类'].str.split('\n').apply(lambda x: re.sub('（.*', '', x[-1]).strip())
tcs_rule = retro_dictify(tsc_rule_df[['大类', '小类', '标签', '建单判定条件']])
label = []
for bk in tcs_rule:
    for mk in tcs_rule[bk]:
        for sk in tcs_rule[bk][mk]:
            lb = (bk, mk, sk)
            label.append(lb)
sml_label2idx = {lb : i for i, lb in enumerate(label)}
idx2sml_label = {idx: lb  for lb, idx in sml_label2idx.items()}

NUM_CLASS = len(label)
MAX_LEN = 512
BATCH_SIZE = 8


CONFIG_PATH = CONFIG.albert_config_path
CHECKPOINT_PATH = CONFIG.albert_checkpoint_path
DICT_PATH = CONFIG.albert_dict_path

# 建立分词器
tokenizer = Tokenizer(DICT_PATH, do_lower_case=True)

def label2vec(labels, lb2id):
    out = np.zeros([len(labels), len(lb2id)])
    for i, lb in enumerate(labels):
        for it in lb:
            j = lb2id.get(it)
            if j:
                out[i, j] = 1
    return out


class data_generator(DataGenerator):
    """数据生成器
    """
    def __init__(self, data, maxlen, batch_size=32, buffer_size=None):
        super(data_generator, self).__init__(data, batch_size=batch_size, buffer_size=None)
        self.maxlen = maxlen
        
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels = [], [], []
        for is_end, (text, label) in self.sample(random):
            token_ids, segment_ids = tokenizer.encode(text, maxlen=self.maxlen) # tokenizer 需要提前定义
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append(label)
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_labels = label2vec(batch_labels, sml_label2idx)
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids, batch_segment_ids, batch_labels = [], [], []
                
def bert_token(data, maxlen):
    token_ids, segment_ids = [], []
    for i, text in enumerate(data):
        t_ids, s_ids = tokenizer.encode(text, maxlen=maxlen) # tokenizer 需要提前定义
        token_ids.append(t_ids)
        segment_ids.append(s_ids)
    token_ids = sequence_padding(token_ids)
    segment_ids = sequence_padding(segment_ids)
    return [token_ids, segment_ids]

                
def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    说明：y_true和y_pred的shape一致，y_true的元素非0即1，
         1表示对应的类为目标类，0表示对应的类为非目标类。
    警告：请保证y_pred的值域是全体实数，换言之一般情况下y_pred
         不用加激活函数，尤其是不能加sigmoid或者softmax！预测
         阶段则输出y_pred大于0的类。如有疑问，请仔细阅读并理解
         本文。
    """
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = K.zeros_like(y_pred[..., :1])
    y_pred_neg = K.concatenate([y_pred_neg, zeros], axis=-1)
    y_pred_pos = K.concatenate([y_pred_pos, zeros], axis=-1)
    neg_loss = K.logsumexp(y_pred_neg, axis=-1)
    pos_loss = K.logsumexp(y_pred_pos, axis=-1)
    return neg_loss + pos_loss


# def evaluate(data):
#     total, right = 0., 0.
#     for x_true, y_true in data:
#         y_pred = model.predict(x_true).argmax(axis=1)
#         y_true = y_true[:, 0]
#         total += len(y_true)
#         right += (y_true == y_pred).sum()
#     return right / total


# class Evaluator(keras.callbacks.Callback):
#     """评估与保存
#     """
#     def __init__(self):
#         self.best_val_acc = 0.

#     def on_epoch_end(self, epoch, logs=None):
#         val_acc = evaluate(valid_generator)
#         if val_acc > self.best_val_acc:
#             self.best_val_acc = val_acc
#             model.save_weights('best_model.weights')
#         test_acc = evaluate(test_generator)
#         print(
#             u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\n' %
#             (val_acc, self.best_val_acc, test_acc)
#         )
        
        

# 加载预训练模型
bert = build_transformer_model(config_path=CONFIG_PATH,
                               checkpoint_path=CHECKPOINT_PATH,
                               model='albert', # 原始bert需要注释掉 
                               return_keras_model=False,
                              )
output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
output = Dense(units=NUM_CLASS,
               kernel_initializer=bert.initializer
              )(output)

model = keras.models.Model(bert.model.input, output)


def train():
    # 模型编译
    model.compile(loss=multilabel_categorical_crossentropy,
                  optimizer=Adam(1e-4),  # 用足够小的学习率
                  metrics=[f1_score, precision, recall],
                 )
    # 转换数据集
    train_generator = data_generator(train_data,MAX_LEN, BATCH_SIZE)

    
def predict(inputfile, textcolumn, outputfile):
    def sigmoid(x):
       y = 1.0 / (1.0 + np.exp(-x))
       return y
    logging.info('加载训练好的模型...')
    model.load_weights('./ckpt/weights.hdf5')
    logging.info('加载模型完成.')
    logging.info('读取待预测文件...')
    ftype = inputfile.split('.')[-1]
    if ftype == 'csv':
       df = pd.read_csv(inputfile)
    elif ftype == 'xlsx':
       df = pd.read_excel(inputfile)
    else:
       raise Exception('输入文件格式不正确，合理的输入文件格式为："csv","xlsx"') 
    logging.info('待预测文件读取完成.')
    logging.info('开始模型预测...')
    pred = []
    X  = bert_token(df[textcolumn].values, MAX_LEN)
    pred = model.predict(X)
    pred = sigmoid(pred)
    pred = [np.argwhere(row > 0.5).flatten().tolist() for row in pred]
    pred = [[idx2sml_label.get(i) for i in row]for row in pred]
    logging.info('预测完成.')
    
    logging.info('把结果写入到目标文件...')
    df[textcolumn+'_tcs_label'] = pred
    if ftype == 'csv':
       df.to_csv(outputfile, index=False, encoding='utf-8-sig')
    elif ftype == 'xlsx':
       df.to_excel(outputfile, index=False)
    df.to_csv(outputfile, index=False, encoding='utf-8-sig')
    logging.info('结果写入完成.')


def main():
    # 解析命令行参数
    try: 
        mode = sys.argv[1]
        if mode == 'train':
            train_path, valid_path = sys.argv[2], sys.argv[3]
        elif mode == 'test':
            test_path = sys.argv[2]
        elif mode == 'predict':
            pred_path = sys.argv[2]
            textcolumn = sys.argv[3]
        else:
            raise ValueError
    except Exception as e:
        raise ValueError('无效的命令行参数', sys.argv)
    # 主程序
    if mode == 'train':
        logging.info('开始TCS标签分类模型训练，训练集{}，验证集{}'.format(train_path, valid_path))
        st = time.time()
        train(train_path, valid_path)
        et = time.time()
        logging.info('TCS标签分类模型训练完成，用时{:.2f}s'.format(et-st))
    elif mode == 'test':
        logging.info('开始TCS标签分类模型测试，测试集{}'.format(test_path))
        st = time.time()
        test(test_path)
        et = time.time()
        logging.info('TCS标签分类模型测试完成，用时{:.2f}s'.format(et-st))
    elif mode == 'predict':
        pred_out_path = pred_path.replace('.', '_label.')
        logging.info('开始TCS标签分类模型预测，预测文件{}，结果保存文件{}'.format(pred_path, pred_out_path))
        st = time.time()
        predict(pred_path, textcolumn, pred_out_path)
        et = time.time()
        logging.info('TCS标签分类模型预测完成，用时{:.2f}s'.format(et-st))


if __name__ == '__main__':
    main()