Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 633992657
  • Loading branch information
Jake VanderPlas authored and fedjax authors committed May 15, 2024
1 parent f4c1f00 commit 1069448
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion fedjax/core/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _l2_regularize(params: Params, weight: float,
params_weights: Optional[Params]) -> float:
"""Returns L2 regularization weight."""
if center_params is not None:
params = jax.tree_map(lambda a, b: a - b, params, center_params)
params = jax.tree.map(lambda a, b: a - b, params, center_params)
leaves = jax.tree_util.tree_leaves(params)
if params_weights is not None:
pw_leaves = jax.tree_util.tree_leaves(params_weights)
Expand Down
4 changes: 2 additions & 2 deletions fedjax/core/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
@jax.jit
def tree_weight(pytree: PyTree, weight: float) -> PyTree:
"""Weights tree leaves by weight."""
return jax.tree_map(lambda l: l * weight, pytree)
return jax.tree.map(lambda l: l * weight, pytree)


def tree_inverse_weight(pytree: PyTree, weight: float) -> PyTree:
Expand All @@ -41,7 +41,7 @@ def tree_inverse_weight(pytree: PyTree, weight: float) -> PyTree:
@jax.jit
def tree_zeros_like(pytree: PyTree) -> PyTree:
"""Creates a tree with zeros with same structure as the input."""
return jax.tree_map(jnp.zeros_like, pytree)
return jax.tree.map(jnp.zeros_like, pytree)


@jax.jit
Expand Down

0 comments on commit 1069448

Please sign in to comment.