Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plans for RNN #46

Closed
sbodenstein opened this issue Apr 7, 2017 · 23 comments
Closed

Plans for RNN #46

sbodenstein opened this issue Apr 7, 2017 · 23 comments
Labels
enhancement A feature or an optimization request

Comments

@sbodenstein
Copy link

Are there any plans to add RNN layers (compatible with the cuDNN RNN layers)? This would be exceptionally useful, given the wide usage of RNN's.

@emfomenk
Copy link

emfomenk commented Apr 7, 2017

Hi @sbodenstein,

The work is in progress.
Hopefully RNN will be available soon.

@vpirogov vpirogov added the enhancement A feature or an optimization request label Apr 7, 2017
@sbodenstein
Copy link
Author

Fantastic, this will be super useful!

@fightbob
Copy link

fightbob commented Jun 7, 2017

Hi @emfomenk, when RNN feature will be released???

@sbodenstein
Copy link
Author

@emfomenk: is the planned RNN API going to be compatible with the cuDNN version?

@emfomenk
Copy link

emfomenk commented Jun 7, 2017

Hi @fightbob and @sbodenstein,

RNN is slightly postponed -- some other urgent stuff appeared...
Unfortunately no ETA at the moment :( I will ping guys for the latest status and get back to you.

Yeah, API is going to be very close to cuDNN one.
We encountered some problems with how it would look like in C++ API (there too many constructors there for different configuration), but C API should be pretty straightforward.

@taliesinb
Copy link

Great. If you can drop any details about how the API might differ before you actually ship, that would be helpful to us for planning purposes.

@emfomenk
Copy link

emfomenk commented Jun 7, 2017

Hi @taliesinb,

All the following changes are not finalized yet.
But may give a clue how the things will look like...

mkldnn_types.h:

/** A descriptor of an RNN operation. */
typedef struct {
    /** The kind of primitive. Used for self identifying the primitive
     * descriptor. Must be #mkldnn_rnn. */
    mkldnn_primitive_kind_t primitive_kind;
    /** The kind of propagation. Possible values: #mkldnn_forward_training,
     * #mkldnn_forward_inference, #mkldnn_backward_data,
     * and #mkldnn_backward_weights. */
    mkldnn_prop_kind_t prop_kind;
    /** The kind of the RNN algorithm. Possible values:
     * #mkldnn_rnn_relu, #mkldnn_rnn_tanh, #mkldnn_rnn_lstm, #mkldnn_rnn_gru. */
    mkldnn_alg_kind_t alg_kind;
    /** The direction of the RNN. Possible values:
     * #mkldnn_rnn_unidirectional, #mkldnn_rnn_bidirectional.*/
    mkldnn_rnn_direction_t direction;
    /** The input mode of the RNN. Possible values:
     * #mkldnn_rnn_linear_input, #mkldnn_rnn_skip_input.*/
    mkldnn_rnn_input_mode_t input_mode;
    /** The number of hidden states in one cell */
    size_t num_states;
    /** The number of layers in entire RNN network */
    size_t num_layers;
    /** The length of sequences in entire RNN network */
    size_t num_seqs;
    /** state and cell output in entire RNN network */
    int state_outputs;
    /** Input(x) memory descriptor. [seq, batch, input_size] */
    mkldnn_memory_desc_t x_desc;
    /** State input(hx) memory descriptor. [layer, batch, hidden_size] */
    mkldnn_memory_desc_t hx_desc;
    /** Output(y) memory descriptor. [seq, batch, hidden_size] */
    mkldnn_memory_desc_t y_desc;
    /** Weights memory descriptor. */
    mkldnn_memory_desc_t weights_desc;

    // @TODO check if we need dropout descriptor
} mkldnn_rnn_desc_t;

mkldnn.h:

/** @addtogroup c_api_rnn RNN (Including vanilla RNN, LSTM, GRU)
 * A primitive to compute RNN.
 * @{ */

/** Initializes an rnn descriptor @p rnn_desc for forward propagation using
 * @p prop_kind (possible values are #mkldnn_forward_training or
 * #mkldnn_forward_inference), @p alg_kind (possible values are
 * #mkldnn_rnn_relu, #mkldnn_rnn_tanh, #mkldnn_rnn_lstm or #mkldnn_rnn_gru),
 * @p direction (possible values are #mkldnn_rnn_unidirectional or
 * #mkldnn_rnn_bidirectional), @p input_mode for the input mode,
 * @p num_states for the number of hidden states, @p num_layers
 * for the number of stacked layers, @p num_seqs for the length of the
 * sequences, and memory descriptors */
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(
        mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
        mkldnn_alg_kind_t alg_kind, mkldnn_rnn_direction_t direction,
        mkldnn_rnn_input_mode_t input_mode, size_t num_states,
        size_t num_layers, size_t num_seqs, int state_outputs,
        const mkldnn_memory_desc_t *x_desc,
        const mkldnn_memory_desc_t *hx_desc,
        const mkldnn_memory_desc_t *y_desc,
        const mkldnn_memory_desc_t *weights_desc);

/** Initializes an rnn descriptor @p rnn_desc for backward propagation using
 * @p alg_kind (possible values are #mkldnn_rnn_relu, #mkldnn_rnn_tanh,
 * #mkldnn_rnn_lstm or #mkldnn_rnn_gru),
 * @p direction (possible values are #mkldnn_rnn_unidirectional or
 * #mkldnn_rnn_bidirectional), @p input_mode for the input mode,
 * @p num_states for the number of hidden states, @p num_layers
 * for the number of stacked layers, @p num_seqs for the length of the
 * sequences, and memory descriptors */
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(
        mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
        mkldnn_alg_kind_t alg_kind, mkldnn_rnn_direction_t direction,
        mkldnn_rnn_input_mode_t input_mode, size_t num_states,
        size_t num_layers, size_t num_seqs, int state_outputs,
        const mkldnn_memory_desc_t *x_desc,
        const mkldnn_memory_desc_t *hx_desc,
        const mkldnn_memory_desc_t *y_desc,
        const mkldnn_memory_desc_t *weights_desc);

/** @} */

@taliesinb
Copy link

@emfomenk Thanks so much, that's very useful to know!

@piiswrong
Copy link

Any updates on this?

@sbodenstein
Copy link
Author

@emfomenk: will there be support for variable-length sequences (ie. a batch of sequences with different lengths)? cuDNN has support, but don't see this in the above design.

@taliesinb
Copy link

specifically, the concern is that because NVIDIA's design just outputs the final cell state (and not a sequence of cell states), you cannot accomplish variable length support after-the-fact, because all cell states corresponding to inputs that don't have the full batch length will be invalid. and so we simply can't use the optimization at all for variable-length problems unless it bakes variable length support into the design.

@ykim362
Copy link

ykim362 commented Nov 20, 2017

@taliesinb @sbodenstein The current design can output all the outputs(h) at the last stack and all the cell state(c) at the last time seq. But, do you need all the cell states in the middle of the sequences?

@taliesinb
Copy link

@ykim362 yes, but you have a choice. if the RNN layer wants to support variable-length operation†, it can either:

  1. provide the entire history of cell states, so that the correct per-element last time step can be selected for each batch element, e.g. by using MXNet's SequenceLast layer

  2. accept another input that contains the sequence lengths and the then expect the sequences to be densely packed (this is what cuDNN does), this way the cell state output is already correct.

† to be clear what I mean by variable-length operation, I'm referring to the case where you have a batch that contains multiple unequal sequence lengths in it -- and most sequence problems are like this. older frameworks just pad the shorter sequences with zeros and expect the net work to learn to deal with the zeros, but this fundamentally changes the problem. by far the better approach is to pad with junk, and carefully make sure that you take the 'correct' outputs and states from just before the junk using pick operations etc. we want to make sure that the MKL implementation makes this possible. Option 1 does the pick externally, option 2 does the pick internally.

@mgouicem
Copy link
Contributor

@taliesinb @sbodenstein Thanks for the comments! I think I am missing something...

I understand that cudnn current interface enables option 1 (parameters yDesc and y in cudnn doc)

For option 2, I did not find any documentation related to that. The only element I see to accommodate for variable length in cudnn API is that the inputs for each time step can have different minibatch (in decreasing order). I guess this assumes that the user has to sort the sequences in the minibatch first (e.g. input from longest sequence first in each minibatch), but there is not much details in their doc. Could you please elaborate on the use case?

@sbodenstein
Copy link
Author

I understand that cudnn current interface enables option 1 (parameters yDesc and y in cudnn doc)

This is correct for GRU and standard RNN, but untrue for LSTM, which has a second state (cell state) that is not returned in y. You have to use Option 2 in cuDNN. Also, for bidirectional, only Option 2 will work.

I guess this assumes that the user has to sort the sequences in the minibatch first

The framework/user has to indeed sort the sequences by length, and pack them. This is annoying, and would be good if the Intel version could avoid it.

Could you please elaborate on the use case?

Frameworks that support variable length RNNs require this (eg PyTorch pytorch/pytorch#873), and we wish to add this support to MXNet as well. Including @apaszke and @jekbradbury, as this discussion about the MKL RNN design seems very relevant for variable length RNNs in PyTorch as well (I think PyTorch will also want to use this MKL RNN implementation).

@taliesinb
Copy link

The framework/user has to indeed sort the sequences by length, and pack them. This is annoying, and would be good if the Intel version could avoid it.

Or better yet provide that as an optional feature.

@sbodenstein
Copy link
Author

@mgouicem: for us, the cleanest approach to supporting variable length sequences is a bit different to cuDNN approach. The approach is:

  • Accept an extra input to mkldnn_rnn_forward_desc_init that accepts a list of sequence lengths at runtime
  • The input can either be the usual shape {batch, max seq len, feature size} and then interpret values outside the sequence lengths provided at runtime as padding
  • Or the input is like cuDNN, effectively a packed piece of memory.

@apaszke
Copy link

apaszke commented Nov 28, 2017

I think at this point we're pretty much stuck with packed sequence/padded inputs in PyTorch, so it would be cool if you supported something similar. cuDNN API is quite good, except for weight format management. Please, unless it is absolutely necessary, don't require frameworks to give you weights as a single chunk of memory, and if this is needed, then at least define a format openly. Right now cuDNN's answer is "use our API to query where to put each weight", which is terribly inconvenient.

@taliesinb
Copy link

taliesinb commented Nov 29, 2017

Right now cuDNN's answer is "use our API to query where to put each weight", which is terribly inconvenient.

Yeah it really sucks. Made compilation so much more complicated for us. And the practice of not publicly defining properties, sizes, etc and putting them behind an API makes scratch memory much harder to share across buckets because the workspace size cannot be decided without querying CUDA at compile time, which MXNet does not support.

EDIT: clarify my complaint.

@mgouicem
Copy link
Contributor

Thank you for the clarification and the input. We will take that into account when designing our API.

@xhzhao
Copy link

xhzhao commented Nov 30, 2017

@apaszke @taliesinb totally agree with you about the cudnn weight format, and i think the clear weight format is very important for the framework and users.

@BenjaminJurke
Copy link

Is there any update on the timeline for the release of the RNN primitives at this point? Just curious, but very much looking forward to it.

@mgouicem
Copy link
Contributor

@BenjaminJurke , unfortunately no precise timeline for the feature yet, but we are working on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement A feature or an optimization request
Projects
None yet
Development

No branches or pull requests