Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Tentative RNN Interface #4618

Merged
merged 9 commits into from
Feb 4, 2017
Merged

[WIP] Tentative RNN Interface #4618

merged 9 commits into from
Feb 4, 2017

Conversation

piiswrong
Copy link
Contributor

def __call__(self, inputs, states, params, prefix=''):
W = params.get('%si2h_weight'%prefix)
B = params.get('%si2h_bias'%prefix)
U = params.get('%sh2h_weight'%prefix)
Copy link
Member

@sxjscience sxjscience Jan 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to put the params + prefix to the __init__ of the RNNCell?
So we do not need to manually set the prefix when calling the one-step RNN

Copy link
Contributor Author

@piiswrong piiswrong Jan 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Because when you put cells in to stackedcell it needs to use different prefix for each stack

Copy link
Member

@sxjscience sxjscience Jan 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Previously I expect the usage of the RNNCell to be like the following:

rnn1 = RNNCell(num_hidden=10, activation="tanh", name="rnn1")
state = rnn1.begin_state()
for i in range(10):
     out, state = rnn1(inputs=dat[i], states=state)

We first define a RNNCell and then repeatedly apply it to make sure that the parameters are shared.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works too. Though we need to be clear that cells should not be copied to form stacks.
I did this because tf did this.
I don't have strong feelings one way or another.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support variable scope like TF? I find that they manage the weights/biases using the scope name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think prefix + param is good enough. It aligns with mxnet name matching better

Copy link
Member

@sxjscience sxjscience Jan 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we do not need both the param and prefix to determine the weights and biases. We can get the names purely by the prefix. If we keep the current way, we will need to create the RNNParam every time we use RNN. Could we design it as an inner registry?

Nevertheless, I think that it's acceptable to keep an additional parameter.

def output_shape(self):
return (0, self._num_hidden)

def __call__(self, inputs, states, params, prefix=''):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest default prefix be "rnn"

Copy link
Contributor

@leopd leopd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good start. I'd like to see how the cuDNN symbol gets used in this framework, and also what the application-level code looks like for problems like sequence classification, or sequence-to-sequence, all with and without attention. And then maybe we build utility classes for constructing this kind of network as well. But this level of interface probably needs to underly those application-level modules.

from .. import ndarray
from ..base import numeric_types, string_types

class RNNParams(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest comment
"""This class holds the learnable parameters (weights, biases) for the RNN.
New params are added as needed with get()
"""

def begin_state(self, prefix='', init_sym=symbol.zeros, **kwargs):
"""initial state"""
state_shape = self.state_shape
def recursive(shape, c):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this is doing. Can you add a comment -- It looks like the state_shape property is overloaded to support multiple different types. We should document what the options are and why.

self._counter = 0

@property
def state_shape(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested comment:
"""LSTM has two internal states, typically called "cell" and "hidden". Here we're setting them both to the same size."""



class StackedRNNCell(BaseRNNCell):
"""Stacked multple rnn cels"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multple -> multiple

This takes a single RNN cell and returns a simple stack of it? No. This takes a list of RNN cells and puts them together to make a stack by wiring the outputs to the inputs. We should say that.

Also, that seems like a bunch of work for somebody to have to do to use a stacked LSTM. They should just be able to do this more easily, like with a single constructor. This class seems more like an "RNNCellStacker" than a "StackedRNNCell".

@@ -48,6 +48,10 @@ def __repr__(self):
return '<%s %s>' % (self.__class__.__name__,
'Grouped' if name is None else name)

def __iter__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty fundamental change. Maybe this should be a named method like "outputs"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bug fix. Previously for x in symbol will fail

@@ -67,26 +67,72 @@ struct InferTypeError {
: msg(msg), index(index) {}
};

/*! \brief check if shape is empty or contains unkown (0) dim. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unkown -> unknown

@sxjscience
Copy link
Member

From my point of view, one way to support CuDNN optimization is to enable multi-step forwarding (e.g add an argument in __call__ called step_num or seq_len) in the RNNCell and add a flag in __init__ to determine whether to use CuDNN . However, we will also need to add another operator to convert the biases and weights to the "parameter" in mx.sym.RNN.

https://github.com/ML-HK/mxnet/blob/master/python/mxnet/recurrent.py#L90-L120

@piiswrong
Copy link
Contributor Author

Cudnn will be used as a function fused_rnn

Copy link

@zarandioon zarandioon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few minor comments.



def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
lines = open(fname).readlines()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with open(fname, 'r') as f:
        lines = f.readlines()

# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

provide_data=[(self.data_name, data.shape)],
provide_label=[(self.label_name, label.shape)])


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up empty lines?

@pluskid
Copy link
Contributor

pluskid commented Jan 26, 2017

Overall, it looks great. Several comments:

  1. The handling of init states is much better than the previous ad hoc way. However, I guess the user needs to re-construct the symbol, instead of just loading from json, in order to do inference, because the init states will be feeded inputs, instead of mx.zeros. Is there some way to make this easier?

  2. The current interface handle time steps as a list, by using SliceChannel to split and later on re-merge. I'm wondering if there is a performance issue here (e.g. for applications with ~2k sequence length). Would it be better to use a fully packed tensor, and making one of the dimension as the time? (i.e. time-major / batch-major data layout). The benefit is also that cuDNN cell uses this layout, so it might be easier to wrap. Another benefit is that symbolically, an RNN with seq-len 10 and seq-len 20 will be exactly the same, except binded with different input shapes.

  3. Currently there seem to be simple one layer RNN cell and then sequential RNN cell that could contain a stack of RNN cells. I'm wondering do we need to expose many different concepts to the user. With commonly used wrappers, for both the tensorflow interface and the theano interface, the RNN cell is a single function that could be used to create multiple layers with a single call. Plus, it can be mixed with other operators, without needing to explicitly call unroll. Our own cuDNN RNN Cell demo is an example that I think is kind of keeping the syntax as close as calling other mxnet operator.

"""shape(s) of output"""
return (0, self._num_hidden)

def __call__(self, inputs, states):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong @pluskid Should we also add a "seq_len" or "input_length" argument to this __call__ function? Like the "input_length" in Keras https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py#L115-L123.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do it will rnn_unrll

@piiswrong
Copy link
Contributor Author

piiswrong commented Jan 26, 2017

  1. Why would the init state be non zero for inference? @pluskid
  2. The current interface allows more flexibility. You can easily unroll it with rnn_unroll. Cudnn rnn will be called FusedRNN. I haven't decided if it should be a Cell or a function. Function is probably more natural in this case.
  3. This can be a higher level interface.

@mz24cn
Copy link

mz24cn commented Jan 28, 2017

@pluskid
Copy link
Contributor

pluskid commented Jan 28, 2017

@piiswrong

  1. I'm talking about step-by-step inference, where one need to do sampling for each time step, and feed that as input to the next time step. The state needs to be forwarded, too. Another situation is to implement truncated-bptt, where the backward is truncated for a fixed number of steps, but forward state is kept. See our speech demo for an existing implementation of bptt.

  2. So the cuDNN backend will be wrapped in a different operator?

@piiswrong
Copy link
Contributor Author

@pluskid

  1. I see , you can do that by calling begin_state(init_sym=sym.Variable), I'll document this.
  2. Yes.

@sxjscience
Copy link
Member

The PR looks good to me.

@pluskid
Copy link
Contributor

pluskid commented Jan 30, 2017

I think it is better to have a unified RNN interface for cuDNN and non-cuDNN. But if that is making things too complicated, this looks good to me, too.

@piiswrong
Copy link
Contributor Author

@pluskid That's very hard given cudnn rnn packs weight.

@sxjscience
Copy link
Member

I agree with Eric for the cuDNN part. It's really difficult to handle the weights when there are more than one layers. Also, is it possible to detect RNN-like patterns in the constructed symbol and use cuDNN-RNN as a kind of kernel fusion?

@piiswrong piiswrong merged commit 8b9d909 into apache:master Feb 4, 2017

ndiscard = 0
self.data = [[] for _ in buckets]
for i in xrange(len(sentences)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this makes the otherwise python3 compatible code incompatible to python 3.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change xrange to range in python3.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants