# pip

In [None]:
!pip install mindspore==2.0.0
!pip install -r requirements.txt

# download punkt

In [None]:
import nltk
nltk.download('punkt')

# 如果无法下载，执行下列命令

# tokenizer

In [3]:
import os
from nltk.tokenize import word_tokenize

data_list = ['news_crawl','ggw_data']
src_folder = "./dataset/"
out_folder = "./tokenized_corpus/"

if not os.path.isdir(out_folder):
    os.mkdir(out_folder)
if not os.path.isdir(out_folder+'news_crawl/'):
    os.mkdir(out_folder+'news_crawl/')
if not os.path.isdir(out_folder+'ggw_data/'):
    os.mkdir(out_folder+'ggw_data/')

def create_tokenized_sentences(file_path, tokenized_file):
    tokenized_sen = []
    print(f" | Processing {file_path}.")
    with open(file_path, "r") as file:
        for sen in file:
            tokens = word_tokenize(sen)
            tokens = [t for t in tokens if t != " "]
            if len(tokens) > 175:
                continue
            tokenized_sen.append(" ".join(tokens) + "\n")

    with open(tokenized_file, "w") as file:
        file.writelines(tokenized_sen)
    print(f" | Wrote to {tokenized_file}.")

for item in data_list:
    folder_path = os.path.join(src_folder, item)
    output_path = os.path.join(out_folder, item)
    for file in os.listdir(folder_path):
        if not file.endswith(".txt"):
            continue
        file_path = os.path.join(folder_path, file)
        tokenized_file = os.path.join(output_path, file.replace(".txt", "_tokenized.txt"))
        create_tokenized_sentences(file_path, tokenized_file)

 | Processing ./dataset/news_crawl/news2012_short.txt.
 | Wrote to ./tokenized_corpus/news_crawl/news2012_short_tokenized.txt.
 | Processing ./dataset/ggw_data/test.tgt.txt.
 | Wrote to ./tokenized_corpus/ggw_data/test.tgt_tokenized.txt.
 | Processing ./dataset/ggw_data/train_short.tgt.txt.
 | Wrote to ./tokenized_corpus/ggw_data/train_short.tgt_tokenized.txt.
 | Processing ./dataset/ggw_data/test.src.txt.
 | Wrote to ./tokenized_corpus/ggw_data/test.src_tokenized.txt.
 | Processing ./dataset/ggw_data/train_short.src.txt.
 | Wrote to ./tokenized_corpus/ggw_data/train_short.src_tokenized.txt.


# learn bpe

In [4]:
if not os.path.isdir('./vocab/'):
    os.mkdir('./vocab/')

import subprocess
commands = "cat ./dataset/ggw_data/*.txt ./dataset/news_crawl/*.txt | subword-nmt learn-bpe -s 46000 -o ./vocab/all.bpe.codes"

subprocess.call(commands,shell=True)

0

# vocab

In [5]:
import os
import subprocess
from src.utils import Dictionary

source_folder = os.path.abspath("./tokenized_corpus/news_crawl/")
output_folder = os.path.abspath("./tokenized_corpus/news_crawl/bpe/")
codes = os.path.abspath("./vocab/all.bpe.codes")
vocab_path = "./vocab/all_en.dict.bin"

if not os.path.isdir(output_folder):
    os.mkdir(output_folder)

ENCODER = "subword-nmt apply-bpe -c"
LEARN_DICT = "subword-nmt get-vocab -i"
def bpe_encode(codes_path, src_path, output_path, dict_path):
    # Encoding.
    print(" | Applying BPE encoding.")
    commands = ENCODER.split() + [codes_path] + ["-i"] + [src_path] + ["-o"] + [output_path]
    subprocess.call(commands)
    print(" | Fetching vocabulary from single file.")
    # Learn vocab.
    commands = LEARN_DICT.split() + [output_path] + ["-o"] + [dict_path]
    subprocess.call(commands)

available_dict = []
for file in os.listdir(source_folder):
    if file.endswith(".txt"):
        output_path = os.path.join(output_folder, file.replace(".txt", "_bpe.txt"))
        dict_path = os.path.join(output_folder, file.replace(".txt", ".dict"))
        available_dict.append(dict_path)
        bpe_encode(codes, os.path.join(source_folder, file), output_path, dict_path)

# 加载bpe_encode處理過的文本词汇表，行格式为[word, freq]。
vocab = Dictionary.load_from_text(available_dict)
vocab.persistence(vocab_path) #将词汇表对象保存为二进制文件。
print(f" | Vocabulary Size: {len(vocab)}")

 | Applying BPE encoding.
 | Fetching vocabulary from single file.
 | Vocabulary Size: 36469


# creat dataset nc

In [6]:
"""Create News Crawl Pre-Training Dataset."""
import os
from src.dataset import MonoLingualDataLoader
from src.language_model import LooseMaskedLanguageModel
from src.utils import Dictionary

input_folder_path = './dataset/news_crawl/' # Raw corpus folder
output_folder_path = './train_data/news_crawl/' # Dataset output path

if not os.path.isdir(output_folder_path):
    os.mkdir('./train_data')
    os.mkdir('./train_data/news_crawl/')

vocab_path = './vocab/all_en.dict.bin' # Existed vocab path
vocab = Dictionary.load_from_persisted_dict(vocab_path)

def create_pre_train(text_file, output_folder, vocab, max_sen_len):

    loader = MonoLingualDataLoader(
        src_filepath=text_file,
        lang="en", dictionary=vocab,
        language_model=LooseMaskedLanguageModel(mask_ratio=0.5, mask_all_prob=None),
        max_sen_len=max_sen_len, min_sen_len=10
    )

    src_file_name = os.path.basename(text_file)

    file_name = os.path.join(
        output_folder_path,
        src_file_name.replace('.txt', f'_len_{max_sen_len}.mindrecord')
    )

    if os.path.exists(file_name):
        os.remove(file_name)
    if os.path.exists(file_name+'.db'):
        os.remove(file_name+'.db')

    loader.write_to_mindrecord(path=file_name)

for file in os.listdir(input_folder_path):
    if file.endswith(".txt"):
        create_pre_train(os.path.join(input_folder_path, file),output_folder_path, vocab, 32)

print(f" | Generate Dataset for Pre-training is done.")
print(f" | Vocabulary size: {vocab.size}.")

 | Processing corpus ./dataset/news_crawl/news2012_short.txt.
 | Shortest len = 1.
 | Longest  len = 1090.
 | Total    sen = 999745.
| Wrote to /home/ma-user/work/MASS-ms200/train_data/news_crawl/news2012_short_len_32.mindrecord.
 | Generate Dataset for Pre-training is done.
 | Vocabulary size: 36469.


# create dataset ggw

In [7]:
"""Generate Gigaword dataset."""
import os
from src.dataset import BiLingualDataLoader
from src.language_model import NoiseChannelLanguageModel
from src.utils import Dictionary

input_folder_path = './dataset/ggw_data/'
output_folder_path = './train_data/ggw_data/'

if not os.path.isdir(output_folder_path):
    os.mkdir('./train_data/ggw_data/')

vocab_path = './vocab/all_en.dict.bin'
vocab = Dictionary.load_from_persisted_dict(vocab_path)

train = BiLingualDataLoader(
    src_filepath=os.path.join(input_folder_path,"train_short.src.txt"),
    tgt_filepath=os.path.join(input_folder_path,"train_short.tgt.txt"),
    src_dict=vocab, tgt_dict=vocab,
    src_lang="en", tgt_lang="en",
    language_model=NoiseChannelLanguageModel(add_noise_prob=0),
    max_sen_len=32
)

train_path = os.path.join(output_folder_path, "gigaword_train_dataset.mindrecord")
if os.path.exists(train_path):
        os.remove(train_path)
if os.path.exists(train_path+'.db'):
        os.remove(train_path+'.db')

train.write_to_mindrecord(
    path=train_path
)

test = BiLingualDataLoader(
    src_filepath=os.path.join(input_folder_path,"test.src.txt"),
    tgt_filepath=os.path.join(input_folder_path,"test.tgt.txt"),
    src_dict=vocab, tgt_dict=vocab,
    src_lang="en", tgt_lang="en",
    language_model=NoiseChannelLanguageModel(add_noise_prob=0),
    max_sen_len=32
)

test_path = os.path.join(output_folder_path, "gigaword_test_dataset.mindrecord")
if os.path.exists(test_path):
        os.remove(test_path)
if os.path.exists(test_path+'.db'):
        os.remove(test_path+'.db')

test.write_to_mindrecord(
    path=test_path
)

print(f" | Generate Dataset for fine-tuneing is done.")
print(f" | Vocabulary size: {vocab.size}.")

 | Processing corpus ./dataset/ggw_data/train_short.src.txt.
 | Processing corpus ./dataset/ggw_data/train_short.tgt.txt.
 | Shortest len = 4.
 | Longest  len = 90.
 | Total    sen = 194181.
 | Total token num=8755897, 10.717634069930243% replaced by <unk>.
| Wrote to /home/ma-user/work/MASS-ms200/train_data/ggw_data/gigaword_train_dataset.mindrecord.
 | Processing corpus ./dataset/ggw_data/test.src.txt.
 | Processing corpus ./dataset/ggw_data/test.tgt.txt.
 | Shortest len = 2.
 | Longest  len = 73.
 | Total    sen = 1081.
 | Total token num=46933, 10.191123516502248% replaced by <unk>.
| Wrote to /home/ma-user/work/MASS-ms200/train_data/ggw_data/gigaword_test_dataset.mindrecord.
 | Generate Dataset for fine-tuneing is done.
 | Vocabulary size: 36469.


# load data

In [2]:
"""Dataset loader to feed into model."""
import mindspore as ms
import mindspore.dataset as ds

def load_dataset(input_file, batch_size: int, epoch_count: int,
                 sink_mode: bool, sink_step: int = 1, rank_size: int = 1, rank_id: int = 0, shuffle=True):

    print(f" | Loading {input_file}.")

    data_set = ds.MindDataset(
        input_file,
        columns_list=[
            "src", "src_padding",
            "prev_opt", "prev_padding",
            "target", "tgt_padding"
        ],
        shuffle=shuffle, num_shards=rank_size, shard_id=rank_id)

    type_cast_op = ds.transforms.transforms.TypeCast(ms.int32)
    data_set = data_set.map(operations=type_cast_op, input_columns="src")
    data_set = data_set.map(operations=type_cast_op, input_columns="src_padding")
    data_set = data_set.map(operations=type_cast_op, input_columns="prev_opt")
    data_set = data_set.map(operations=type_cast_op, input_columns="prev_padding")
    data_set = data_set.map(operations=type_cast_op, input_columns="target")
    data_set = data_set.map(operations=type_cast_op, input_columns="tgt_padding")


    ori_dataset_size = data_set.get_dataset_size()
    print(f" | Dataset size: {ori_dataset_size}.")
    repeat_count = epoch_count

    data_set = data_set.rename(
        input_columns=["src",
                       "src_padding",
                       "prev_opt",
                       "prev_padding",
                       "target",
                       "tgt_padding"],
        output_columns=["source_eos_ids",
                        "source_eos_mask",
                        "target_sos_ids",
                        "target_sos_mask",
                        "target_eos_ids",
                        "target_eos_mask"]
    )

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    data_set = data_set.repeat(repeat_count)

    data_set.channel_name = 'transformer'
    return data_set

# pretraining

In [9]:
"""定义训练与推理的参数"""
import mindspore.common.dtype as mstype
class config():
    enable_modelarts = True #Whether training on modelarts, default = False
    device_target = "Ascend"
    output_path = "./output/"
    save_checkpoint_path = "./output/checkpoint/"
    checkpoint_file_path = ""
    # Training options
    epochs = 20
    batch_size = 192
    dtype = mstype.float32 #only support float16 and float32
    compute_type = mstype.float16 #only support float16 and float32
    pre_train_dataset = "./train_data/news_crawl/news2012_short_len_32.mindrecord"
    # ./train_data/news_crawl/news2012_len_32.tfrecord-001-of-001
    fine_tune_dataset = "./train_data/ggw_data/gigaword_train_dataset.mindrecord"
    test_dataset = "./train_data/ggw_data/gigaword_test_dataset.mindrecord"
    dataset_sink_mode = False
    dataset_sink_step = 100
    random_seed = 100
    save_graphs = False
    seq_length = 32 #64
    vocab_size = 0 #need change
    hidden_size = 1024
    num_hidden_layers = 6
    num_attention_heads = 8
    intermediate_size = 4096
    hidden_act = "relu"
    hidden_dropout_prob = 0.2
    attention_dropout_prob = 0.2
    max_position_embeddings = 32 #64
    initializer_range = 0.02
    label_smoothing = 0.1
    beam_width = 4
    length_penalty_weight = 1.0
    max_decode_length = 32 #64
    init_loss_scale = 65536
    loss_scale_factor = 2
    scale_window = 200
    lr = 0.0001
    poly_lr_scheduler_power = 0.5
    decay_steps = 10000
    decay_start_step = 12000
    warmup_steps = 4000
    min_lr = 0.000001
    save_ckpt_steps = 10000
    keep_ckpt_max = 50
    ckpt_prefix = "pt"
    metric = "rouge"
    vocab = "./vocab/all_en.dict.bin"
    output = "./output/infer.bin"

In [3]:
import os
import pickle
import numpy as np

from mindspore.common.tensor import Tensor
from mindspore.nn import Momentum
from mindspore.nn.optim import Adam, Lamb
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore import context, Parameter
from mindspore.communication import management as MultiAscend
from mindspore.train.serialization import load_checkpoint
from mindspore.common import set_seed

from src.transformer import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
from src.utils import LossCallBack
from src.utils import one_weight, zero_weight, weight_variable
from src.utils import square_root_schedule
from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate

In [11]:
config = config()
config.epochs = 1 #for test

vocab_path = './vocab/all_en.dict.bin'
vocab = Dictionary.load_from_persisted_dict(vocab_path)
config.vocab_size = vocab.size

print(" | Starting training on single device.")
pre_train_dataset = load_dataset(input_file=config.pre_train_dataset,
                                 batch_size=config.batch_size,
                                 epoch_count=1,
                                 sink_mode=config.dataset_sink_mode,
                                 sink_step=config.dataset_sink_step)

# 定义帶loss的網路
net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
net_with_loss.init_parameters_data()

for param in net_with_loss.trainable_params():
    name = param.name
    value = param.data
    if isinstance(value, Tensor):
        if name.endswith(".gamma"):
            param.set_data(one_weight(value.asnumpy().shape))
        elif name.endswith(".beta") or name.endswith(".bias"):
            param.set_data(zero_weight(value.asnumpy().shape))
        else:
            param.set_data(weight_variable(value.asnumpy().shape))

update_steps = config.epochs * pre_train_dataset.get_dataset_size()

# 定义递减的学习率
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
           min_lr=config.min_lr,
           decay_steps=config.decay_steps,
           total_update_num=update_steps,
           warmup_steps=config.warmup_steps,
           power=config.poly_lr_scheduler_power), dtype=mstype.float32)
# 定义优化器
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)

# loss scale (mode = dynamic)
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
                                        scale_factor=config.loss_scale_factor,
                                        scale_window=config.scale_window)
# 定义反向网络
net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer,
                                                          scale_update_cell=scale_manager.get_update_cell())
net_with_grads.set_train(True)

# 初始化模型
model = Model(net_with_grads)

time_cb = TimeMonitor(data_size=pre_train_dataset.get_dataset_size())
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
                               keep_checkpoint_max=config.keep_ckpt_max)

callbacks = []
callbacks.append(time_cb)
ckpt_save_dir = config.save_checkpoint_path

ckpt_callback = ModelCheckpoint(
    prefix=config.ckpt_prefix,
    directory=os.path.join(ckpt_save_dir, 'ckpt_pt'),
    config=ckpt_config)
loss_monitor = LossCallBack(rank_id=0)
callbacks.append(loss_monitor)
callbacks.append(ckpt_callback)

print(" | Start pre-training job.")
model.train(config.epochs, pre_train_dataset,
            callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
            sink_size=config.dataset_sink_step)

 | Starting training on single device.
 | Loading ./train_data/news_crawl/news2012_short_len_32.mindrecord.
 | Dataset size: 999745.




 | Start pre-training job.
Train epoch time: 4965862.593 ms, per step time: 953.690 ms


# fine tuning
# 修改 config.checkpoint_file_path = './output/checkpoint/ckpt_pt/pt-xxxx.ckpt'

In [16]:
"""定义训练与推理的参数"""
import mindspore.common.dtype as mstype
class config():
    enable_modelarts = True #Whether training on modelarts, default = False
    device_target = "Ascend"
    output_path = "./output/"
    save_checkpoint_path = "./output/checkpoint/"
    checkpoint_file_path = ""
    # Training options
    epochs = 20
    batch_size = 192
    dtype = mstype.float32 #only support float16 and float32
    compute_type = mstype.float16 #only support float16 and float32
    pre_train_dataset = "./train_data/news_crawl/news2012_short_len_32.mindrecord"
    # ./train_data/news_crawl/news2012_len_32.tfrecord-001-of-001
    fine_tune_dataset = "./train_data/ggw_data/gigaword_train_dataset.mindrecord"
    test_dataset = "./train_data/ggw_data/gigaword_test_dataset.mindrecord"
    dataset_sink_mode = False
    dataset_sink_step = 100
    random_seed = 100
    save_graphs = False
    seq_length = 32 #64
    vocab_size = 0 #need change
    hidden_size = 1024
    num_hidden_layers = 6
    num_attention_heads = 8
    intermediate_size = 4096
    hidden_act = "relu"
    hidden_dropout_prob = 0.2
    attention_dropout_prob = 0.2
    max_position_embeddings = 32 #64
    initializer_range = 0.02
    label_smoothing = 0.1
    beam_width = 4
    length_penalty_weight = 1.0
    max_decode_length = 32 #64
    init_loss_scale = 65536
    loss_scale_factor = 2
    scale_window = 200
    lr = 0.0001
    poly_lr_scheduler_power = 0.5
    decay_steps = 10000
    decay_start_step = 12000
    warmup_steps = 4000
    min_lr = 0.000001
    save_ckpt_steps = 10000
    keep_ckpt_max = 50
    ckpt_prefix = "ft"
    metric = "rouge"
    vocab = "./vocab/all_en.dict.bin"
    output = "./output/infer.bin"

In [17]:
config = config()
config.checkpoint_file_path = './output/checkpoint/ckpt_pt/pt-1_5207.ckpt' #赋值给预训练生成的已有模型文件
config.epochs = 1 #for test

vocab_path = './vocab/all_en.dict.bin'
vocab = Dictionary.load_from_persisted_dict(vocab_path)
config.vocab_size = vocab.size

print(" | Starting training on single device.")
fine_tune_dataset = load_dataset(input_file=config.fine_tune_dataset,
                                 batch_size=config.batch_size,
                                 epoch_count=1,
                                 sink_mode=config.dataset_sink_mode,
                                 sink_step=config.dataset_sink_step)

# 定义帶loss的網路
net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
net_with_loss.init_parameters_data()

# 读取已有模型文件的权重
weights = load_checkpoint(config.checkpoint_file_path)
for param in net_with_loss.trainable_params():
    weights_name = param.name
    if isinstance(weights[weights_name], Parameter):
        param.set_data(weights[weights_name].data)
    elif isinstance(weights[weights_name], Tensor):
        param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
    elif isinstance(weights[weights_name], np.ndarray):
        param.set_data(Tensor(weights[weights_name], config.dtype))
    else:
        param.set_data(weights[weights_name])

update_steps = config.epochs * fine_tune_dataset.get_dataset_size()

# 定义递减的学习率
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
           min_lr=config.min_lr,
           decay_steps=config.decay_steps,
           total_update_num=update_steps,
           warmup_steps=config.warmup_steps,
           power=config.poly_lr_scheduler_power), dtype=mstype.float32)
# 定义优化器
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)

# loss scale (mode = dynamic)
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
                                        scale_factor=config.loss_scale_factor,
                                        scale_window=config.scale_window)
# 定义反向网络
net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer,
                                                          scale_update_cell=scale_manager.get_update_cell())
net_with_grads.set_train(True)

# 初始化模型
model = Model(net_with_grads)

time_cb = TimeMonitor(data_size=fine_tune_dataset.get_dataset_size())
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
                               keep_checkpoint_max=config.keep_ckpt_max)

callbacks = []
callbacks.append(time_cb)
ckpt_save_dir = config.save_checkpoint_path

ckpt_callback = ModelCheckpoint(
    prefix=config.ckpt_prefix,
    directory=os.path.join(ckpt_save_dir, 'ckpt_ft'),
    config=ckpt_config)
loss_monitor = LossCallBack(rank_id=0)
callbacks.append(loss_monitor)
callbacks.append(ckpt_callback)

print(" | Start fine-tuning job.")
model.train(config.epochs, fine_tune_dataset,
            callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
            sink_size=config.dataset_sink_step)

 | Starting training on single device.
 | Loading ./train_data/ggw_data/gigaword_train_dataset.mindrecord.
 | Dataset size: 194181.




 | Start fine-tuning job.
Train epoch time: 991487.919 ms, per step time: 980.700 ms


# test
# 修改 config.checkpoint_file_path = "./output/checkpoint/ckpt_ft/ft-xxxx.ckpt"

In [4]:
"""定义训练与推理的参数"""
import mindspore.common.dtype as mstype
class config():
    enable_modelarts = True #Whether training on modelarts, default = False
    device_target = "Ascend"
    output_path = "./output/"
    save_checkpoint_path = "./output/checkpoint/"
    checkpoint_file_path = ""
    # Training options
    epochs = 20
    batch_size = 192
    dtype = mstype.float32 #only support float16 and float32
    compute_type = mstype.float16 #only support float16 and float32
    pre_train_dataset = "./train_data/news_crawl/news2012_short_len_32.mindrecord"
    # ./train_data/news_crawl/news2012_len_32.tfrecord-001-of-001
    fine_tune_dataset = "./train_data/ggw_data/gigaword_train_dataset.mindrecord"
    test_dataset = "./train_data/ggw_data/gigaword_test_dataset.mindrecord"
    dataset_sink_mode = False
    dataset_sink_step = 100
    random_seed = 100
    save_graphs = False
    seq_length = 32 #64
    vocab_size = 0 #need change
    hidden_size = 1024
    num_hidden_layers = 6
    num_attention_heads = 8
    intermediate_size = 4096
    hidden_act = "relu"
    hidden_dropout_prob = 0.2
    attention_dropout_prob = 0.2
    max_position_embeddings = 32 #64
    initializer_range = 0.02
    label_smoothing = 0.1
    beam_width = 4
    length_penalty_weight = 1.0
    max_decode_length = 32 #64
    init_loss_scale = 65536
    loss_scale_factor = 2
    scale_window = 200
    lr = 0.0001
    poly_lr_scheduler_power = 0.5
    decay_steps = 10000
    decay_start_step = 12000
    warmup_steps = 4000
    min_lr = 0.000001
    save_ckpt_steps = 10000
    keep_ckpt_max = 50
    ckpt_prefix = "ft"
    metric = "rouge"
    vocab = "./vocab/all_en.dict.bin"
    output = "./output/infer.bin"

In [5]:
import os
import pickle
import time

import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from src.utils import Dictionary
from src.transformer.transformer_for_infer import TransformerInferModel
from src.transformer.transformer_for_train import TransformerTraining
from src.utils.load_weights import load_infer_weights
from src.utils.rouge_score import rouge

class TransformerInferCell(nn.Cell):
    def __init__(self, network):
        super(TransformerInferCell, self).__init__(auto_prefix=False)
        self.network = network

    def construct(self,
                  source_ids,
                  source_mask):
        predicted_ids, predicted_probs = self.network(source_ids,
                                                      source_mask)
        return predicted_ids, predicted_probs

def get_rouge_score(result, vocab):
    """Calculate ROUGE score."""
    predictions = []
    targets = []
    for sample in result:
        predictions.append(' '.join([vocab[t] for t in sample['prediction']]))
        targets.append(' '.join([vocab[t] for t in sample['target']]))
        print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
        print(f" | prediction: {predictions[-1]}")
        print(f" | target: {targets[-1]}")

    return rouge(predictions, targets)

config = config()
config.epoch = 1 #for test
#这个文件名称要修改为训练好的模型文件路径
config.checkpoint_file_path = "./output/checkpoint/ckpt_ft/ft-1_1011.ckpt"

vocab_path = './vocab/all_en.dict.bin'
vocab = Dictionary.load_from_persisted_dict(vocab_path)
config.vocab_size = vocab.size

eval_dataset = load_dataset(input_file=config.test_dataset,
                            batch_size=config.batch_size,
                            epoch_count=1,
                            sink_mode=config.dataset_sink_mode,
                            shuffle=False)

tfm_model = TransformerInferModel(config=config, use_one_hot_embeddings=False)
tfm_model.init_parameters_data()

params = tfm_model.trainable_params()
weights = load_infer_weights(config)

for param in params:
    value = param.data
    name = param.name
    with open("weight_after_deal.txt", "a+") as f:
        weights_name = name
        f.write(weights_name + "\n")
        if isinstance(value, Tensor):
            if weights_name in weights:
                assert weights_name in weights
                param.set_data(Tensor(weights[weights_name], mstype.float32))

print(" | Load weights successfully.")

tfm_infer = TransformerInferCell(tfm_model)
model = Model(tfm_infer)

predictions = []
probs = []
source_sentences = []
target_sentences = []
for batch in eval_dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
    source_sentences.append(batch["source_eos_ids"])
    target_sentences.append(batch["target_eos_ids"])

    source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
    source_mask = Tensor(batch["source_eos_mask"], mstype.int32)

    start_time = time.time()
    predicted_ids, entire_probs = model.predict(source_ids, source_mask)
    print(f" | Batch size: {config.batch_size}, "
          f"Time cost: {time.time() - start_time}.")

    predictions.append(predicted_ids.asnumpy())
    probs.append(entire_probs.asnumpy())

output = []
for inputs, ref, batch_out, batch_probs in zip(source_sentences,
                                               target_sentences,
                                               predictions,
                                               probs):
    for i in range(config.batch_size):
        if batch_out.ndim == 3:
            batch_out = batch_out[:, 0]

        example = {
            "source": inputs[i].tolist(),
            "target": ref[i].tolist(),
            "prediction": batch_out[i].tolist(),
            "prediction_prob": batch_probs[i].tolist()
        }
        output.append(example)

with open(config.output, "wb") as f:
    pickle.dump(output, f, 1)

score = get_rouge_score(output, vocab)
print(score)

 | Loading ./train_data/ggw_data/gigaword_test_dataset.mindrecord.
 | Dataset size: 1081.
 | Load weights successfully.
 | Batch size: 192, Time cost: 1652.7519772052765.
 | Batch size: 192, Time cost: 28.484652042388916.
 | Batch size: 192, Time cost: 6.254255533218384.
 | Batch size: 192, Time cost: 6.127770900726318.
 | Batch size: 192, Time cost: 6.0402281284332275.
 | source: japan 's nec corp. and UNK computer corp. of the united states said wednesday they had agreed to join forces in <unk> sales . </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
 | prediction: <s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s>
 | target: nec UNK in computer sales tie-up </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
 | source: the sri <unk> government on wednesday announced the cl