Skip to content

Commit

Permalink
jax.vmap: improve docs & error for structured in_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 15, 2023
1 parent f2c89a4 commit 0bcd64a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
28 changes: 17 additions & 11 deletions jax/_src/api.py
Expand Up @@ -1089,8 +1089,8 @@ def vmap(fun: F,
Args:
fun: Function to be mapped over additional axes.
in_axes: An integer, None, or (nested) standard Python container
(tuple/list/dict) thereof specifying which input array axes to map over.
in_axes: An integer, None, or sequence of values specifying which input
array axes to map over.
If each positional argument to ``fun`` is an array, then ``in_axes`` can
be an integer, a None, or a tuple of integers and Nones with length equal
Expand All @@ -1101,11 +1101,12 @@ def vmap(fun: F,
range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
dimensions (axes) of the corresponding input array.
If the positional arguments to ``fun`` are container (pytree) types, the
corresponding element of ``in_axes`` can itself be a matching container,
so that distinct array axes can be mapped for different container
elements. ``in_axes`` must be a container tree prefix of the positional
argument tuple passed to ``fun``. See this link for more detail:
If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
must be a sequence with length equal to the number of positional arguments to
``fun``, and for each argument the corresponding element of ``in_axes`` can
be a container with a matching pytree structure specifying the mapping of its
container elements. In other words, ``in_axes`` must be a container tree prefix
of the positional argument tuple passed to ``fun``. See this link for more detail:
https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
Either ``axis_size`` must be provided explicitly, or at least one
Expand Down Expand Up @@ -1233,18 +1234,23 @@ def vmap(fun: F,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axes = tuple(in_axes)

if not all(type(l) is int or type(l) in batching.spec_types
for l in tree_leaves(in_axes)):
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}):
raise TypeError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
f"to the positional arguments passed to the function, but got {in_axes}.")
if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(in_axes)):
raise TypeError("vmap in_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {in_axes}.")
if not all(type(l) is int or type(l) in batching.spec_types
for l in tree_leaves(out_axes)):
if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(out_axes)):
raise TypeError("vmap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")

@wraps(fun, docstr=docstr)
@api_boundary
def vmap_f(*args, **kwargs):
if isinstance(in_axes, tuple) and len(in_axes) != len(args):
raise ValueError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
"to the positional arguments passed to the function, "
f"but got {len(in_axes)=}, {len(args)=}")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
f = lu.wrap_init(fun)
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
Expand Down
1 change: 0 additions & 1 deletion jax/_src/api_util.py
Expand Up @@ -404,7 +404,6 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
# the given treedef, build a complete axis spec tree with the same structure
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation

proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
Expand Down
20 changes: 18 additions & 2 deletions tests/api_test.py
Expand Up @@ -2981,15 +2981,31 @@ def f(dct, x, y):
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
self.assertAllClose(out1, out2)

def test_vmap_in_axes_non_tuple_error(self):
# https://github.com/google/jax/issues/18548
with self.assertRaisesRegex(
TypeError,
re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding "
"to the positional arguments passed to the function, but got {'a': 0}.")):
jax.vmap(lambda x: x['a'], in_axes={'a': 0})

def test_vmap_in_axes_wrong_length_tuple_error(self):
# https://github.com/google/jax/issues/18548
with self.assertRaisesRegex(
ValueError,
re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding to the "
"positional arguments passed to the function, but got len(in_axes)=2, len(args)=1")):
jax.vmap(lambda x: x['a'], in_axes=(0, {'a': 0}))({'a': jnp.zeros((3, 3))})

def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
value_tree = jnp.ones(3)
self.assertRaisesRegex(
ValueError,
"vmap in_axes specification must be a tree prefix of the corresponding "
r"value, got specification \(0, 0\) for value tree "
r"value, got specification \(\[0\],\) for value tree "
+ re.escape(f"{tree_util.tree_structure((value_tree,))}."),
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(value_tree)
lambda: api.vmap(lambda x: x, in_axes=([0],))(value_tree)
)

def test_vmap_in_axes_leaf_types(self):
Expand Down

0 comments on commit 0bcd64a

Please sign in to comment.