Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TypeError in contrib.reduce_on_plateau() when x64 is enabled #697

Conversation

stefanocortinovis
Copy link
Contributor

The current implementation of reduce_on_plateau raises a TypeError when it is used with 64 bit floats enabled.

The error can be reproduced by running the following MRE:

from jax import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
from optax import contrib

updates = {"params": jnp.array(1.0)}
transform = contrib.reduce_on_plateau()
transform_state = transform.init(updates["params"])

updates, transform_state = transform.update(updates=updates, state=transform_state, loss=1.0)  # raises TypeError

This is due to the fact that, when 64 bit floats are enabled, the two branches of the jax.lax.cond at line 114 of reduce_on_plateau.py, in_cooldown and not_in_cooldown return types (int64, float32, int32) and (int32, float32, int64).

The issue can be fixed by explicitly casting the two int64 to int32.

@stefanocortinovis stefanocortinovis changed the title Fix TypeError in contrib.reduce_on_plateau() when x64 is enabled Fix TypeError in contrib.reduce_on_plateau() when x64 is enabled Jan 9, 2024
@fabianp
Copy link
Member

fabianp commented Jan 11, 2024

Thanks for the contribution! This looks good to me, but could you add a test so we're sure we don't break this in the future? Thanks again!

@stefanocortinovis
Copy link
Contributor Author

I added tests for the x64 fix as parameterised test cases of the ones that were already in reduce_on_plateau_test.py. While I was at it, I've also slightly refactored the test file. Let me know if you'd like me to do something else!

Copy link
Member

@fabianp fabianp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, thanks!

@copybara-service copybara-service bot merged commit c0787bf into google-deepmind:master Jan 14, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants