Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtype of RNN cell's state is changed to tf.float32 during reset_states #649

Closed
TillHa opened this issue Aug 13, 2021 · 3 comments
Closed

Comments

@TillHa
Copy link

TillHa commented Aug 13, 2021

  • Have I written custom code (as opposed to using a stock example script
    provided in Keras)
    : yes
  • OS Platform and Distribution: both win10 and CentOS Linux
  • TensorFlow installed from: pip
  • TensorFlow version: 2.6.0
  • Python version: 3.8
  • Exact command to reproduce: tf.keras.layers.RNN(cell)

Describe the problem.

I have implemented a recurrent cell which is to be wrapped within a tf.keras.layers.RNN. The cell has a state whose data type is not tf.float32 but tf.complex64. However, each time when layer.reset_states() is invoked, the data type of the state is changed to tf.float32. As a result, a value error is thrown during the initial symbolic call. See attached stack trace.

Describe the current behavior.
The programm crashes at the construction of the RNN layer. See attached stack trace.

Contributing.

I assume, a reason for this issue is line 933, 934 in function reset_states in class RNN in file keras/layers/recurrent.py

      flat_states_variables = tf.nest.map_structure(
          backend.variable, flat_init_state_values)

Here, the initialized state values are stored in flat_init_state_values and backend.variable is called on each of the states. However, no dtype argument is passed to backend.variable. As a consequence it defaults to tf.float32 for all states. T
I would recommend the following patch, which solves the issue for me

      flat_states_variables = tf.nest.map_structure(
    lambda var: backend.variable(var, var.dtype), flat_init_state_values)

I also tried to run the example after replacing the files affected by the latest commit regarding mixed precision. Unfortunately it did not solve the issue for me

Standalone code to reproduce the issue.

Currently, the example fails at the construction of the RNN layer.

import tensorflow as tf

class RecurrentCell(tf.keras.layers.Layer):
    def __init__(self, state_size):
        super(RecurrentCell, self).__init__()
        self.state_size = state_size

    def build(self, input_shape):
        super(RecurrentCell, self).build(input_shape)

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        # explicit initialization with tf.complex64
        return tf.zeros((self.state_size, ), dtype=tf.complex64)

    @tf.function
    def call(self, inputs, states):
        # toy example
        x = inputs
        xfd = tf.signal.rfft(x)[..., :self.state_size]
        yfd = tf.multiply(xfd, states)
        return tf.signal.irfft(yfd), states


recCell = RecurrentCell(state_size=5)

inp = tf.keras.Input(shape=(None, 8),
                     batch_size=32)
out = tf.keras.layers.RNN(recCell,  # crashes
                          return_sequences=True,
                          stateful=True,
                          return_state=False)(inp)
model = tf.keras.Model(inputs=[inp], outputs=[out])

y = model.predict(tf.random.normal((32, 16, 8)))

Source code / Logs
Stacktrace:
stacktrace.txt

copybara-service bot referenced this issue in keras-team/keras Jul 24, 2023
The backend.variable was asssuming float32 when dtype is not provided. The RNN init state should pass the init state dtype to the backend.variable.

Seehttps://github.com/keras-team/keras/issues/15164 for more details.

PiperOrigin-RevId: 550600164
copybara-service bot referenced this issue in keras-team/keras Jul 24, 2023
The backend.variable was asssuming float32 when dtype is not provided. The RNN init state should pass the init state dtype to the backend.variable.

Seehttps://github.com/keras-team/keras/issues/15164 for more details.

PiperOrigin-RevId: 550600164
copybara-service bot referenced this issue in keras-team/keras Jul 24, 2023
The backend.variable was asssuming float32 when dtype is not provided. The RNN init state should pass the init state dtype to the backend.variable.

Seehttps://github.com/keras-team/keras/issues/15164 for more details.

PiperOrigin-RevId: 550619673
@tilakrayal
Copy link
Collaborator

@TillHa,
I tried to execute the mentioned code on tf-nightly(2.15.0-dev20230922) without any issue/error. Kindly find the gist of it here. Thank you!

@tilakrayal tilakrayal assigned tilakrayal and unassigned qlzh727 Sep 22, 2023
@fchollet fchollet transferred this issue from keras-team/keras Sep 22, 2023
@github-actions
Copy link

github-actions bot commented Oct 9, 2023

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Oct 9, 2023
@github-actions
Copy link

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants