Skip to content

Commit

Permalink
Merge pull request #19930 from jakevdp:dep-tree_map
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609508069
  • Loading branch information
jax authors committed Feb 22, 2024
2 parents 051ebf0 + e59a050 commit be002b5
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 38 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Remember to align the itemized text with the first line of an item within a list
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
* Added {mod}`jax.tree` module, with a more convenient interface for referencing functions
in {mod}`jax.tree_util`.
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
out_axis_resources = jax.sharding.NamedSharding(mesh, spec)

f = pjit_lib.pjit(
lambda x: jax.tree_map(lambda x: x + 1, x),
lambda x: jax.tree.map(lambda x: x + 1, x),
in_shardings=in_axis_resources,
out_shardings=out_axis_resources,
)
Expand Down
24 changes: 18 additions & 6 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
)

from jax._src.tree_util import (
tree_map as tree_map,
tree_map as _deprecated_tree_map,
treedef_is_leaf as _deprecated_treedef_is_leaf,
tree_flatten as _deprecated_tree_flatten,
tree_leaves as _deprecated_tree_leaves,
Expand Down Expand Up @@ -188,32 +188,44 @@
_deprecated_treedef_is_leaf
),
"tree_flatten": (
"jax.tree_flatten is deprecated: use jax.tree.flatten.",
"jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) "
"or jax.tree_util.tree_flatten (any JAX version).",
_deprecated_tree_flatten
),
"tree_leaves": (
"jax.tree_leaves is deprecated: use jax.tree.leaves.",
"jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) "
"or jax.tree_util.tree_leaves (any JAX version).",
_deprecated_tree_leaves
),
"tree_structure": (
"jax.tree_structure is deprecated: use jax.tree.structure.",
"jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) "
"or jax.tree_util.tree_structure (any JAX version).",
_deprecated_tree_structure
),
"tree_transpose": (
"jax.tree_transpose is deprecated: use jax.tree.transpose.",
"jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) "
"or jax.tree_util.tree_transpose (any JAX version).",
_deprecated_tree_transpose
),
"tree_unflatten": (
"jax.tree_unflatten is deprecated: use jax.tree.unflatten.",
"jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) "
"or jax.tree_util.tree_unflatten (any JAX version).",
_deprecated_tree_unflatten
),
# Added Feb 22, 2024
"tree_map": (
"jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) "
"or jax.tree_util.tree_map (any JAX version).",
_deprecated_tree_map
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves
from jax._src.tree_util import tree_map as tree_map
from jax._src.tree_util import tree_structure as tree_structure
from jax._src.tree_util import tree_transpose as tree_transpose
from jax._src.tree_util import tree_unflatten as tree_unflatten
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/array_serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ async def _run_serializer():

def serialize_with_paths(self, arrays: Sequence[jax.Array],
paths: Sequence[str], *, on_commit_callback):
tspecs = jax.tree_map(get_tensorstore_spec, paths)
tspecs = jax.tree.map(get_tensorstore_spec, paths)
self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback)

def deserialize(self, shardings: Sequence[sharding.Sharding],
Expand All @@ -564,6 +564,6 @@ def deserialize_with_paths(
global_shapes: Sequence[array.Shape] | None = None,
dtypes: Sequence[typing.DTypeLike] | None = None,
concurrent_gb: int = 32):
tspecs = jax.tree_map(get_tensorstore_spec, paths)
tspecs = jax.tree.map(get_tensorstore_spec, paths)
return self.deserialize(shardings, tspecs, global_shapes, dtypes,
concurrent_gb)
14 changes: 7 additions & 7 deletions jax/experimental/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,20 +266,20 @@ def f(*args):
def propagate_user_sharding(mesh, user_shape):
'''Update the sharding of the op from a user's shape.sharding.'''
user_sharding = jax.tree_map(lambda x: x.sharding, user_shape)
user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)
def partition(mesh, arg_shapes, result_shape):
def lower_fn(*args):
... builds computation on per-device shapes ...
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
# result_sharding and arg_shardings may optionally be modified and the
# partitioner will insert collectives to reshape.
return mesh, lower_fn, result_sharding, arg_shardings
def infer_sharding_from_operands(mesh, arg_shapes, shape):
'''Compute the result sharding from the sharding of the operands.'''
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
Expand Down Expand Up @@ -337,14 +337,14 @@ def supported_sharding(sharding, shape):
return NamedSharding(sharding.mesh, P(*names))
def partition(mesh, arg_shapes, result_shape):
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return mesh, fft, \
supported_sharding(arg_shardings[0], arg_shapes[0]), \
(supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return supported_sharding(arg_shardings[0], arg_shapes[0])
@custom_partitioning
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/models_test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _format(e):
jax_result = harness.apply_with_vars(*xs)
try:
tf_result = apply_tf(*xs)
jax.tree_map(np_assert_allclose, jax_result, tf_result)
jax.tree.map(np_assert_allclose, jax_result, tf_result)
print("=== Numerical comparison OK!")
except AssertionError as e:
error_msg = "Numerical comparison error:\n" + _format(e)
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


def _psum(x: Any) -> Any:
return jax.tree_map(partial(jnp.sum, axis=0), x)
return jax.tree.map(partial(jnp.sum, axis=0), x)


def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
Expand Down Expand Up @@ -76,10 +76,10 @@ def pre_jit(x):
def post_jit(x):
return np.asarray(x.addressable_data(0))

in_tree = jax.tree_map(pre_jit, in_tree)
in_tree = jax.tree.map(pre_jit, in_tree)
out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(
global_mesh, P()))(in_tree)
return jax.tree_map(post_jit, out_tree)
return jax.tree.map(post_jit, out_tree)


def sync_global_devices(name: str):
Expand Down Expand Up @@ -148,7 +148,7 @@ def process_allgather(in_tree: Any, tiled: bool = False) -> Any:

def _pjit(inp):
return _handle_array_process_allgather(inp, tiled)
return jax.tree_map(_pjit, in_tree)
return jax.tree.map(_pjit, in_tree)


def assert_equal(in_tree, fail_message: str = ''):
Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def f(inp1):

def test_donate_args_info_aot(self):
def fn(x, y):
return jax.tree_map(lambda i: i * 2, x), y * 2
return jax.tree.map(lambda i: i * 2, x), y * 2

x = jax.device_put({"A": np.array(1.0), "B": np.array(2.0)},
jax.devices()[0])
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_ufuncs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def scalar_sub(x, y):
def cast_outputs(fun):
def wrapped(*args, **kwargs):
dtype = np.asarray(args[0]).dtype
return jax.tree_map(lambda x: np.asarray(x, dtype=dtype), fun(*args, **kwargs))
return jax.tree.map(lambda x: np.asarray(x, dtype=dtype), fun(*args, **kwargs))
return wrapped


Expand Down
20 changes: 10 additions & 10 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def testBufferDonationWithPyTreeKwargs(self):

@partial(pjit, out_shardings=s, donate_argnames='inp2')
def f(inp1, inp2, inp3):
return jax.tree_map(lambda x, y, z: x + y + z, inp1, inp2, inp3)
return jax.tree.map(lambda x, y, z: x + y + z, inp1, inp2, inp3)

x = np.ones((2, 5)) * 4
x_tree = jax.device_put({"a": {"b": x}, "c": x}, s)
Expand All @@ -389,10 +389,10 @@ def f(inp1, inp2, inp3):

expected = x + y + z
out = f(x_tree, inp2=y_tree, inp3=z_tree)
jax.tree_map(lambda o: self.assertAllClose(o, expected), out)
jax.tree_map(self.assertNotDeleted, x_tree)
jax.tree_map(self.assertDeleted, y_tree)
jax.tree_map(self.assertNotDeleted, z_tree)
jax.tree.map(lambda o: self.assertAllClose(o, expected), out)
jax.tree.map(self.assertNotDeleted, x_tree)
jax.tree.map(self.assertDeleted, y_tree)
jax.tree.map(self.assertNotDeleted, z_tree)

@unittest.skipIf(xla_extension_version < 220, 'jaxlib version too old')
@jtu.run_on_devices('tpu', 'cpu', 'gpu')
Expand Down Expand Up @@ -422,9 +422,9 @@ def f(inp1, inp2, inp3):
z_tree = jax.device_put({'a': {'b': z}, 'c': z}, s)

out = f(x_tree, y_tree, z_tree)
jax.tree_map(self.assertNotDeleted, x_tree)
jax.tree_map(self.assertDeleted, y_tree)
jax.tree_map(self.assertDeleted, z_tree)
jax.tree.map(self.assertNotDeleted, x_tree)
jax.tree.map(self.assertDeleted, y_tree)
jax.tree.map(self.assertDeleted, z_tree)

@unittest.skipIf(xla_extension_version < 220, 'jaxlib version too old')
@jtu.run_on_devices('tpu')
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def test_custom_partitioner(self):
self.skip_if_custom_partitioning_not_supported()

def partition(precision, mesh, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)
result_sharding = result_shape[0].sharding
self.assertEqual(arg_shardings[0], result_sharding)
self.assertEqual(P('x', None), result_sharding.spec)
Expand All @@ -1351,7 +1351,7 @@ def lower_fn(x, y):
return mesh, lower_fn, (result_sharding, result_sharding), arg_shardings

def infer_sharding_from_operands(precision, mesh, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)
x_shard, y_shard = arg_shardings
x_shape, y_shape = arg_shapes
x_names = tuple(x_shard.spec) + tuple(
Expand Down
10 changes: 5 additions & 5 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,16 +1606,16 @@ def shmap_reference(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
def f_shmapped(*args):
outs = jax.tree_map(lambda y: jnp.zeros(y.shape, y.dtype), out_types)
outs = jax.tree.map(lambda y: jnp.zeros(y.shape, y.dtype), out_types)
getters = [make_indexer(mesh, s, x) for s, x in zip(in_specs, args)]
putters = jax.tree_map(partial(make_indexer, mesh), out_specs, outs)
putters = jax.tree.map(partial(make_indexer, mesh), out_specs, outs)
for idx in it.product(*map(range, mesh.shape.values())):
args_shards = [x[indexer(idx)] for x, indexer in zip(args, getters)]
assert all(x.shape == r.shape for x, r in zip(args_shards, body_in_types))
out_shards = f(*args_shards)
assert jax.tree_util.tree_all(jax.tree_map(lambda y, r: y.shape == r.shape,
assert jax.tree_util.tree_all(jax.tree.map(lambda y, r: y.shape == r.shape,
out_shards, body_out_types))
outs = jax.tree_map(lambda y, out, indexer: out.at[indexer(idx)].set(y),
outs = jax.tree.map(lambda y, out, indexer: out.at[indexer(idx)].set(y),
out_shards, outs, putters)
return outs
return f_shmapped
Expand Down Expand Up @@ -1662,7 +1662,7 @@ def sample_shmap() -> Chooser:
for ty in in_types]
out_reps = spec.out_rep(*map(partial(unmentioned, mesh), in_specs))
out_specs = yield from make_out_specs(mesh, body_out_types, out_reps)
out_types = jax.tree_map(partial(dilate, mesh), out_specs, body_out_types)
out_types = jax.tree.map(partial(dilate, mesh), out_specs, body_out_types)
ref = partial(shmap_reference, body_in_types, body_out_types, out_types)
in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short()
for t in in_types) + ')'
Expand Down

0 comments on commit be002b5

Please sign in to comment.