## MindSpore-BERT-NER

首先解决 ipykernel 版本问题，pip install ipykernel --upgrade

In [21]:
# import moxing as mox
# mox.file.copy_parallel(src_url="obs://nlp-workspace/3.ner/src/", dst_url='./src/') 
# mox.file.copy_parallel(src_url="obs://nlp-workspace/3.ner/data/", dst_url='./data/')
# mox.file.copy_parallel(src_url="obs://nlp-workspace/3.ner/pre_model/", dst_url='./pre_model/')

conda install -n mindspore scikit-learn

### 2. 导入依赖库

In [4]:
import os
import argparse
import numpy as np
import json
from sklearn.metrics import classification_report # 需放在前面导入

import mindspore.nn as nn
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
import mindspore.dataset as de
from mindspore.ops import operations as P
import mindspore.dataset.transforms.c_transforms as C
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import TruncatedNormal

from src import tokenization
from src.CRF import CRF
from src.CRF import postprocess
from src.cluener_evaluation import process_one_example_p, label_generation
from src.utils import BertLearningRate
from src.bert_for_finetune import BertFinetuneCell
from src.config import optimizer_cfg
from src.bert_model import BertConfig, BertModel

In [3]:
# context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

### 3. 定义参数配置

原来的 train file 和 eval file，checkpoint，还有 premodel 一开始没有，得从网上自己下。只能说实验文档的作用有但不多。

In [5]:
cfg = edict({
    'is_train': True,
    'num_labels': 41,                 
    'schema_file': r'./data/clue_ner/schema.json',
    'ckpt_prefix': 'bert-ner-crf',          # 'bert-ner'  'bert-ner-crf'
    'train_file': r'./data/clue_ner/train.tf_record', 
    'eval_file': r'./data/clue_ner/dev.tf_record',
    # 'train_file': None, 
    # 'eval_file': None,
    'use_crf': True,         

    'epoch_num': 5,
    'batch_size': 16,
    'ckpt_dir': 'ckpt',
    'pre_training_ckpt': './pre_model/bert_base.ckpt',
    # 'pre_training_ckpt': None,

    # 'finetune_ckpt': './ckpt/bert-ner-crf-5_671.ckpt', 
    # bertbilstmcrf_ascend_v170_ner_official_nlp_f1acc99.30
    'finetune_ckpt': './ckpt/bertbilstmcrf_ascend_v170_ner_official_nlp_f1acc99.30.ckpt', 
    'label2id_file': './data/clue_ner/label2id.json',
    'vocab_file': './data/vocab.txt',
    'eval_out_file': 'ner_crf_result.txt'      #  ner_result.txt   ner_crf_result.txt
})

bert_net_cfg = BertConfig(
    seq_length=128,
    vocab_size=21128,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=512,
    type_vocab_size=2,
    initializer_range=0.02,
    use_relative_positions=False,
    dtype=mstype.float32,
    compute_type=mstype.float16
)

### 4. 定义数据集加载函数

In [6]:
def get_dataset(data_file, schema_file, batch_size):
    '''
    get dataset
    '''
    ds = de.TFRecordDataset([data_file], schema_file, columns_list=["input_ids", "input_mask","segment_ids", "label_ids"])
    type_cast_op = C.TypeCast(mstype.int32)
    ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
    ds = ds.map(input_columns="input_mask", operations=type_cast_op)
    ds = ds.map(input_columns="input_ids", operations=type_cast_op)
    ds = ds.map(input_columns="label_ids", operations=type_cast_op)
    
    # apply shuffle operation
    buffer_size = 960
    ds = ds.shuffle(buffer_size=buffer_size)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

数据集测试

这里一开始测不了，别测了

先去把 https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com/NLP/NLP.zip 下了

In [7]:
next(get_dataset(cfg.train_file, cfg.schema_file, batch_size=1).create_dict_iterator())['input_ids'][0]



Tensor(shape=[128], dtype=Int32, value= [ 101, 2512, 5865, 1291,  680, 1169, 4275, 3175, 2199, 5468, 1394, 1403,  677, 3862, 2356, 5018,  671,  704, 5277,  782, 3696, 3791, 7368, 2990, 
 6629, 6401, 6390, 8024, 3418, 2945,  100, 4510, 7723,  100, 4638, 4157, 1140, 4372, 8024,  102,    0,    0,    0,    0,    0,    0,    0,    0, 
    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0, 
    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0, 
    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0, 
    0,    0,    0,    0,    0,    0,    0,    0])

### 5. 定义BertNER模型

In [8]:
class BertNER(nn.Cell):
    """
    Train interface for sequence labeling finetuning task.
    """
    def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
                 use_one_hot_embeddings=False):
        super(BertNER, self).__init__()
        self.bert = BertModel(config, is_training, use_one_hot_embeddings)
        self.cast = P.Cast()
        self.weight_init = TruncatedNormal(config.initializer_range)
        self.log_softmax = P.LogSoftmax(axis=-1)
        self.dtype = config.dtype
        self.num_labels = num_labels
        self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
                                has_bias=True).to_float(config.compute_type)
        self.dropout = nn.Dropout(1 - dropout_prob)
        self.reshape = P.Reshape()
        self.shape = (-1, config.hidden_size)
        self.use_crf = use_crf
        self.origin_shape = (batch_size, config.seq_length, self.num_labels)
        if use_crf:
            if not tag_to_index:
                raise Exception("The dict for tag-index mapping should be provided for CRF.")
            self.loss = CRF(tag_to_index, batch_size, config.seq_length, is_training)
        else:
            self.loss = CrossEntropyCalculation(is_training)
        self.num_labels = num_labels
        self.use_crf = use_crf
        
    def construct(self, input_ids, input_mask, token_type_id, label_ids):
        sequence_output, _, _ = \
            self.bert(input_ids, token_type_id, input_mask)
        seq = self.dropout(sequence_output)
        seq = self.reshape(seq, self.shape)
        logits = self.dense_1(seq)
        logits = self.cast(logits, self.dtype)
        
        if self.use_crf:
            return_value = self.reshape(logits, self.origin_shape)
            loss = self.loss(return_value, label_ids)
        else:
            return_value = self.log_softmax(logits)
            loss = self.loss(return_value, label_ids, self.num_labels)
        return loss

### 6. 加载词汇-id映射表

In [9]:
tag_to_index = json.loads(open(cfg.label2id_file).read())

if cfg.use_crf:
    print(tag_to_index)
    max_val = len(tag_to_index)
    tag_to_index["<START>"] = max_val
    tag_to_index["<STOP>"] = max_val + 1
    number_labels = len(tag_to_index)
else:
    number_labels = cfg.num_labels

{'O': 0, 'S_address': 1, 'B_address': 2, 'M_address': 3, 'E_address': 4, 'S_book': 5, 'B_book': 6, 'M_book': 7, 'E_book': 8, 'S_company': 9, 'B_company': 10, 'M_company': 11, 'E_company': 12, 'S_game': 13, 'B_game': 14, 'M_game': 15, 'E_game': 16, 'S_government': 17, 'B_government': 18, 'M_government': 19, 'E_government': 20, 'S_movie': 21, 'B_movie': 22, 'M_movie': 23, 'E_movie': 24, 'S_name': 25, 'B_name': 26, 'M_name': 27, 'E_name': 28, 'S_organization': 29, 'B_organization': 30, 'M_organization': 31, 'E_organization': 32, 'S_position': 33, 'B_position': 34, 'M_position': 35, 'E_position': 36, 'S_scene': 37, 'B_scene': 38, 'M_scene': 39, 'E_scene': 40}


### 7. 定义训练函数

In [10]:
def train():
    '''
    finetune function
    '''
    # BertNER train for sequence labeling

    netwithloss = BertNER(bert_net_cfg, cfg.batch_size, True, num_labels=number_labels,
                          use_crf=cfg.use_crf,
                          tag_to_index=tag_to_index, dropout_prob=0.1)

    dataset = get_dataset(data_file=cfg.train_file, schema_file=cfg.schema_file, batch_size=cfg.batch_size)
    steps_per_epoch = dataset.get_dataset_size()
    print('steps_per_epoch:',steps_per_epoch)

    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                   end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                   warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
                                   decay_steps=steps_per_epoch * cfg.epoch_num,
                                   power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                    {'params': other_params, 'weight_decay': 0.0}]
    optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
        
    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossMonitor(), ckpoint_cb]
    model.train(cfg.epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=True)

### 8. 启动训练

开始训练了才发现你 mindspore 得安装 gpu 或者 npu 版本的，结果文档里说 cpu 也可以，这导致代码里面 NPUAlloc 一个 API 用不了。

pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.7.1/MindSpore/cpu/x86_64/mindspore-1.7.1-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple 

只好到这个源去找 gpu 版本

pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.7.1/MindSpore/gpu/x86_64/cuda-11.1/mindspore_gpu-1.7.1-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

结果发现只支持 cuda 11.1

到这里我们已经意识到 docker 上跑才是最好的，可惜已经迟了。

总之训练就不训练了，反正也就那样。

In [11]:
train()



steps_per_epoch: 671




TypeError: Cannot join the return values of different branches, perhaps you need to make them equal.
Type Join Failed: dtype1 = Bool, dtype2 = Anything.
For more details, please refer to https://www.mindspore.cn/search?inputValue=Type%20Join%20Failed

Inner Message:
The abstract type of the return value of the current branch is AbstractTuple{element[0]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[1]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[2]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[3]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[4]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[5]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[6]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[7]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[8]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[9]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[10]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[11]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[12]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[13]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[14]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[15]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[16]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[17]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[18]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[19]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[20]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[21]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[22]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[23]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[24]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[25]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[26]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[27]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[28]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[29]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[30]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[31]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[32]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[33]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[34]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[35]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[36]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[37]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[38]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[39]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[40]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[41]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[42]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[43]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[44]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[45]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[46]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[47]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[48]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[49]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[50]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[51]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[52]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[53]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[54]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[55]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[56]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[57]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[58]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[59]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[60]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[61]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[62]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[63]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[64]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[65]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[66]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[67]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[68]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[69]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[70]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[71]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[72]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[73]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[74]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[75]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[76]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[77]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[78]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[79]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[80]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[81]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[82]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[83]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[84]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[85]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[86]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[87]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[88]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[89]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[90]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[91]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[92]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[93]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[94]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[95]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[96]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[97]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[98]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[99]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[100]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[101]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[102]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[103]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[104]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[105]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[106]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[107]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[108]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[109]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[110]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[111]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[112]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[113]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[114]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[115]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[116]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[117]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[118]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[119]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[120]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[121]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[122]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[123]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[124]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[125]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[126]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[127]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[128]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[129]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[130]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[131]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[132]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[133]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[134]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[135]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[136]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[137]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[138]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[139]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[140]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[141]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[142]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[143]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[144]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[145]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[146]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[147]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[148]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[149]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[150]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[151]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[152]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[153]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[154]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[155]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[156]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[157]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[158]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[159]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[160]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[161]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[162]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[163]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[164]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[165]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[166]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[167]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[168]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[169]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[170]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[171]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[172]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[173]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[174]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[175]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[176]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[177]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[178]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[179]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[180]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[181]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[182]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[183]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[184]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[185]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[186]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[187]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[188]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[189]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[190]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[191]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[192]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[193]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[194]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[195]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[196]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[197]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[198]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[199]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[200]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), element[201]: AbstractScalar(Type: Bool, Value: AnyValue, Shape: NoShape), sequence_nodes: {@hyper_map.50:optim_result{[0]: ValueNode<Primitive> MakeTuple, [1]: optim_result, [2]: optim_result, [3]: optim_result, [4]: optim_result, [5]: optim_result, [6]: optim_result, [7]: optim_result, [8]: optim_result, [9]: optim_result, [10]: optim_result, [11]: optim_result, [12]: optim_result, [13]: optim_result, [14]: optim_result, [15]: optim_result, [16]: optim_result, [17]: optim_result, [18]: optim_result, [19]: optim_result, [20]: optim_result, [21]: optim_result, [22]: optim_result, [23]: optim_result, [24]: optim_result, [25]: optim_result, [26]: optim_result, [27]: optim_result, [28]: optim_result, [29]: optim_result, [30]: optim_result, [31]: optim_result, [32]: optim_result, [33]: optim_result, [34]: optim_result, [35]: optim_result, [36]: optim_result, [37]: optim_result, [38]: optim_result, [39]: optim_result, [40]: optim_result, [41]: optim_result, [42]: optim_result, [43]: optim_result, [44]: optim_result, [45]: optim_result, [46]: optim_result, [47]: optim_result, [48]: optim_result, [49]: optim_result, [50]: optim_result, [51]: optim_result, [52]: optim_result, [53]: optim_result, [54]: optim_result, [55]: optim_result, [56]: optim_result, [57]: optim_result, [58]: optim_result, [59]: optim_result, [60]: optim_result, [61]: optim_result, [62]: optim_result, [63]: optim_result, [64]: optim_result, [65]: optim_result, [66]: optim_result, [67]: optim_result, [68]: optim_result, [69]: optim_result, [70]: optim_result, [71]: optim_result, [72]: optim_result, [73]: optim_result, [74]: optim_result, [75]: optim_result, [76]: optim_result, [77]: optim_result, [78]: optim_result, [79]: optim_result, [80]: optim_result, [81]: optim_result, [82]: optim_result, [83]: optim_result, [84]: optim_result, [85]: optim_result, [86]: optim_result, [87]: optim_result, [88]: optim_result, [89]: optim_result, [90]: optim_result, [91]: optim_result, [92]: optim_result, [93]: optim_result, [94]: optim_result, [95]: optim_result, [96]: optim_result, [97]: optim_result, [98]: optim_result, [99]: optim_result, [100]: optim_result, [101]: optim_result, [102]: optim_result, [103]: optim_result, [104]: optim_result, [105]: optim_result, [106]: optim_result, [107]: optim_result, [108]: optim_result, [109]: optim_result, [110]: optim_result, [111]: optim_result, [112]: optim_result, [113]: optim_result, [114]: optim_result, [115]: optim_result, [116]: optim_result, [117]: optim_result, [118]: optim_result, [119]: optim_result, [120]: optim_result, [121]: optim_result, [122]: optim_result, [123]: optim_result, [124]: optim_result, [125]: optim_result, [126]: optim_result, [127]: optim_result, [128]: optim_result, [129]: optim_result, [130]: optim_result, [131]: optim_result, [132]: optim_result, [133]: optim_result, [134]: optim_result, [135]: optim_result, [136]: optim_result, [137]: optim_result, [138]: optim_result, [139]: optim_result, [140]: optim_result, [141]: optim_result, [142]: optim_result, [143]: optim_result, [144]: optim_result, [145]: optim_result, [146]: optim_result, [147]: optim_result, [148]: optim_result, [149]: optim_result, [150]: optim_result, [151]: optim_result, [152]: optim_result, [153]: optim_result, [154]: optim_result, [155]: optim_result, [156]: optim_result, [157]: optim_result, [158]: optim_result, [159]: optim_result, [160]: optim_result, [161]: optim_result, [162]: optim_result, [163]: optim_result, [164]: optim_result, [165]: optim_result, [166]: optim_result, [167]: optim_result, [168]: optim_result, [169]: optim_result, [170]: optim_result, [171]: optim_result, [172]: optim_result, [173]: optim_result, [174]: optim_result, [175]: optim_result, [176]: optim_result, [177]: optim_result, [178]: optim_result, [179]: optim_result, [180]: optim_result, [181]: optim_result, [182]: optim_result, [183]: optim_result, [184]: optim_result, [185]: optim_result, [186]: optim_result, [187]: optim_result, [188]: optim_result, [189]: optim_result, [190]: optim_result, [191]: optim_result, [192]: optim_result, [193]: optim_result, [194]: optim_result, [195]: optim_result, [196]: optim_result, [197]: optim_result, [198]: optim_result, [199]: optim_result, [200]: optim_result, [201]: optim_result, [202]: optim_result}, elements_use_flags: {ptr: 0x7f697caf65a0, value: [const vector][0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}}}, and that of the previous branch is AbstractScalar(Type: Bool, Value: false, Shape: NoShape).
The node is @↓↓↓↓↓↓Default.51:[CNode]52{[0]: @↓↓↓↓↓↓Default.51:[CNode]53{[0]: ValueNode<Primitive> Switch, [1]: [CNode]116, [2]: ValueNode<FuncGraph> ✓↓↓↓↓↓↓Default.43, [3]: ValueNode<FuncGraph> ✗↓↓↓↓↓↓Default.44}}, true branch: ✓↓↓↓↓↓↓Default.43, false branch: ✗↓↓↓↓↓↓Default.44

----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------
The function call stack (See file '/home/fish/Documents/GitHub/nlpfinal/ner/rank_0/om/analyze_fail.dat' for more details. Get instructions about `analyze_fail.dat` at https://www.mindspore.cn/search?inputValue=analyze_fail.dat):
# 0 In file /home/fish/Documents/GitHub/nlpfinal/ner/src/bert_for_finetune.py:216
        if overflow:

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:847 ProcessEvalResults


### 9. 加载离线模型

In [12]:
netwithloss = BertNER(bert_net_cfg, 1, False, num_labels=number_labels,
                     use_crf=cfg.use_crf,
                     tag_to_index=tag_to_index)

netwithloss.set_train(False)
param_dict = load_checkpoint(cfg.finetune_ckpt)
load_param_into_net(netwithloss, param_dict)
model = Model(netwithloss)

tokenizer_ = tokenization.FullTokenizer(vocab_file=cfg.vocab_file)



### 10. 定义测试集评估函数

In [13]:
def eval():
    '''
    evaluation function
    '''

    dataset = get_dataset(cfg.eval_file, cfg.schema_file, 1)
    columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]

    y_true, y_pred = [], []
    for data in dataset.create_dict_iterator():
        input_data = []
        for i in columns_list:
            input_data.append(Tensor(data[i]))
        input_ids, input_mask, token_type_id, label_ids = input_data
        logits = model.predict(input_ids, input_mask, token_type_id, label_ids)

        if cfg.use_crf:
            backpointers, best_tag_id = logits
            best_path = postprocess(backpointers, best_tag_id)
            logit_ids = []
            for ele in best_path:
                logit_ids.append(ele)
        else:
            logits = logits.asnumpy()
            logit_ids = np.argmax(logits, axis=-1)

        for ids in label_ids.asnumpy():
            y_true.extend(ids)
        for ids in logit_ids:
            y_pred.extend(ids)

    print(classification_report(y_true, y_pred, labels=range(1, 41), target_names=list(tag_to_index.keys())[1:41]))

### 11. 启动测试集评估

In [17]:
eval()

                precision    recall  f1-score   support

     S_address       0.00      0.00      0.00         0
     B_address       0.01      0.01      0.01       373
     M_address       0.00      0.00      0.00       956
     E_address       0.00      0.01      0.00       373
        S_book       0.00      0.00      0.00         0
        B_book       0.00      0.05      0.00       154
        M_book       0.00      0.00      0.00       723
        E_book       0.00      0.00      0.00       154
     S_company       0.00      0.00      0.00         0
     B_company       0.10      0.03      0.04       378
     M_company       0.00      0.02      0.00       937
     E_company       0.00      0.24      0.01       378
        S_game       0.00      0.00      0.00         0
        B_game       0.00      0.00      0.00       295
        M_game       0.25      0.00      0.00      1067
        E_game       0.00      0.00      0.00       295
  S_government       0.00      0.00      0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### 12. 定义在线推理函数

In [14]:
def inference(text):
    """
    online inference
    """
    feature = process_one_example_p(tokenizer_, cfg.vocab_file, text, max_seq_len=bert_net_cfg.seq_length)
    input_ids, input_mask, token_type_id = feature
    input_ids = Tensor(np.array(input_ids), mstype.int32)
    input_mask = Tensor(np.array(input_mask), mstype.int32)
    token_type_id = Tensor(np.array(token_type_id), mstype.int32)
    if cfg.use_crf:
        backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
        best_path = postprocess(backpointers, best_tag_id)
        logits = []
        for ele in best_path:
            logits.extend(ele)
        ids = logits
    else:
        logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
        ids = logits.asnumpy()
        ids = np.argmax(ids, axis=-1)
        ids = list(ids)
        
    res = label_generation(text=text, probs=ids, tag_to_index=tag_to_index)
    return res

### 13. 在线推理测试

In [15]:
inference("温格的球队终于又踢了一场经典的比赛，2比1战胜曼联之后枪手仍然留在了夺冠集团之内，")

{'government': {'温': [[0, 0]]},
 'company': {'格的': [[1, 2]],
  '球队': [[3, 4]],
  '终于': [[5, 6]],
  '又踢': [[7, 8]],
  '了一': [[9, 10]],
  '场经': [[11, 12]],
  '典的': [[13, 14]],
  '比赛': [[15, 16]],
  '，2': [[17, 18]],
  '比1': [[19, 20]],
  '战胜': [[21, 22]],
  '曼联': [[23, 24]],
  '之后': [[25, 26]],
  '枪手': [[27, 28]],
  '仍然': [[29, 30]],
  '留在': [[31, 32]],
  '了夺': [[33, 34]],
  '冠集': [[35, 36]],
  '团之': [[37, 38]]},
 'movie': {'内，': [[39, 40]]}}

In [16]:
inference("郑阿姨就赶到文汇路排队拿钱，希望能将缴纳的一万余元学费拿回来，顺便找校方或者教委要个说法。")

{'book': {'阿': [[1, 1]], '。': [[44, 44]]},
 'government': {'姨': [[2, 2]], '者': [[37, 37]]},
 'company': {'就赶': [[3, 4]],
  '到文': [[5, 6]],
  '汇路': [[7, 8]],
  '排队': [[9, 10]],
  '拿钱': [[11, 12]],
  '，希': [[13, 14]],
  '望能': [[15, 16]],
  '将缴': [[17, 18]],
  '纳的': [[19, 20]],
  '一万': [[21, 22]],
  '余元': [[23, 24]],
  '学费': [[25, 26]],
  '拿回': [[27, 28]],
  '来，': [[29, 30]],
  '顺便': [[31, 32]],
  '找校': [[33, 34]],
  '方或': [[35, 36]]},
 'position': {'教': [[38, 38]], '委要个说法': [[39, 43]]}}