Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634413399
  • Loading branch information
Jake VanderPlas authored and OptaxDev committed May 16, 2024
1 parent 8a3ee74 commit 702e6d2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions optax/contrib/_sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def scale_by_sophia(
def init_fn(params: base.Params):
return SophiaState(
step=jnp.array(0, dtype=jnp.int64),
gradient_avg=jax.tree_map(jnp.zeros_like, params),
hessian=jax.tree_map(jnp.zeros_like, params),
gradient_avg=jax.tree.map(jnp.zeros_like, params),
hessian=jax.tree.map(jnp.zeros_like, params),
)

def update_fn(
Expand All @@ -85,7 +85,7 @@ def update_fn(
del params

# Update exponential average of gradients.
gradient_avg = jax.tree_map(
gradient_avg = jax.tree.map(
lambda ga, gr: ga * b1 + gr * (1 - b1),
state.gradient_avg,
updates,
Expand All @@ -94,15 +94,15 @@ def update_fn(
# Update Hessian diagonal estimate, potentially every nth step.
hessian = jax.lax.cond(
state.step % update_hessian_every == 0,
lambda: jax.tree_map(
lambda: jax.tree.map(
lambda he, gr: he * b2 + gr**2 * (1 - b2),
state.hessian,
updates,
),
lambda: state.hessian,
)

updates = jax.tree_map(
updates = jax.tree.map(
lambda grad_av, he: jnp.clip(
grad_av / (rho * batch_size * he + 1e-15), -1, 1
),
Expand Down
2 changes: 1 addition & 1 deletion optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def tree_map_params(
that will be passed to f.
transform_non_params: An optional function that will be called on all
non-parameter fields within the optimizer state.
is_leaf: Passed through to `jax.tree_map`. This makes it possible to ignore
is_leaf: Passed through to `jax.tree.map`. This makes it possible to ignore
parts of the parameter tree e.g. when the gradient transformations modify
the shape of the original pytree, such as for ``optax.masked``.
Expand Down

0 comments on commit 702e6d2

Please sign in to comment.