Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect handling of nan in in_top_k of Jax backend (regression from keras 3.3.3) #19995

Closed
michaelpradel opened this issue Jul 15, 2024 · 3 comments · Fixed by #20033
Closed
Assignees
Labels
backend:jax keras-team-review-pending Pending review by a Keras team member. type:Bug

Comments

@michaelpradel
Copy link

The top_in_k function in the Jax backend provides an unexpected result.

Example to reproduce the problem:

import jax.numpy as jnp
from numpy import nan
from keras.src.backend.jax.math import in_top_k

r = in_top_k(targets=jnp.array([1, 0]), predictions=jnp.array([[.1, nan, .5], [.3, .2, .5]]), k=2)
print(r)

With keras 3.3.3, I'm getting the expected outcome:
[False True]

However, keras 3.4.1 gives this:
[ True True]

The new behavior is unexpected because nan shouldn't be be considered as large probability in the prediction.

As a first step to debug it: The change in behavior has been introduced by #19814.

@sachinprasadhs
Copy link
Collaborator

sachinprasadhs commented Jul 22, 2024

I was able to replicate the reported behavior with JAX backend, attaching the Gist for reference
.

Tensorflow backend is working as expected in the latest Keras version.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@michaelpradel
Copy link
Author

Thanks for fixing this so quickly!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:jax keras-team-review-pending Pending review by a Keras team member. type:Bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants