Skip to content
This repository has been archived by the owner on Feb 9, 2021. It is now read-only.

Commit

Permalink
[MXNET-766] add dynamic_unroll RNN for HybridBlock (apache#11948)
Browse files Browse the repository at this point in the history
* add contrib unroll.

* reenable some tests.

* fix a bug.

* fix lint.

* fix a bug.

* support diff layouts.

* update doc.

* use a diff default layout.

* remove _contrib_format_sequence.

* fix lint.

* rename.
  • Loading branch information
zheng-da authored and Gordon Reid committed Feb 19, 2019
1 parent 1befc5a commit 544419c
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 1 deletion.
115 changes: 115 additions & 0 deletions python/mxnet/gluon/contrib/rnn/rnn_cell.py
Expand Up @@ -22,6 +22,7 @@
from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
from ... import tensor_types
from ....base import _as_list

class VariationalDropoutCell(ModifierCell):
"""
Expand Down Expand Up @@ -320,3 +321,117 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,

return next_r, [next_r, next_c]
# pylint: enable= arguments-differ


def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
layout='TNC', valid_length=None):
"""Unrolls an RNN cell across time steps.
Currently, 'TNC' is a preferred layout. unroll on the input of this layout
runs much faster.
Parameters
----------
cell : an object whose base class is RNNCell.
The RNN cell to run on the input sequence.
inputs : Symbol
It should have shape (batch_size, length, ...) if `layout` is 'NTC',
or (length, batch_size, ...) if `layout` is 'TNC'.
begin_state : nested list of Symbol
The initial states of the RNN sequence.
drop_inputs : float, default 0.
The dropout rate for inputs. Won't apply dropout if it equals 0.
drop_outputs : float, default 0.
The dropout rate for outputs. Won't apply dropout if it equals 0.
layout : str, optional
`layout` of input symbol. Only used if inputs
is a single Symbol.
valid_length : Symbol, NDArray or None
`valid_length` specifies the length of the sequences in the batch without padding.
This option is especially useful for building sequence-to-sequence models where
the input and output sequences would potentially be padded.
If `valid_length` is None, all sequences are assumed to have the same length.
If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
The ith element will be the length of the ith sequence in the batch.
The last valid state will be return and the padded outputs will be masked with 0.
Note that `valid_length` must be smaller or equal to `length`.
Returns
-------
outputs : Symbol
the output of the RNN from this unrolling.
states : list of Symbol
The new state of this RNN after this unrolling.
The type of this symbol is same as the output of `begin_state`.
Examples
--------
>>> seq_len = 3
>>> batch_size = 2
>>> input_size = 5
>>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_')
>>> cell.initialize(ctx=mx.cpu())
>>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
>>> state_shape = (batch_size, input_size)
>>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)]
>>> valid_length = mx.nd.array([2, 3])
>>> output, states = mx.gluon.contrib.rnn.rnn_cell.dynamic_unroll(cell, rnn_data, states,
valid_length=valid_length,
layout='TNC')
>>> print(output)
[[[ 0.00767238 0.00023103 0.03973929 -0.00925503 -0.05660512]
[ 0.00881535 0.05428379 -0.02493718 -0.01834097 0.02189514]]
[[-0.00676967 0.01447039 0.01287002 -0.00574152 -0.05734247]
[ 0.01568508 0.02650866 -0.04270559 -0.04328435 0.00904011]]
[[ 0. 0. 0. 0. 0. ]
[ 0.01055336 0.02734251 -0.03153727 -0.03742751 -0.01378113]]]
<NDArray 3x2x5 @cpu(0)>
"""

# Merge is always True, so we don't need length.
inputs, axis, F, _ = _format_sequence(0, inputs, layout, True)
if axis != 0:
axes = list(range(len(layout)))
tmp = axes[0]
axes[0] = axes[axis]
axes[axis] = tmp
inputs = F.transpose(inputs, axes=axes)
states = begin_state

if drop_inputs:
inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,))

if valid_length is None:
def loop_body(inputs, states):
return cell(inputs, states)
else:
zeros = []
for s in states:
zeros.append(F.zeros_like(s))
states = list(_as_list(states))
states.append(F.zeros((1)))
def loop_body(inputs, states):
cell_states = states[:-1]
iter_no = states[-1]
out, new_states = cell(inputs, cell_states)
for i, state in enumerate(cell_states):
new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no),
new_states[i], state)
new_states.append(iter_no + 1)
return out, new_states

outputs, states = F.contrib.foreach(loop_body, inputs, states)
if drop_outputs:
outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,))
if valid_length is not None:
if axis != 0:
outputs = F.transpose(outputs, axes)
outputs = F.SequenceMask(outputs, sequence_length=valid_length,
use_sequence_length=True, axis=axis)
# the last state is the iteration number. We don't need it.
return outputs, states[:-1]
else:
if axis != 0:
outputs = F.transpose(outputs, axes)
return outputs, states
95 changes: 94 additions & 1 deletion tests/python/unittest/test_gluon_contrib.py
Expand Up @@ -17,12 +17,14 @@

from __future__ import print_function
import mxnet as mx
import copy
from mxnet import gluon
from mxnet.gluon import contrib
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import (
Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D,
PixelShuffle2D, PixelShuffle3D)
from mxnet.test_utils import almost_equal
from mxnet.test_utils import almost_equal, default_context, assert_almost_equal
from common import setup_module, with_seed, teardown
import numpy as np
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -313,6 +315,97 @@ def test_sampler():
assert list(interval_sampler) == [0, 3, 6, 9]


class TestRNNLayer(gluon.HybridBlock):
def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.cell = cell_type(hidden_size, prefix='rnn_')
self.layout = layout

def hybrid_forward(self, F, inputs, states, valid_length):
if isinstance(valid_length, list) and len(valid_length) == 0:
valid_length = None
return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states,
valid_length=valid_length,
layout=self.layout)

def check_unroll(cell_type, num_states, layout):
batch_size = 20
input_size = 50
hidden_size = 30
seq_len = 10
if layout == 'TNC':
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
elif layout == 'NTC':
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size))
else:
print("Wrong layout")
return
valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size)))
state_shape = (batch_size, hidden_size)
states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]

cell = cell_type(hidden_size, prefix='rnn_')
cell.initialize(ctx=default_context())
if layout == 'TNC':
cell(rnn_data[0], states)
else:
cell(rnn_data[:,0,:], states)
params1 = cell.collect_params()
orig_params1 = copy.deepcopy(params1)

trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
layout=layout, merge_outputs=True)
res1.backward()
trainer.step(batch_size)

configs = [
lambda layer: None,
lambda layer: layer.hybridize(),
lambda layer: layer.hybridize({'inline_limit': 0}),
lambda layer: layer.hybridize({'static_alloc': True}),
lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ]
# We can't pass None to a hybrid block, but it accepts an empty list.
# so we use an empty list to represent valid_length if it's None.
if valid_length is None:
valid_length = []
for config in configs:
layer = TestRNNLayer(cell_type, hidden_size, layout)
layer.initialize(ctx=default_context())
config(layer)
res2, states2 = layer(rnn_data, states, valid_length)
params2 = layer.collect_params()
for key, val in orig_params1.items():
params2[key].set_data(copy.deepcopy(val.data()))

trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res2, states2 = layer(rnn_data, states, valid_length)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
assert len(states1) == len(states2)
for i in range(len(states1)):
assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(),
rtol=0.001, atol=0.0001)
res2.backward()
trainer.step(batch_size)

for key, val in params1.items():
weight1 = val.data()
weight2 = params2[key].data()
assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(),
rtol=0.001, atol=0.0001)


@with_seed()
def test_contrib_unroll():
cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
(gluon.rnn.GRUCell, 1)]
for cell_type, num_states in cell_types:
check_unroll(cell_type, num_states, 'TNC')
check_unroll(cell_type, num_states, 'NTC')


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 544419c

Please sign in to comment.