In [3]:
from mxnet import nd
from mxnet.gluon import rnn, nn
import d2l
import mxnet as mx

In [2]:
class Seq2SeqAttentionDecoder(d2l.Decoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention_cell = d2l.MLPAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Dense(vocab_size, flatten = False)

    def init_state(self, enc_outputs, enc_valid_len, *args):
        outputs, hidden_state = enc_outputs
        # Transpose outputs to (batch_size, seq_len, hidden_size)
        return (outputs.swapaxes(0,1), hidden_state, enc_valid_len)

    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_len = state
        X = self.embedding(X).swapaxes(0, 1)
        outputs = []
        for x in X:
            query = hidden_state[0][-1].expand_dims(axis=1)
            context = self.attention_cell(
                query, enc_outputs, enc_outputs, enc_valid_len)

            x = nd.concat(context, x.expand_dims(axis=1), dim=-1)

            out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)
            outputs.append(out)
        outputs = self.dense(nd.concat(*outputs, dim=0))
        return outputs.swapaxes(0, 1), [enc_outputs, hidden_state, enc_valid_len]

### Load data

In [4]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 1, 0.0
batch_size, num_examples, max_len = 32, 1e4, 20
lr, num_epochs, ctx = 0.005, 20, mx.cpu()

src_vocab, tgt_vocab, train_iter = d2l.load_data_nmt(batch_size, max_len, num_examples)

Downloading fra-eng.zip from http://www.manythings.org/anki/fra-eng.zip...


In [17]:
len(src_vocab), len(tgt_vocab)

(1143, 1389)

### Train

In [11]:
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)

decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)

model = d2l.EncoderDecoder(encoder, decoder)
d2l.train_ch7(model, train_iter, lr, num_epochs, ctx)

epoch 5, loss 0.113, time 361.4 sec
epoch 10, loss 0.078, time 297.9 sec
epoch 15, loss 0.061, time 289.5 sec
epoch 20, loss 0.050, time 283.8 sec


### Validate

In [12]:
sentence = "I am hungry !"
sentence + ' -> ' + d2l.translate_ch7(model, sentence, src_vocab, tgt_vocab, max_len, ctx)

"I am hungry ! -> j'ai faim !"