Skip to content

Commit

Permalink
Issue a warning where code relies on a bug where treedef.flatten_up_t…
Browse files Browse the repository at this point in the history
…o(...) was overly permissive for None treedefs.

For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
  • Loading branch information
hawkinsp authored and jax authors committed Jun 9, 2024
1 parent 14d87d3 commit a8246ea
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ Remember to align the itemized text with the first line of an item within a list
(https://github.com/openxla/xla/pull/13301).
* Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).

* Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
raise an error in a future version of jax. `None` is only a tree-prefix of
itself. To preserve the current behavior, you can ask `jax.tree.map` to
treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.

## jax 0.4.28 (May 9, 2024)

* Bug fixes
Expand Down

0 comments on commit a8246ea

Please sign in to comment.