Skip to content

Commit

Permalink
[JAX] Update users of jax.tree.map() to be more careful about how the…
Browse files Browse the repository at this point in the history
…y handle Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 642455356
  • Loading branch information
hawkinsp authored and OptaxDev committed Jun 12, 2024
1 parent 0287b95 commit 9547a04
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,25 @@ def tree_clip(
def tree_update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order`-th moment."""
return jtu.tree_map(
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
lambda g, t: (
(1 - decay) * (g**order) + decay * t if g is not None else None
),
updates,
moments,
is_leaf=lambda x: x is None,
)


def tree_update_infinity_moment(updates, moments, decay, eps):
"""Compute the exponential moving average of the infinity norm."""
return jtu.tree_map(
lambda g, t: jnp.maximum(jnp.abs(g) + eps, decay * t), updates, moments)
lambda g, t: (
jnp.maximum(jnp.abs(g) + eps, decay * t) if g is not None else g
),
updates,
moments,
is_leaf=lambda x: x is None,
)


def tree_update_moment_per_elem_norm(updates, moments, decay, order):
Expand All @@ -300,7 +312,13 @@ def orderth_norm(g):
return numerics.abs_sq(g) ** half_order

return jtu.tree_map(
lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)
lambda g, t: (
(1 - decay) * orderth_norm(g) + decay * t if g is not None else None
),
updates,
moments,
is_leaf=lambda x: x is None,
)


@functools.partial(jax.jit, inline=True)
Expand Down

0 comments on commit 9547a04

Please sign in to comment.