## MindSpore-Transformer-Machine Translation
### 1. 下载源码和数据至本地容器

因为notebook是挂载在obs上，运行的容器实例不能直接读取操作obs上的文件，需下载至容器本地环境中

In [None]:
import moxing as mox
mox.file.copy_parallel(src_url="s3://ascend-zyjs-dcyang/nlp/mt_transformer_mindspore_1.1/data/", dst_url='./data/')
mox.file.copy_parallel(src_url="s3://ascend-zyjs-dcyang/nlp/mt_transformer_mindspore_1.1/src/", dst_url='./src/')

INFO:root:Using MoXing-v2.0.1.rc0.ffd1c0c8-ffd1c0c8
INFO:root:Using OBS-Python-SDK-3.20.9.1


### 2. 导入依赖库

In [2]:
import os
import numpy as np
from easydict import EasyDict as edict

import mindspore.nn as nn
from mindspore import context
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
from mindspore.mindrecord import FileWriter
from mindspore.common.parameter import Parameter
import mindspore.dataset.transforms.c_transforms as deC
from mindspore.common.tensor import Tensor
from mindspore.nn.optim import Adam
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import Callback, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from src import tokenization
from src.train_util import LossCallBack
from src.lr_schedule import create_dynamic_lr
from src.transformer_model import TransformerConfig, TransformerModel
from src.data_utils import create_training_instance, write_instance_to_file
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell

### 3. 设置运行环境

In [3]:
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

### 4. 定义数据处理相关参数

In [4]:
data_cfg = edict({
        'input_file': './data/ch_en_all.txt',
        'vocab_file': './data/ch_en_vocab.txt',
        'train_file_mindrecord': './data/train.mindrecord',
        'eval_file_mindrecord': './data/test.mindrecord',
        'train_file_source': './data/source_train.txt',
        'eval_file_source': './data/source_test.txt',
        'num_splits':1,
        'clip_to_max_len': False,
        'max_seq_length': 40
})

### 5. 定义数据处理函数

加载原始数据，切分训练、测试数据，并预处理成模型输入所需的数据形式，并保存为mindrecord格式

In [5]:
def data_prepare(cfg, eval_idx):
    tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=cfg.vocab_file)

    writer_train = FileWriter(cfg.train_file_mindrecord, cfg.num_splits)
    writer_eval = FileWriter(cfg.eval_file_mindrecord, cfg.num_splits)
    data_schema = {"source_sos_ids": {"type": "int32", "shape": [-1]},
                   "source_sos_mask": {"type": "int32", "shape": [-1]},
                   "source_eos_ids": {"type": "int32", "shape": [-1]},
                   "source_eos_mask": {"type": "int32", "shape": [-1]},
                   "target_sos_ids": {"type": "int32", "shape": [-1]},
                   "target_sos_mask": {"type": "int32", "shape": [-1]},
                   "target_eos_ids": {"type": "int32", "shape": [-1]},
                   "target_eos_mask": {"type": "int32", "shape": [-1]}
                   }

    writer_train.add_schema(data_schema, "tranformer train")
    writer_eval.add_schema(data_schema, "tranformer eval")

    index = 0
    f_train = open(cfg.train_file_source, 'w', encoding='utf-8')
    f_test = open(cfg.eval_file_source,'w',encoding='utf-8')
    f = open(cfg.input_file, "r", encoding='utf-8')
    for s_line in f:
        print("finish {}/{}".format(index, 23607), end='\r')
        
        line = tokenization.convert_to_unicode(s_line)

        source_line, target_line = line.strip().split("\t")
        source_tokens = tokenizer.tokenize(source_line)
        target_tokens = tokenizer.tokenize(target_line)

        if len(source_tokens) >= (cfg.max_seq_length-1) or len(target_tokens) >= (cfg.max_seq_length-1):
            if cfg.clip_to_max_len:
                source_tokens = source_tokens[:cfg.max_seq_length-1]
                target_tokens = target_tokens[:cfg.max_seq_length-1]
            else:
                continue
        
        index = index + 1
        # print(source_tokens)
        instance = create_training_instance(source_tokens, target_tokens, cfg.max_seq_length)
        
        if index in eval_idx:
            f_test.write(s_line)
            features = write_instance_to_file(writer_eval, instance, tokenizer, cfg.max_seq_length)
        else:
            f_train.write(s_line)
            features = write_instance_to_file(writer_train, instance, tokenizer, cfg.max_seq_length)
    f.close()
    f_test.close()
    f_train.close()
    writer_train.commit()
    writer_eval.commit()

### 6. 数据处理，随机选20%作为测试数据

In [5]:
sample_num = 23607
eval_idx = np.random.choice(sample_num, int(sample_num*0.2), replace=False)
data_prepare(data_cfg, eval_idx)

finish 23607/23607

### 7. 定义数据加载函数

In [6]:
def load_dataset(batch_size=1, data_file=None):
    """
    Load mindrecord dataset
    """
    ds = de.MindDataset(data_file,
                        columns_list=["source_eos_ids", "source_eos_mask",
                                      "target_sos_ids", "target_sos_mask",
                                      "target_eos_ids", "target_eos_mask"],
                        shuffle=False)
    type_cast_op = deC.TypeCast(mstype.int32)
    ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op)
    ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op)
    ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op)
    ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op)
    ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op)
    ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op)
    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)
    ds.channel_name = 'transformer'
    return ds

测试数据是否能正常加载

In [7]:
next(load_dataset(data_file=data_cfg.train_file_mindrecord).create_dict_iterator())['source_eos_ids'][0]

Tensor(shape=[40], dtype=Int32, value= [3983,    3,    2,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0, 
    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0])

### 8. 定义训练相关配置参数

In [8]:
train_cfg = edict({
    #--------------------------------------nework confige-------------------------------------
    'transformer_network': 'large',
    'init_loss_scale_value': 1024,
    'scale_factor': 2,
    'scale_window': 2000,

    'lr_schedule': edict({
        'learning_rate': 1.0,
        'warmup_steps': 8000,
        'start_decay_step': 16000,
        'min_lr': 0.0,
    }),
    #-----------------------------------save model confige-------------------------
    'enable_save_ckpt': True ,        #Enable save checkpointdefault is true.
    'save_checkpoint_steps':590,   #Save checkpoint steps, default is 590.
    'save_checkpoint_num':2,     #Save checkpoint numbers, default is 2.
    'save_checkpoint_path': './checkpoint',    #Save checkpoint file path,default is ./checkpoint/
    'save_checkpoint_name':'transformer-32_40',
    'checkpoint_path':'',     #Checkpoint file path
    
    
    #-------------------------------device confige-----------------------------
    'enable_data_sink':False,   #Enable data sink, default is False.
    'device_id':0,
    'device_num':1,
    'distribute':False,
    
    # -----------------mast same with the dataset-----------------------
    'seq_length':40,
    'vocab_size':10067,
    
    #--------------------------------------------------------------------------
    'data_path':"./data/train.mindrecord",   #Data path
    'epoch_size':15,
    'batch_size':32,
    'max_position_embeddings':40,
    'enable_lossscale': False,       #Use lossscale or not, default is False.
    'do_shuffle':True       #Enable shuffle for dataset, default is True.
})
'''
two kinds of transformer model version
'''
if train_cfg.transformer_network == 'base':
    transformer_net_cfg = TransformerConfig(
        batch_size=train_cfg.batch_size,
        seq_length=train_cfg.seq_length,
        vocab_size=train_cfg.vocab_size,
        hidden_size=512,
        num_hidden_layers=6,
        num_attention_heads=8,
        intermediate_size=2048,
        hidden_act="relu",
        hidden_dropout_prob=0.2,
        attention_probs_dropout_prob=0.2,
        max_position_embeddings=train_cfg.max_position_embeddings,
        initializer_range=0.02,
        label_smoothing=0.1,
        input_mask_from_dataset=True,
        dtype=mstype.float32,
        compute_type=mstype.float16)
elif train_cfg.transformer_network == 'large':
    transformer_net_cfg = TransformerConfig(
        batch_size=train_cfg.batch_size,
        seq_length=train_cfg.seq_length,
        vocab_size=train_cfg.vocab_size,
        hidden_size=1024,
        num_hidden_layers=6,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="relu",
        hidden_dropout_prob=0.2,
        attention_probs_dropout_prob=0.2,
        max_position_embeddings=train_cfg.max_position_embeddings,
        initializer_range=0.02,
        label_smoothing=0.1,
        input_mask_from_dataset=True,
        dtype=mstype.float32,
        compute_type=mstype.float16)
else:
    raise Exception("The src/train_confige of transformer_network must base or large. Change the str/train_confige file and try again!")

### 9. 定义训练函数

In [9]:
def train(cfg):
    """
    Transformer training.
    """
    
    train_dataset = load_dataset(cfg.batch_size, data_file=cfg.data_path)

    netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)

    if cfg.checkpoint_path:
        parameter_dict = load_checkpoint(cfg.checkpoint_path)
        load_param_into_net(netwithloss, parameter_dict)

    lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
                                  training_steps=train_dataset.get_dataset_size()*cfg.epoch_size,
                                  learning_rate=cfg.lr_schedule.learning_rate,
                                  warmup_steps=cfg.lr_schedule.warmup_steps,
                                  hidden_size=transformer_net_cfg.hidden_size,
                                  start_decay_step=cfg.lr_schedule.start_decay_step,
                                  min_lr=cfg.lr_schedule.min_lr), mstype.float32)
    optimizer = Adam(netwithloss.trainable_params(), lr)

    callbacks = [TimeMonitor(train_dataset.get_dataset_size()), LossCallBack()]
    if cfg.enable_save_ckpt:
        ckpt_config = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
                                       keep_checkpoint_max=cfg.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix=cfg.save_checkpoint_name, directory=cfg.save_checkpoint_path, config=ckpt_config)
        callbacks.append(ckpoint_cb)

    if cfg.enable_lossscale:
        scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value,
                                                scale_factor=cfg.scale_factor,
                                                scale_window=cfg.scale_window)
        update_cell = scale_manager.get_update_cell()
        netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,scale_update_cell=update_cell)
    else:
        netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer)

    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    model.train(cfg.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=cfg.enable_data_sink)

### 10. 启动训练

In [10]:
train(train_cfg)



time: 146254, epoch: 1, step: 1, outputs are [9.839527]
time: 146559, epoch: 1, step: 2, outputs are [10.021371]
time: 146614, epoch: 1, step: 3, outputs are [10.062596]
time: 146667, epoch: 1, step: 4, outputs are [9.945813]
time: 146720, epoch: 1, step: 5, outputs are [9.998308]
time: 146774, epoch: 1, step: 6, outputs are [9.932056]
time: 146828, epoch: 1, step: 7, outputs are [9.99134]
time: 146881, epoch: 1, step: 8, outputs are [9.963903]
time: 146935, epoch: 1, step: 9, outputs are [9.837811]
time: 146988, epoch: 1, step: 10, outputs are [10.001595]
time: 147044, epoch: 1, step: 11, outputs are [9.920187]
time: 147097, epoch: 1, step: 12, outputs are [9.840982]
time: 147151, epoch: 1, step: 13, outputs are [9.789044]
time: 147206, epoch: 1, step: 14, outputs are [9.908868]
time: 147260, epoch: 1, step: 15, outputs are [9.856711]
time: 147313, epoch: 1, step: 16, outputs are [9.787952]
time: 147367, epoch: 1, step: 17, outputs are [9.843648]
time: 147422, epoch: 1, step: 18, outp

### 11. 定义推理相关参数配置

In [16]:
eval_cfg = edict({
    'transformer_network': 'large',
    
    'data_file': './data/test.mindrecord',
    'test_source_file':'./data/source_test.txt',
    'model_file': './checkpoint/transformer-32_40_2-15_590.ckpt' ,
    'vocab_file':'./data/ch_en_vocab.txt',
    'token_file': './token-32-40.txt',
    'pred_file':'./pred-32-40.txt',
    
    # -------------------mast same with the train config and the datsset------------------------
    'seq_length':40,
    'vocab_size':10067,

    #-------------------------------------eval config-----------------------------
    'batch_size':32,
    'max_position_embeddings':40       # mast same with the train config
})

'''
two kinds of transformer model version
'''
if eval_cfg.transformer_network == 'base':
    transformer_net_cfg = TransformerConfig(
        batch_size=eval_cfg.batch_size,
        seq_length=eval_cfg.seq_length,
        vocab_size=eval_cfg.vocab_size,
        hidden_size=512,
        num_hidden_layers=6,
        num_attention_heads=8,
        intermediate_size=2048,
        hidden_act="relu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        max_position_embeddings=eval_cfg.max_position_embeddings,
        label_smoothing=0.1,
        input_mask_from_dataset=True,
        beam_width=4,
        max_decode_length=eval_cfg.seq_length,
        length_penalty_weight=1.0,
        dtype=mstype.float32,
        compute_type=mstype.float16)
    
elif eval_cfg.transformer_network == 'large':
    transformer_net_cfg = TransformerConfig(
        batch_size=eval_cfg.batch_size,
        seq_length=eval_cfg.seq_length,
        vocab_size=eval_cfg.vocab_size,
        hidden_size=1024,
        num_hidden_layers=6,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="relu",
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        max_position_embeddings=eval_cfg.max_position_embeddings,
        label_smoothing=0.1,
        input_mask_from_dataset=True,
        beam_width=4,
        max_decode_length=40,
        length_penalty_weight=1.0,
        dtype=mstype.float32,
        compute_type=mstype.float16)
else:
    raise Exception("The src/eval_confige of transformer_network must base or large and same with the train_confige confige. Change the str/eval_confige file and try again!")

### 12. 定义评估测试函数

In [17]:
class TransformerInferCell(nn.Cell):
    """
    Encapsulation class of transformer network infer.
    """
    def __init__(self, network):
        super(TransformerInferCell, self).__init__(auto_prefix=False)
        self.network = network

    def construct(self,
                  source_ids,
                  source_mask):
        predicted_ids = self.network(source_ids, source_mask)
        return predicted_ids

def load_weights(model_path):
    """
    Load checkpoint as parameter dict, support both npz file and mindspore checkpoint file.
    """
    if model_path.endswith(".npz"):
        ms_ckpt = np.load(model_path)
        is_npz = True
    else:
        ms_ckpt = load_checkpoint(model_path)
        is_npz = False

    weights = {}
    for msname in ms_ckpt:
        infer_name = msname
        if "tfm_decoder" in msname:
            infer_name = "tfm_decoder.decoder." + infer_name
        if is_npz:
            weights[infer_name] = ms_ckpt[msname]
        else:
            weights[infer_name] = ms_ckpt[msname].data.asnumpy()
    weights["tfm_decoder.decoder.tfm_embedding_lookup.embedding_table"] = \
        weights["tfm_embedding_lookup.embedding_table"]

    parameter_dict = {}
    for name in weights:
        parameter_dict[name] = Parameter(Tensor(weights[name]), name=name)
    return parameter_dict

def evaluate(cfg):
    """
    Transformer evaluation.
    """
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)

    tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
    print(cfg.model_file)
    parameter_dict = load_weights(cfg.model_file)
    load_param_into_net(tfm_model, parameter_dict)
    tfm_infer = TransformerInferCell(tfm_model)
    model = Model(tfm_infer)
    
    tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=cfg.vocab_file)
    dataset = load_dataset(batch_size=cfg.batch_size, data_file=cfg.data_file)
    predictions = []
    source_sents = []
    target_sents = []
    f2 = open(cfg.test_source_file, 'r', encoding='utf-8')
    for batch in dataset.create_dict_iterator():
        source_sents.append(batch["source_eos_ids"])
        target_sents.append(batch["target_eos_ids"])
        source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
        source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
        predicted_ids = model.predict(source_ids, source_mask)
        #predictions.append(predicted_ids.asnumpy())
        # ----------------------------------------decode and write to file(token file)---------------------
        batch_out = predicted_ids.asnumpy()
        for i in range(transformer_net_cfg.batch_size):
            if batch_out.ndim == 3:
                batch_out = batch_out[:, 0]
            token_ids = [str(x) for x in batch_out[i].tolist()]
            token=" ".join(token_ids)
            #-------------------------------token_ids to real output file-------------------------------
            token_ids = [int(x) for x in token.strip().split()]
            tokens = tokenizer.convert_ids_to_tokens(token_ids)
            sent = " ".join(tokens)
            sent = sent.split("<s>")[-1]
            sent = sent.split("</s>")[0]
            
            label_sent = f2.readline().strip()+'\t'
            print("source: {}".format(label_sent))
            print("result: {}".format(sent.strip()))

### 13. 启动评估测试

In [18]:
evaluate(eval_cfg)

./checkpoint/transformer-32_40_2-15_590.ckpt
source: Hi .	你 好 。	
result: 嗨 。
source: Wait !	等 一 下 ！	
result: 等 ！
source: Got it ?	你 懂 了 吗 ？	
result: 一 下 雪 可 以 把 它 放 在 吗 ？
source: I quit .	我 退 出 。	
result: 我 不 再 戒 烟 了 。
source: Really ?	你 确 定 ？	
result: 为 什 么 滚 的 女 朋 友 在 做 饭 ？
source: Call us .	联 系 我 们 。	
result: 打 电 话 给 我 们 打 电 话 给 我 们 打 电 话 。
source: Come in .	进 来 。	
result: 进 来 之 内 进 来 的 进 来 。
source: Go away !	走 开 ！	
result: 走 开 ！
source: He came .	他 来 了 。	
result: 他 来 了 自 己 来 的 走 了 。
source: It ' s me .	是 我 。	
result: 它 是 我 对 我 来 说 是 这 件 事 了 。
source: Kiss me .	吻 我 。	
result: 把 我 的 行 为 我 放 轻 松 的 信 给 我 。
source: See you .	再 见 ！	
result: 由 于 见 儿 见 面 见 。
source: Skip it .	不 管 它 。	
result: 数 学 的 需 要 它 做 的 下 雪 。
source: Wake up !	醒 醒 ！	
result: 叫 醒 醒 醒 醒 醒 来 点 !
source: You win .	算 你 狠 。	
result: 你 赢 得 让 人 赢 。
source: Cuff him .	把 他 铐 上 。	
result: 为 了 他 的 国 家 的 印 象 ， 一 直 接 受 他 。
source: Get down !	趴 下 ！	
result: 下 来 ！
source: Good job !	做 得 好 ！	
result: 干 的 好 ！
source: I forgot .	我 忘 了 

In [None]:
evaluate(eval_cfg)

./checkpoint/transformer-32_40-15_590.ckpt
source: 	
result: 我 赢 了 ！
source: 	
result: 分 ！
source: 	
result: 他 跑 了 跑 步 知 道 他 跑 了 。
source: 	
result: 没 办 法 ！
source: 	
result: 为 什 么 我 为 什 么 为 什 么 呢 ？
source: 	
result: 汤 姆 问 问 问 问 问 问 题 ， 问 问 题 汤 姆 要 问 题 。
source: 	
result: 冷 静 保 持 冷 静 冷 静 ， 冷 静 静 醒 着 冷 静 。
source: 	
result: 和 气 点 点 。
source: 	
result: 我 们 打 电 话 给 我 们 打 电 话 给 我 们 打 电 话 。
source: 	
result: 进 来 之 前 进 来 ， 进 来 。
source: 	
result: 回 家 吧 ！
source: 	
result: 告 辞 ！
source: 	
result: 他 跑 步 行 动 地 跑 步 。
source: 	
result: 把 它 保 持 它 放 在 它 上 。
source: 	
result: 冷 静 点 醒 来 ！
source: 	
result: 仍 然 还 是 依 赖 着 安 静 的 。
source: 	
result: 驾 驶 车 在 小 的 驾 驶 车 祸 上 开 车 。
source: 	
result: 走 开 ！
source: 	
result: 下 雪 ！
source: 	
result: 滚 失 去 ！
source: 	
result: 真 正 实 实 际 上 真 相 很 抱 歉 。
source: 	
result: 抓 住 汤 姆 。
source: 	
result: 很 有 趣 的 玩 趣 趣 趣 ， 你 玩 得 很 有 趣 。
source: 	
result: 多 可 能 更 多 可 爱 啊 ！
source: 	
result: 我 很 忙 忙 忙 忙 忙 忙 忙 忙 ， 忙 忙 忙 忙 忙 。
source: 	
result: 让 我 们 走 吧 ！
source: 	
result: 相 信

In [None]:
./checkpoint/transformer-32_40-15_590.ckpt
source: Hi .	你 好 。	
result: 嗨 。
source: Wait !	等 一 下 ！	
result: 等 一 下 ！
source: Got it ?	你 懂 了 吗 ？	
result: 你 懂 了 吗 ？
source: I quit .	我 退 出 。	
result: 我 退 休 了 。
source: Really ?	你 确 定 ？	
result: 你 确 定 ？
source: Call us .	联 系 我 们 。	
result: 我 们 打 电 话 给 我 们 打 电 话 给 我 们 打 电 话 。
source: Come in .	进 来 。	
result: 进 来 之 前 进 来 ， 进 来 。
source: Go away !	走 开 ！	

In [None]:
evaluate(eval_cfg) # seq_length':40,

In [None]:
seq_length':20,