Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[API] Split up Seq2SeqDecoder in Seq2SeqDecoder and Seq2SeqOneStepDec…
Browse files Browse the repository at this point in the history
…oder (#976)

* Split up Seq2SeqDecoder into Seq2SeqDecoder and Seq2SeqOneStepDecoder

In the current Gluon API, each HybridBlock has to serve one puropse and can only
define a single callable interface. Previous Seq2SeqDecoder interface required
each Seq2SeqDecoder Block to perform two functionalities (multi-step ahead and
single-step ahead decoding). This means neither of the two functionalities can
in practice be hybridized completely. Thus use two separate Blocks for the two
functionalities. They may share parameters.

Update the NMTModel API accordingly.

Further refactor TransformerDecoder to make it completely hybridizable.
TransformerOneStepDecoder still relies on a small hack but can be hybridized
completely when we enable numpy shape semantics.

* Extend unit tests to include one-step decoding

* Improve doc
  • Loading branch information
leezu committed Oct 29, 2019
1 parent bfa5503 commit 57a45aa
Show file tree
Hide file tree
Showing 11 changed files with 575 additions and 463 deletions.
10 changes: 5 additions & 5 deletions docs/examples/machine_translation/gnmt.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ feed the encoder and decoder to the `NMTModel` to construct the GNMT model.
`model.hybridize` allows computation to be done using the symbolic backend. To understand what it means to be "hybridized," please refer to [this](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/hybrid.html) page on MXNet hybridization and its advantages.

```{.python .input}
encoder, decoder = nmt.gnmt.get_gnmt_encoder_decoder(hidden_size=num_hidden,
dropout=dropout,
num_layers=num_layers,
num_bi_layers=num_bi_layers)
encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder(
hidden_size=num_hidden, dropout=dropout, num_layers=num_layers,
num_bi_layers=num_bi_layers)
model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder,
decoder=decoder, embed_size=num_hidden, prefix='gnmt_')
decoder=decoder, one_step_ahead_decoder=one_step_ahead_decoder,
embed_size=num_hidden, prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
Expand Down
267 changes: 162 additions & 105 deletions scripts/machine_translation/gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder usded in sequence-to-sequence learning."""
__all__ = ['GNMTEncoder', 'GNMTDecoder', 'get_gnmt_encoder_decoder']
__all__ = ['GNMTEncoder', 'GNMTDecoder', 'GNMTOneStepDecoder', 'get_gnmt_encoder_decoder']

import mxnet as mx
from mxnet.base import _as_list
from mxnet.gluon import nn, rnn
from mxnet.gluon.block import HybridBlock
from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \
_get_attention_cell, _get_cell_type, _nested_sequence_last
Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last


class GNMTEncoder(Seq2SeqEncoder):
Expand Down Expand Up @@ -158,48 +158,14 @@ def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arg
return [outputs, new_states], []


class GNMTDecoder(HybridBlock, Seq2SeqDecoder):
"""Structure of the RNN Encoder similar to that used in the
Google Neural Machine Translation paper.
We use gnmt_v2 strategy in tensorflow/nmt
Parameters
----------
cell_type : str or type
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
hidden_size : int
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
class _BaseGNMTDecoder(HybridBlock):
def __init__(self, cell_type='lstm', attention_cell='scaled_luong',
num_layers=2, hidden_size=128,
dropout=0.0, use_residual=True, output_attention=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
prefix=None, params=None):
super(GNMTDecoder, self).__init__(prefix=prefix, params=params)
super().__init__(prefix=prefix, params=params)
self._cell_type = _get_cell_type(cell_type)
self._num_layers = num_layers
self._hidden_size = hidden_size
Expand Down Expand Up @@ -249,59 +215,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
decoder_states.append(mem_masks)
return decoder_states

def decode_seq(self, inputs, states, valid_length=None):
"""Decode the decoder inputs. This function is only used for training.
Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:
- rnn_states : NDArray
- attention_vec : NDArray
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
length = inputs.shape[1]
output = []
additional_outputs = []
inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
rnn_states_l = []
attention_output_l = []
fixed_states = states[2:]
for i in range(length):
ele_output, states, ele_additional_outputs = self.forward(inputs[i], states)
rnn_states_l.append(states[0])
attention_output_l.append(states[1])
output.append(ele_output)
additional_outputs.extend(ele_additional_outputs)
output = mx.nd.stack(*output, axis=1)
if valid_length is not None:
states = [_nested_sequence_last(rnn_states_l, valid_length),
_nested_sequence_last(attention_output_l, valid_length)] + fixed_states
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
if self._output_attention:
additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
return output, states, additional_outputs

def __call__(self, step_input, states): #pylint: disable=arguments-differ
def forward(self, step_input, states): # pylint: disable=arguments-differ
"""One-step-ahead decoding of the GNMT decoder.
Parameters
Expand All @@ -326,11 +240,7 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ
The attention weights will have shape (batch_size, 1, mem_length) or
(batch_size, num_heads, 1, mem_length)
"""
return super(GNMTDecoder, self).__call__(step_input, states)

def forward(self, step_input, states): #pylint: disable=arguments-differ, missing-docstring
step_output, new_states, step_additional_outputs =\
super(GNMTDecoder, self).forward(step_input, states)
step_output, new_states, step_additional_outputs = super().forward(step_input, states)
# In hybrid_forward, only the rnn_states and attention_vec are calculated.
# We directly append the mem_value and mem_masks in the forward() function.
# We apply this trick because the memory value/mask can be directly appended to the next
Expand Down Expand Up @@ -402,6 +312,148 @@ def hybrid_forward(self, F, step_input, states): #pylint: disable=arguments-dif
return rnn_out, new_states, step_additional_outputs


class GNMTOneStepDecoder(_BaseGNMTDecoder, Seq2SeqOneStepDecoder):
"""RNN Encoder similar to that used in the Google Neural Machine Translation paper.
One-step ahead decoder used during inference.
We use gnmt_v2 strategy in tensorflow/nmt
Parameters
----------
cell_type : str or type
Can be "lstm", "gru" or constructor functions that can be directly called,
like rnn.LSTMCell
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
Total number of layers
hidden_size : int
Number of hidden units
dropout : float
The dropout rate
use_residual : bool
Whether to use residual connection. Residual connection will be added in the
uni-directional RNN layers
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""


class GNMTDecoder(_BaseGNMTDecoder, Seq2SeqDecoder):
"""RNN Encoder similar to that used in the Google Neural Machine Translation paper.
Multi-step decoder used during training with teacher forcing.
We use gnmt_v2 strategy in tensorflow/nmt
Parameters
----------
cell_type : str or type
Can be "lstm", "gru" or constructor functions that can be directly called,
like rnn.LSTMCell
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
Total number of layers
hidden_size : int
Number of hidden units
dropout : float
The dropout rate
use_residual : bool
Whether to use residual connection. Residual connection will be added in the
uni-directional RNN layers
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""

def forward(self, inputs, states, valid_length=None): # pylint: disable=arguments-differ
"""Decode the decoder inputs. This function is only used for training.
Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:
- rnn_states : NDArray
- attention_vec : NDArray
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
length = inputs.shape[1]
output = []
additional_outputs = []
inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
rnn_states_l = []
attention_output_l = []
fixed_states = states[2:]
for i in range(length):
ele_output, states, ele_additional_outputs = super().forward(inputs[i], states)
rnn_states_l.append(states[0])
attention_output_l.append(states[1])
output.append(ele_output)
additional_outputs.extend(ele_additional_outputs)
output = mx.nd.stack(*output, axis=1)
if valid_length is not None:
states = [_nested_sequence_last(rnn_states_l, valid_length),
_nested_sequence_last(attention_output_l, valid_length)] + fixed_states
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
if self._output_attention:
additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
return output, states, additional_outputs


def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', num_layers=2,
num_bi_layers=1, hidden_size=128, dropout=0.0, use_residual=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
Expand Down Expand Up @@ -435,19 +487,24 @@ def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', nu
decoder : GNMTDecoder
"""
encoder = GNMTEncoder(cell_type=cell_type, num_layers=num_layers, num_bi_layers=num_bi_layers,
hidden_size=hidden_size, dropout=dropout,
use_residual=use_residual,
hidden_size=hidden_size, dropout=dropout, use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer,
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
prefix=prefix + 'enc_', params=params)
h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'enc_',
params=params)
decoder = GNMTDecoder(cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers,
hidden_size=hidden_size, dropout=dropout,
use_residual=use_residual,
hidden_size=hidden_size, dropout=dropout, use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer,
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
prefix=prefix + 'dec_', params=params)
return encoder, decoder
h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_',
params=params)
one_step_ahead_decoder = GNMTOneStepDecoder(
cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers,
hidden_size=hidden_size, dropout=dropout, use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_',
params=decoder.collect_params())
return encoder, decoder, one_step_ahead_decoder
17 changes: 7 additions & 10 deletions scripts/machine_translation/inference_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,14 @@
else:
tgt_max_len = max_len[1]

encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
hidden_size=args.hidden_size,
dropout=args.dropout,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500),
scaled=args.scaled)
encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder(
units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout,
num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
share_embed=args.dataset != 'TOY', embed_size=args.num_units,
tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')
one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset != 'TOY',
embed_size=args.num_units, tie_weights=args.dataset != 'TOY',
embed_initializer=None, prefix='transformer_')

param_name = args.model_parameter
if (not os.path.exists(param_name)):
Expand Down
14 changes: 7 additions & 7 deletions scripts/machine_translation/train_gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@
else:
ctx = mx.gpu(args.gpu)

encoder, decoder = get_gnmt_encoder_decoder(hidden_size=args.num_hidden,
dropout=args.dropout,
num_layers=args.num_layers,
num_bi_layers=args.num_bi_layers)
encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder(
hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers,
num_bi_layers=args.num_bi_layers)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
embed_size=args.num_hidden, prefix='gnmt_')
one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden,
prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
Expand Down Expand Up @@ -175,8 +175,8 @@ def evaluate(data_loader):
avg_loss += loss * (tgt_seq.shape[1] - 1)
avg_loss_denom += (tgt_seq.shape[1] - 1)
# Translate
samples, _, sample_valid_length =\
translator.translate(src_seq=src_seq, src_valid_length=src_valid_length)
samples, _, sample_valid_length = translator.translate(
src_seq=src_seq, src_valid_length=src_valid_length)
max_score_sample = samples[:, 0, :].asnumpy()
sample_valid_length = sample_valid_length[:, 0].asnumpy()
for i in range(max_score_sample.shape[0]):
Expand Down

0 comments on commit 57a45aa

Please sign in to comment.