Skip to content

Commit

Permalink
Fix jax.tree_map deprecation warnings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622995953
  • Loading branch information
OptaxDev committed Apr 9, 2024
1 parent aea1d5f commit eff8b6d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ def tree_map_params(

def map_params(maybe_placeholder_value, value):
if isinstance(maybe_placeholder_value, _ParamsPlaceholder):
return jax.tree_map(f, value, *rest, is_leaf=is_leaf)
return jax.tree_util.tree_map(f, value, *rest, is_leaf=is_leaf)
elif transform_non_params is not None:
return transform_non_params(value)
else:
return value

return jax.tree_map(
return jax.tree_util.tree_map(
map_params,
state_with_placeholders,
state,
Expand Down

0 comments on commit eff8b6d

Please sign in to comment.