New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Specifying Initial States of RNN Layers #5795
Conversation
Fix when initial_states is a tensor
The proper behavior in that case would to check if all tensors in the list are Keras tensors. If it's mixed, raise an exception. If they're all Keras tensors, add them to the inputs. Else, use I am confused by what you are referring to as "numerical values". There are no numerical values involved (e.g. Numpy tensors), only symbolic tensors (e.g. TF tensors). Can you clarify? |
In the API we should use |
It is unclear to me why we are checking if they are Keras tensors in the first place. If
My apologies, I misspoke. In the current code, if
I am fine with using initial_state everywhere in the API. However, I feel that we should change the code in |
A non-Keras tensor set as initial state will generally not be a constant. It will simply be a non-Keras tensor, dependent or not on the underlying model's inputs. |
keras/layers/recurrent.py
Outdated
else: | ||
kwargs['initial_state'] = initial_state | ||
|
||
# We need to build the layer so that state_spec exists. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this should be delegated to the parent's __call__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, state_spec
is defined in Recurrent.build
. When, initial_state
is passed, we need to build the layer so that we can use state_spec
.
keras/layers/recurrent.py
Outdated
# Compute the full inputs, including state | ||
if not isinstance(initial_state, (list, tuple)): | ||
initial_state = [initial_state] | ||
inputs = [inputs] + list(initial_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no longer any check that the initial states are Keras tensors, which will cause the model construction to fail when using non-Keras tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For every other type of layer, model construction fail when using non-Keras tensors. Why are we making a special exception for the initial states of RNNs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not making an exception, RNN layers will behave like every other layer. We're just building an API. The API is that RNN(inputs, initial_state=x)
should work as a way to set initial state tensors, independently of the value of x
(Keras tensor or not). It's actually very simple to set up, via a switch between inputs and layer keyword arguments.
The latest commit moves the definition of |
Any volunteers to review this PR? |
Any updates on this? Setting the initial state seems to be an important component of any viable Seq2Seq-Model. |
Looks good. Please merge if no other issues. |
Just a heads up, this breaks because of masking in a seq2seq model with the tensorflow backend. This is because Something similar will happen whenever you pass a tf tensor as the initial state and either the initial state or the RNN has a mask (so it would also affect image captioning, RNN VAEs etc). In gratuitous detail:
|
@fchollet I think it would be nice to have Keras topology handle optional inputs. First step would be to make it such a way that input spec for layers with multiple inputs is not a list, but a single |
@Joshua-Chin Have sent you a pull request to fix masking issue mentioned by @AMabona. |
@farizrahman4u currently only this layer will require optional inputs, so an initial ad-hoc system is fine in this case. Later we can use what we learned from this ad-hoc implementation to write a more general system that can apply to all layers. |
Could someone please review this PR? |
@fchollet Should reset_states be renamed to reset_state (with backward compatibility) for consistency? |
@farizrahman4u @fchollet What's the status of the review? |
I think we are good.. fix the typo though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
keras/layers/recurrent.py
Outdated
' non-Keras tensors') | ||
|
||
if is_keras_tensor: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary blank line
I have begun to review the changes made in da21c15. There appears to a number of issues
initial_states
is a list of tensors, it will not have the attribute_keras_history
, and will be treated as a numerical value.initial_states
is passed as a keyword argument tocall
, it will be ignored / overwritten.state_spec
may not be defined by the time it is used (state_spec
is defined inbuild
).reset_states
, there is a check ifinput_spec
is notNone
. However,input_spec
is neverNone
, because it is defined in__init__
to beInputSpec(ndim=3)
.initial_state
andinitial_states
.reset_states(state_values)
is inconsistent with the signature forset_weights(weights)
test_specify_initial_states
does not check ifinitial_states
is part of the computational graph.This commit does the following to fix those issues:
initial_states
is passed to__call__
it is always treated as a tensor / list of tensors. It a user wants to specify the state numerically, they should pass a numpy array toreset_states
.__call__
before usingstate_spec
.input_spec
is removed.initial_states
is used throughout the code and documentation.reset_states(state_values)
is changed toreset_states(states)
.test_specify_states
explicitly checks ifinitial_states
is added to the computational graph.#5738