In [1]:
import os
import numpy as np

from mindspore import Tensor, nn, Model, context
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint

from src.preprocess import convert_to_mindrecord
from src.dataset import create_dataset
from src.seq2seq import Seq2Seq, WithLossCell, InferCell
from src.config import cfg

MindSpore version 1.1.1 and "topi" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install




In [2]:
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target='Ascend', device_id=4)

### train

In [3]:
ds_train = create_dataset(cfg.dataset_path, cfg.batch_size)

In [4]:
network = Seq2Seq(cfg)
network = WithLossCell(network, cfg)
optimizer = nn.Adam(network.trainable_params(), learning_rate=cfg.learning_rate, beta1=0.9, beta2=0.98)
model = Model(network, optimizer=optimizer)

In [5]:
loss_cb = LossMonitor()
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="gru", directory=cfg.ckpt_save_path, config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callbacks = [time_cb, ckpoint_cb, loss_cb]

model.train(cfg.num_epochs, ds_train, callbacks=callbacks, dataset_sink_mode=True)

epoch: 1 step: 125, loss is 2.804515
epoch time: 32208.512 ms, per step time: 257.668 ms
epoch: 2 step: 125, loss is 1.963039
epoch time: 11227.136 ms, per step time: 89.817 ms
epoch: 3 step: 125, loss is 1.8751457
epoch time: 11207.574 ms, per step time: 89.661 ms
epoch: 4 step: 125, loss is 2.0917926
epoch time: 11235.453 ms, per step time: 89.884 ms
epoch: 5 step: 125, loss is 1.5626856
epoch time: 11257.191 ms, per step time: 90.058 ms
epoch: 6 step: 125, loss is 1.0996865
epoch time: 11264.321 ms, per step time: 90.115 ms
epoch: 7 step: 125, loss is 0.9826399
epoch time: 11222.325 ms, per step time: 89.779 ms
epoch: 8 step: 125, loss is 0.61559135
epoch time: 11283.613 ms, per step time: 90.269 ms
epoch: 9 step: 125, loss is 0.34942892
epoch time: 11223.944 ms, per step time: 89.792 ms
epoch: 10 step: 125, loss is 0.32617155
epoch time: 11203.418 ms, per step time: 89.627 ms
epoch: 11 step: 125, loss is 0.25858104
epoch time: 11256.972 ms, per step time: 90.056 ms
epoch: 12 step: 

### eval

In [6]:
rank = 0
device_num = 1
ds_eval= create_dataset(cfg.dataset_path, cfg.eval_batch_size, is_training=False)

In [7]:
network = Seq2Seq(cfg,is_train=False)
network = InferCell(network, cfg)
network.set_train(False)
parameter_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(network, parameter_dict)
model = Model(network)

In [9]:
with open(os.path.join(cfg.dataset_path,"en_vocab.txt"), 'r', encoding='utf-8') as f:
    data = f.read()
en_vocab = list(data.split('\n'))

with open(os.path.join(cfg.dataset_path,"ch_vocab.txt"), 'r', encoding='utf-8') as f:
    data = f.read()
ch_vocab = list(data.split('\n'))

for data in ds_eval.create_dict_iterator():
    en_data=''
    ch_data=''
    for x in data['encoder_data'][0].asnumpy():
        if x == 0:
            break
        en_data += en_vocab[x]
        en_data += ' '
    for x in data['decoder_data'][0].asnumpy():
        if x == 0:
            break
        if x == 1:
            continue
        ch_data += ch_vocab[x]
    output = network(data['encoder_data'],data['decoder_data'])
    print('English:', en_data)
    print('expect Chinese:', ch_data)
    out =''
    for x in output[0].asnumpy():
        if x == 0:
            break
        out += ch_vocab[x]
    print('predict Chinese:', out)
    print(' ')

English: do you like snow ? 
expect Chinese: 你喜欢雪吗？
predict Chinese: 你喜欢雪吗？
 
English: i can see tom . 
expect Chinese: 我看得见汤姆。
predict Chinese: 我看得见汤姆。
 
English: stay sharp . 
expect Chinese: 保持警惕。
predict Chinese: 保持警惕。
 
English: stop meddling . 
expect Chinese: 别再插手。
predict Chinese: 别再插手。
 
English: tom is a magician . 
expect Chinese: 汤姆是魔法师。
predict Chinese: 汤姆是魔法师。
 
English: i am very sad . 
expect Chinese: 我很难过。
predict Chinese: 我很难过。
 
English: i m very happy . 
expect Chinese: 我很快乐。
predict Chinese: 我很快乐。
 
English: don t let tom die . 
expect Chinese: 别让汤姆死了。
predict Chinese: 别让汤姆死了。
 
English: time flies . 
expect Chinese: 时光飞逝。
predict Chinese: 时光飞逝。
 
English: let s turn back . 
expect Chinese: 我们掉头吧！
predict Chinese: 我们掉头吧！
 
English: he caught a cold . 
expect Chinese: 他感冒了。
predict Chinese: 他着凉了。
 
English: is it all there ? 
expect Chinese: 全都在那里吗？
predict Chinese: 全都在那里吗？
 
English: take me home . 
expect Chinese: 带我回家。
predict Chinese: 带我回家。
 
English: anyone can