-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dtype of RNN cell's state is changed to tf.float32 during reset_states #649
Comments
copybara-service bot
referenced
this issue
in keras-team/keras
Jul 24, 2023
The backend.variable was asssuming float32 when dtype is not provided. The RNN init state should pass the init state dtype to the backend.variable. Seehttps://github.com/keras-team/keras/issues/15164 for more details. PiperOrigin-RevId: 550600164
copybara-service bot
referenced
this issue
in keras-team/keras
Jul 24, 2023
The backend.variable was asssuming float32 when dtype is not provided. The RNN init state should pass the init state dtype to the backend.variable. Seehttps://github.com/keras-team/keras/issues/15164 for more details. PiperOrigin-RevId: 550600164
copybara-service bot
referenced
this issue
in keras-team/keras
Jul 24, 2023
The backend.variable was asssuming float32 when dtype is not provided. The RNN init state should pass the init state dtype to the backend.variable. Seehttps://github.com/keras-team/keras/issues/15164 for more details. PiperOrigin-RevId: 550619673
tilakrayal
added
stat:awaiting response from contributor
and removed
stat:awaiting keras-eng
labels
Sep 22, 2023
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
provided in Keras): yes
Describe the problem.
I have implemented a recurrent cell which is to be wrapped within a
tf.keras.layers.RNN.
The cell has a state whose data type is nottf.float32
buttf.complex64
. However, each time whenlayer.reset_states()
is invoked, the data type of the state is changed totf.float32
. As a result, a value error is thrown during the initial symbolic call. See attached stack trace.Describe the current behavior.
The programm crashes at the construction of the RNN layer. See attached stack trace.
Contributing.
I assume, a reason for this issue is line 933, 934 in function
reset_states
in classRNN
in file keras/layers/recurrent.pyHere, the initialized state values are stored in
flat_init_state_values
andbackend.variable
is called on each of the states. However, nodtype
argument is passed tobackend.variable
. As a consequence it defaults totf.float32
for all states. TI would recommend the following patch, which solves the issue for me
I also tried to run the example after replacing the files affected by the latest commit regarding mixed precision. Unfortunately it did not solve the issue for me
Standalone code to reproduce the issue.
Currently, the example fails at the construction of the RNN layer.
Source code / Logs
Stacktrace:
stacktrace.txt
The text was updated successfully, but these errors were encountered: