In [1]:
import os
import numpy as np

from src.config import cfg
from src.dataset import create_dataset
from src.seq2seq import Seq2Seq, InferCell
from src.seq2seq import Seq2Seq, WithLossCell

from mindspore import Tensor, nn, Model, context
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor

MindSpore version 1.1.1 and "topi" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install




In [2]:
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target='Ascend', device_id=5)

### train

In [3]:
ds_train = create_dataset(cfg.dataset_path, cfg.batch_size)

In [4]:
network = Seq2Seq(cfg)
network = WithLossCell(network, cfg)
optimizer = nn.Adam(network.trainable_params(), learning_rate=cfg.learning_rate, beta1=0.9, beta2=0.98)
model = Model(network, optimizer=optimizer)

In [5]:
loss_cb = LossMonitor()
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="gru", directory=cfg.ckpt_save_path, config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callbacks = [time_cb, ckpoint_cb, loss_cb]

model.train(cfg.num_epochs, ds_train, callbacks=callbacks, dataset_sink_mode=True)


epoch: 1 step: 125, loss is 2.5471632
epoch time: 72593.791 ms, per step time: 580.750 ms
epoch: 2 step: 125, loss is 2.5645504
epoch time: 11230.366 ms, per step time: 89.843 ms
epoch: 3 step: 125, loss is 2.3836899
epoch time: 11235.888 ms, per step time: 89.887 ms
epoch: 4 step: 125, loss is 2.279439
epoch time: 11229.956 ms, per step time: 89.840 ms
epoch: 5 step: 125, loss is 1.5323433
epoch time: 11232.835 ms, per step time: 89.863 ms
epoch: 6 step: 125, loss is 1.3322783
epoch time: 11236.202 ms, per step time: 89.890 ms
epoch: 7 step: 125, loss is 0.8172446
epoch time: 11236.513 ms, per step time: 89.892 ms
epoch: 8 step: 125, loss is 0.6874578
epoch time: 11227.472 ms, per step time: 89.820 ms
epoch: 9 step: 125, loss is 0.46486482
epoch time: 11228.080 ms, per step time: 89.825 ms
epoch: 10 step: 125, loss is 0.39268598
epoch time: 11235.647 ms, per step time: 89.885 ms
epoch: 11 step: 125, loss is 0.22333553
epoch time: 11271.179 ms, per step time: 90.169 ms
epoch: 12 step: 

### eval

In [6]:
rank = 0
device_num = 1
ds_eval= create_dataset(cfg.dataset_path, cfg.eval_batch_size, is_training=False)

In [7]:
network = Seq2Seq(cfg,is_train=False)
network = InferCell(network, cfg)
network.set_train(False)
parameter_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(network, parameter_dict)
model = Model(network)

In [27]:
with open(os.path.join(cfg.dataset_path,"en_vocab.txt"), 'r', encoding='utf-8') as f:
    data = f.read()
en_vocab = list(data.split('\n'))

with open(os.path.join(cfg.dataset_path,"ch_vocab.txt"), 'r', encoding='utf-8') as f:
    data = f.read()
ch_vocab = list(data.split('\n'))

for data in ds_eval.create_dict_iterator():
    en_data=''
    ch_data=''
    for x in data['encoder_data'][0].asnumpy():
        if x == 0:
            break
        en_data += en_vocab[x]
        en_data += ' '
    for x in data['decoder_data'][0].asnumpy():
        if x == 0:
            break
        if x == 1:
            continue
        ch_data += ch_vocab[x]
    output = network(data['encoder_data'],data['decoder_data'])
    print('English:', en_data)
    print('expect Chinese:', ch_data)
    out =''
    for x in output[0].asnumpy():
        if x == 0:
            break
        out += ch_vocab[x]
    print('predict Chinese:', out)
    print(' ')

English: who likes beans ? 
expect Chinese: 谁喜欢豆子？
predict Chinese: 谁喜欢豆子？
 
English: who built it ? 
expect Chinese: 这是谁建的？
predict Chinese: 这是谁建的？
 
English: tom is very quiet . 
expect Chinese: 汤姆很安静。
predict Chinese: 汤姆人很好。
 
English: are you finished ? 
expect Chinese: 你结束了吗？
predict Chinese: 你结束了吗？
 
English: i don t get it . 
expect Chinese: 我不懂。
predict Chinese: 我无所谓。
 
English: i understand . 
expect Chinese: 我明白了。
predict Chinese: 我明白了。
 
English: you made me laugh . 
expect Chinese: 我被你逗乐了。
predict Chinese: 我被你逗乐了。
 
English: excuse me . 
expect Chinese: 对不起。
predict Chinese: 对不起。
 
English: it s business . 
expect Chinese: 公事公办。
predict Chinese: 公事公办。
 
English: she is graceful . 
expect Chinese: 她举止优雅。
predict Chinese: 她举止优雅。
 
English: he s not home . 
expect Chinese: 他不在家。
predict Chinese: 他不在家。
 
English: it s very big . 
expect Chinese: 它很大。
predict Chinese: 它很大。
 
English: what s that ? 
expect Chinese: 那是什么？
predict Chinese: 那是什么？
 
English: tom hit a triple . 
expec