Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix usage of jax.tree.map in gradient descent loop #1809

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ocadni
Copy link

@ocadni ocadni commented May 17, 2024

This pull request fixes an issue in the gradient descent loop where jax.tree.map was incorrectly used instead of jax.tree_util.tree_map. The incorrect usage led to errors during parameter updates. The corrected code ensures proper parameter updates by correctly applying the gradient descent step using jax.tree_util.tree_map.

Changes Made:

  • Replaced jax.tree.map with jax.tree_util.tree_map to correctly apply the gradient descent step on every leaf of the dictionaries containing the set of parameters.

Code Changes:

# Original incorrect code
new_pars = jax.tree.map(lambda x, y: x - 0.05 * y, vstate.parameters, E_grad)

# Updated correct code
new_pars = jax.tree_util.tree_map(lambda x, y: x - 0.05 * y, vstate.parameters, E_grad)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@PhilipVinc
Copy link
Member

OHHHHH! Indaco is in 'da town! Am-az-ing! Welcome 'da family!

@PhilipVinc
Copy link
Member

But in fact, jax is shifting away from jax.tree_util.tree_map towards jax.tree.map. Of course, jax.tree.map is only present on more recent versions, where jax.tree_map was deprecated.

They plan to move all the 'user facing' components of jax.tree_util.* to jax.tree.* in the future.

I suppose you have an 'older' (>2 months) version of jax installed.

NetKet internally uses jax.tree_util.tree_map to work with both old and new versions of jax, but when I updated the tutorials to remove jax.tree_map I had decided to use the 'new' jax.tree.map.

I'm fine using jax.tree_util.tree_map for now... Though at some point netket will require a more recent jax version and then we could go back to jax.tree.mapon the tutorials as well...

@ocadni
Copy link
Author

ocadni commented May 17, 2024

But in fact, jax is shifting away from jax.tree_util.tree_map towards jax.tree.map. Of course, jax.tree.map is only present on more recent versions, where jax.tree_map was deprecated.

They plan to move all the 'user facing' components of jax.tree_util.* to jax.tree.* in the future.

I suppose you have an 'older' (>2 months) version of jax installed.

NetKet internally uses jax.tree_util.tree_map to work with both old and new versions of jax, but when I updated the tutorials to remove jax.tree_map I had decided to use the 'new' jax.tree.map.

I'm fine using jax.tree_util.tree_map for now... Though at some point netket will require a more recent jax version and then we could go back to jax.tree.mapon the tutorials as well...

Ah ok!
I tried running the code using Google Colab and encountered the same issue, which seems to be related to the JAX version installed there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants