diff --git a/optax/contrib/_sophia.py b/optax/contrib/_sophia.py index 1707b296..8a1edf96 100644 --- a/optax/contrib/_sophia.py +++ b/optax/contrib/_sophia.py @@ -73,8 +73,8 @@ def scale_by_sophia( def init_fn(params: base.Params): return SophiaState( step=jnp.array(0, dtype=jnp.int64), - gradient_avg=jax.tree_map(jnp.zeros_like, params), - hessian=jax.tree_map(jnp.zeros_like, params), + gradient_avg=jax.tree.map(jnp.zeros_like, params), + hessian=jax.tree.map(jnp.zeros_like, params), ) def update_fn( @@ -85,7 +85,7 @@ def update_fn( del params # Update exponential average of gradients. - gradient_avg = jax.tree_map( + gradient_avg = jax.tree.map( lambda ga, gr: ga * b1 + gr * (1 - b1), state.gradient_avg, updates, @@ -94,7 +94,7 @@ def update_fn( # Update Hessian diagonal estimate, potentially every nth step. hessian = jax.lax.cond( state.step % update_hessian_every == 0, - lambda: jax.tree_map( + lambda: jax.tree.map( lambda he, gr: he * b2 + gr**2 * (1 - b2), state.hessian, updates, @@ -102,7 +102,7 @@ def update_fn( lambda: state.hessian, ) - updates = jax.tree_map( + updates = jax.tree.map( lambda grad_av, he: jnp.clip( grad_av / (rho * batch_size * he + 1e-15), -1, 1 ), diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index b9e7e437..0bf42d0e 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -128,7 +128,7 @@ def tree_map_params( that will be passed to f. transform_non_params: An optional function that will be called on all non-parameter fields within the optimizer state. - is_leaf: Passed through to `jax.tree_map`. This makes it possible to ignore + is_leaf: Passed through to `jax.tree.map`. This makes it possible to ignore parts of the parameter tree e.g. when the gradient transformations modify the shape of the original pytree, such as for ``optax.masked``.