In [1]:
import os
import collections
from easydict import EasyDict as edict

import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from src.dataset import create_squad_dataset
from src.bert_for_finetune import BertSquadCell, BertSquad
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate

_cur_dir = os.getcwd()

MindSpore version 1.1.1 and "topi" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install




In [2]:
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0}]

        optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
                                       end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
                             momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="squad",
                                 directory=None if save_checkpoint_path == "" else save_checkpoint_path,
                                 config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
    model.train(epoch_num, dataset, callbacks=callbacks)

In [3]:
def do_eval(dataset=None, load_checkpoint_path="", eval_batch_size=1):
    """ do eval """
    if load_checkpoint_path == "":
        raise ValueError("Finetune model missed, evaluation task must load finetune model!")
    net = BertSquad(bert_net_cfg, False, 2)
    net.set_train(False)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(net, param_dict)
    model = Model(net)
    output = []
    RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
    columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"]
    for data in dataset.create_dict_iterator(num_epochs=1):
        input_data = []
        for i in columns_list:
            input_data.append(data[i])
        input_ids, input_mask, segment_ids, unique_ids = input_data
        start_positions = Tensor([1], mstype.float32)
        end_positions = Tensor([1], mstype.float32)
        is_impossible = Tensor([1], mstype.float32)
        logits = model.predict(input_ids, input_mask, segment_ids, start_positions,
                               end_positions, unique_ids, is_impossible)
        ids = logits[0].asnumpy()
        start = logits[1].asnumpy()
        end = logits[2].asnumpy()

        for i in range(eval_batch_size):
            unique_id = int(ids[i])
            start_logits = [float(x) for x in start[i].flat]
            end_logits = [float(x) for x in end[i].flat]
            output.append(RawResult(
                unique_id=unique_id,
                start_logits=start_logits,
                end_logits=end_logits))
    return output

In [6]:
args_opt = edict({
    "device_target":"Ascend",
    "do_train":"true",
    "do_eval":"true",
    "epoch_num":3,
    "num_class":2,
    "train_data_shuffle":"false",
    "eval_data_shuffle":"false",
    "train_batch_size":32,
    "eval_batch_size":1,
    "vocab_file_path":"./squad/vocab_bert_large_en.txt",
    "save_finetune_checkpoint_path":"",
    "load_pretrain_checkpoint_path":"./squad/bert_converted.ckpt",
    "load_finetune_checkpoint_path":"./squad-3_2745.ckpt",
    "train_data_file_path":"./squad/train.tf_record",
    "eval_json_path":"./squad/dev-v1.1.json",
    "schema_file_path":""
})

In [7]:
target = args_opt.device_target
if target == "Ascend":
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=1)
elif target == "GPU":
    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
    if bert_net_cfg.compute_type != mstype.float32:
        logger.warning('GPU only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32
else:
    raise Exception("Target error, GPU or Ascend is supported.")

netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)

if args_opt.do_train.lower() == "true":
    ds = create_squad_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
                              data_file_path=args_opt.train_data_file_path,
                              schema_file_path=args_opt.schema_file_path,
                              do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
    do_train(ds, netwithloss, args_opt.load_pretrain_checkpoint_path, args_opt.save_finetune_checkpoint_path, args_opt.epoch_num)
    if args_opt.do_eval.lower() == "true":
        if args_opt.save_finetune_checkpoint_path == "":
            load_finetune_checkpoint_dir = _cur_dir
        else:
            load_finetune_checkpoint_dir = make_directory(args_opt.save_finetune_checkpoint_path)
        load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,
                                                       ds.get_dataset_size(), args_opt.epoch_num, "squad")

if args_opt.do_eval.lower() == "true":
    from src import tokenization
    from src.create_squad_data import read_squad_examples, convert_examples_to_features
    from src.squad_get_predictions import write_predictions
    from src.squad_postprocess import SQuad_postprocess
    tokenizer = tokenization.FullTokenizer(vocab_file=args_opt.vocab_file_path, do_lower_case=True)
    eval_examples = read_squad_examples(args_opt.eval_json_path, False)
    eval_features = convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=bert_net_cfg.seq_length,
        doc_stride=128,
        max_query_length=64,
        is_training=False,
        output_fn=None,
        vocab_file=args_opt.vocab_file_path)
    ds = create_squad_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
                              data_file_path=eval_features,
                              schema_file_path=args_opt.schema_file_path, is_training=False,
                              do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
    outputs = do_eval(ds, args_opt.load_finetune_checkpoint_path, args_opt.eval_batch_size)
    all_predictions = write_predictions(eval_examples, eval_features, outputs, 20, 30, True)
    SQuad_postprocess(args_opt.eval_json_path, all_predictions, output_metrics="output.json")



{"exact_match": 80.51087984862819, "f1": 87.94542868231342}
