In [12]:
import mindspore
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, ms_function

In [13]:
def make_batch(seq_data, num_dic, n_step):
    input_batch, output_batch, target_batch = [], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + 'P' * (n_step - len(seq[i]))

        input = [num_dic[n] for n in seq[0]]
        output = [num_dic[n] for n in ('S' + seq[1])]
        target = [num_dic[n] for n in (seq[1] + 'E')]

        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[output])
        target_batch.append(target) # not one-hot

    # make tensor
    return Tensor(input_batch, mindspore.float32), Tensor(output_batch, mindspore.float32), Tensor(target_batch, mindspore.int32)

In [14]:
# Model
class Seq2Seq(nn.Cell):
    def __init__(self, n_class, n_hidden, dropout):
        super(Seq2Seq, self).__init__()

        self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=dropout)
        self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=dropout)
        self.fc = nn.Dense(n_hidden, n_class)
        
        
    def construct(self, enc_input, dec_input):
        enc_input = enc_input.transpose((1, 0, 2)) # enc_input: [max_len(=n_step, time step), batch_size, n_class]
        dec_input = dec_input.transpose((1, 0, 2)) # dec_input: [max_len(=n_step, time step), batch_size, n_class]

        # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        _, enc_states = self.enc_cell(enc_input)
        # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]
        outputs, _ = self.dec_cell(dec_input, enc_states)

        model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class]
        return model

In [15]:
n_step = 5
n_hidden = 128
dropout = 0.5
char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dic = {n: i for i, n in enumerate(char_arr)}
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

n_class = len(num_dic)
batch_size = len(seq_data)

In [16]:
model = Seq2Seq(n_class, n_hidden, dropout)



In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)

In [18]:
input_batch, output_batch, target_batch = make_batch(seq_data, num_dic, n_step)

In [19]:
def forward(enc_input, dec_input, target):
    output = model(enc_input, dec_input)
    output = output.transpose((1, 0, 2))
    return criterion(output.view(-1, output.shape[-1]), target.view(-1))

In [20]:
grad_fn = ops.value_and_grad(forward, None, optimizer.parameters)

In [21]:
@ms_function
def train_step(enc_input, dec_input, target):
    loss, grads = grad_fn(enc_input, dec_input, target)
    optimizer(grads)
    return loss

In [22]:
model.set_train()

for epoch in range(5000):
    # input_batch : [batch_size, max_len(=n_step, time step), n_class]
    # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]
    # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot
    loss = train_step(input_batch, output_batch, target_batch)
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))

Epoch: 1000 cost = 0.000974
Epoch: 2000 cost = 0.000262
Epoch: 3000 cost = 0.000112
Epoch: 4000 cost = 0.000056
Epoch: 5000 cost = 0.000030


In [23]:
model.set_train(False)
# Test
def translate(word):
    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]],  num_dic, n_step)
    output = model(input_batch, output_batch)
    # output : [max_len+1(=6), batch_size(=1), n_class]

    predict = output.asnumpy().argmax(2) # select n_class dimension
    decoded = [char_arr[i[0]] for i in predict]
    end = decoded.index('E')
    translated = ''.join(decoded[:end])

    return translated.replace('P', '')

print('test')
print('man ->', translate('man'))
print('mans ->', translate('mans'))
print('king ->', translate('king'))
print('black ->', translate('black'))
print('upp ->', translate('upp'))

test
man -> women
mans -> women
king -> queen
black -> white
upp -> down
