From a8246ea67f1a333f2f0b6b0975c6691defb24395 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 9 Jun 2024 09:17:42 -0700 Subject: [PATCH] Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs. For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case. In a future release of JAX, this behavior will become an error. PiperOrigin-RevId: 641690427 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ecad73d27bb0..e7823a3ee0aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,13 @@ Remember to align the itemized text with the first line of an item within a list (https://github.com/openxla/xla/pull/13301). * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). +* Deprecations + * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will + raise an error in a future version of jax. `None` is only a tree-prefix of + itself. To preserve the current behavior, you can ask `jax.tree.map` to + treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. + ## jax 0.4.28 (May 9, 2024) * Bug fixes