In [None]:
import math
import mindspore
import numpy as np
import mindspore.nn as nn
import mindspore.ops as P
from mindspore import Tensor, Parameter

RNNCell: 
<center>$h' = \tanh(W_{ih} x + b_{ih}  +  W_{hh} h + b_{hh})$</center>

In [None]:
class RNNCell(nn.Cell):
    """
    An Elman RNN cell with tanh or ReLU non-linearity.
    
    Args:
        input_size:  The number of expected features in the input 'x'
        hidden_size: The number of features in the hidden state 'h'
        bias: If 'False', then the layer does not use bias weights b_ih and b_hh. Default: 'True'
        nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
    
    Inputs:
        input: Tensor, (batch, input_size)
        hidden: Tensor, (batch, hidden_size)
    Outputs:
        h: Tensor, (batch, hidden_size)
    """
    nonlinearity_dict = {
        'tanh': nn.Tanh(),
        'relu': nn.ReLU()
    }
    def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = 'tanh'):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        
        stdv = 1 / math.sqrt(hidden_size)
        self.weight_ih = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, hidden_size)).astype(np.float32)))
        self.weight_hh = Parameter(Tensor(np.random.uniform(-stdv, stdv, (hidden_size, hidden_size)).astype(np.float32)))
        if bias:
            self.bias_ih = Parameter(Tensor(np.random.uniform(-stdv, stdv, (hidden_size)).astype(np.float32)))
            self.bias_hh = Parameter(Tensor(np.random.uniform(-stdv, stdv, (hidden_size)).astype(np.float32)))
        
        self.nonlinearity = self.nonlinearity_dict[nonlinearity]
        self.mm = P.MatMul()

    def construct(self, input: Tensor, hx: Tensor) -> Tensor:
        if self.bias:
            i_gates = self.mm(input, self.weight_ih) + self.bias_ih
            h_gates = self.mm(hx, self.weight_hh) + self.bias_hh
        else:
            i_gates = self.mm(input, self.weight_ih)
            h_gates = self.mm(hx, self.weight_hh)
        h = self.nonlinearity(i_gates + h_gates)
        return h

RNN: 
<center>$h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})$</center>

In [None]:
class RNN(nn.Cell):
    def __init__(self, 
                 input_size: int, 
                 hidden_size: int, 
                 nonlinearity: str = 'tanh', 
                 bias: bool = True, 
                 dropout: float = 0.,
                 batch_first: bool = False,
                ):
        super().__init__()
        self.rnn_cell = RNNCell(input_size, hidden_size, bias, nonlinearity)
        
        self.batch_first = batch_first
        self.transpose = P.Transpose()
        self.stack = P.Stack()
        self.dropout = nn.Dropout(1 - dropout)
    def construct(self, input: Tensor, h_0: Tensor):
        if self.batch_first:
            input = self.transpose(input, (1, 0, 2))
        input_shape = input.shape
        time_steps = input_shape[0]
        h_t = h_0
        output = []
        for t in range(time_steps):
            h_t = self.rnn_cell(input[t], h_t)
            output.append(h_t)
        output = self.stack(output)
        h_t = self.dropout(h_t)
        output = self.dropout(output)
        return output, h_t        

TextRNN Model:

In [None]:
def make_batch(sentences, word_dict, n_class):
    input_batch = []
    target_batch = []

    for sen in sentences:
        word = sen.split()  # space tokenizer
        input = [word_dict[n] for n in word[:-1]]  # create (1~n-1) as input
        target = word_dict[word[-1]]  # create (n) as target, We usually call this 'casual language model'

        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return input_batch, target_batch

In [None]:
class TextRNN(nn.Cell):
    def __init__(self, n_class, n_hidden, batch_size):
        super(TextRNN, self).__init__()
        self.rnn = RNN(input_size=n_class, hidden_size=n_hidden, batch_first=True)
        self.W = nn.Dense(n_hidden, n_class, has_bias=False)
        self.b = Parameter(Tensor(np.ones([n_class]), mindspore.float32))
        
        self.h_0 = Tensor(np.zeros((batch_size, n_hidden)).astype(np.float32))
        self.transpose = P.Transpose()
    def construct(self, X):
        outputs, hidden = self.rnn(X, self.h_0)
        # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]
        # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        outputs = outputs[-1] # [batch_size, num_directions(=1) * n_hidden]
        model = self.W(outputs)# model : [batch_size, n_class]
        
        return model

In [None]:
n_step = 2 # number of cells(= number of Step)
n_hidden = 5 # number of hidden units in one cell

sentences = ["i like dog", "i love coffee", "i hate milk"]

word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}
n_class = len(word_dict)
batch_size = len(sentences)

In [None]:
model = TextRNN(n_class, n_hidden, batch_size)

In [None]:
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)

In [None]:
input_batch, target_batch = make_batch(sentences, word_dict, n_class)
input_batch = Tensor(input_batch, mindspore.float32)
target_batch = Tensor(target_batch, mindspore.int32)

In [None]:
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

net_with_criterion = nn.WithLossCell(model, criterion)
train_network = nn.TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()

# Training
for epoch in range(5000):
    # hidden : [num_layers * num_directions, batch, hidden_size]
    loss = train_network(input_batch, target_batch)
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))

In [None]:
# Predict
predict = model(input_batch).asnumpy().argmax(1)
print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])