Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropout layer doesn't work with jax backend when using jit + StatelessScope in training mode #248

Closed
tirthasheshpatel opened this issue Jun 2, 2023 · 0 comments · Fixed by #249

Comments

@tirthasheshpatel
Copy link
Contributor

The Dropout layer throws an error 'list' object has no attribute 'dtype' when run with jit and StatelessScope enabled. It looks like the seed's state value isn't getting initialized properly.

Here's an MRE:

import keras_core
from keras_core import backend
from keras_core.operations import numpy as knp
from jax import numpy as jnp
import jax

x = knp.array([0.1, 0.2, 0.3])

@jax.jit
def train_step(x):
    with backend.StatelessScope():
        x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
    return x

keras_core.utils.traceback_utils.disable_traceback_filtering()
x = train_step(x)
assert isinstance(x, jnp.ndarray)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant