diff --git a/jax/_src/api.py b/jax/_src/api.py index 5fb5b636a275..a4e30643b5ea 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1089,8 +1089,8 @@ def vmap(fun: F, Args: fun: Function to be mapped over additional axes. - in_axes: An integer, None, or (nested) standard Python container - (tuple/list/dict) thereof specifying which input array axes to map over. + in_axes: An integer, None, or sequence of values specifying which input + array axes to map over. If each positional argument to ``fun`` is an array, then ``in_axes`` can be an integer, a None, or a tuple of integers and Nones with length equal @@ -1101,11 +1101,12 @@ def vmap(fun: F, range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of dimensions (axes) of the corresponding input array. - If the positional arguments to ``fun`` are container (pytree) types, the - corresponding element of ``in_axes`` can itself be a matching container, - so that distinct array axes can be mapped for different container - elements. ``in_axes`` must be a container tree prefix of the positional - argument tuple passed to ``fun``. See this link for more detail: + If the positional arguments to ``fun`` are container (pytree) types, ``in_axes`` + must be a sequence with length equal to the number of positional arguments to + ``fun``, and for each argument the corresponding element of ``in_axes`` can + be a container with a matching pytree structure specifying the mapping of its + container elements. In other words, ``in_axes`` must be a container tree prefix + of the positional argument tuple passed to ``fun``. See this link for more detail: https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees Either ``axis_size`` must be provided explicitly, or at least one @@ -1233,18 +1234,23 @@ def vmap(fun: F, # rather than raising an error. https://github.com/google/jax/issues/2367 in_axes = tuple(in_axes) - if not all(type(l) is int or type(l) in batching.spec_types - for l in tree_leaves(in_axes)): + if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}): + raise TypeError("vmap in_axes must be an int, None, or a tuple of entries corresponding " + f"to the positional arguments passed to the function, but got {in_axes}.") + if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(in_axes)): raise TypeError("vmap in_axes must be an int, None, or (nested) container " f"with those types as leaves, but got {in_axes}.") - if not all(type(l) is int or type(l) in batching.spec_types - for l in tree_leaves(out_axes)): + if not all(type(l) in {int, *batching.spec_types} for l in tree_leaves(out_axes)): raise TypeError("vmap out_axes must be an int, None, or (nested) container " f"with those types as leaves, but got {out_axes}.") @wraps(fun, docstr=docstr) @api_boundary def vmap_f(*args, **kwargs): + if isinstance(in_axes, tuple) and len(in_axes) != len(args): + raise ValueError("vmap in_axes must be an int, None, or a tuple of entries corresponding " + "to the positional arguments passed to the function, " + f"but got {len(in_axes)=}, {len(args)=}") args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable) f = lu.wrap_init(fun) flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 9c098c99b80b..5f5e631005a9 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -404,7 +404,6 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # the given treedef, build a complete axis spec tree with the same structure # and return the flattened result # TODO(mattjj,phawkins): improve this implementation - proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) axes = [] diff --git a/tests/api_test.py b/tests/api_test.py index 5bf712d9f9ec..6fb47b568f39 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2981,15 +2981,31 @@ def f(dct, x, y): out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y) self.assertAllClose(out1, out2) + def test_vmap_in_axes_non_tuple_error(self): + # https://github.com/google/jax/issues/18548 + with self.assertRaisesRegex( + TypeError, + re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding " + "to the positional arguments passed to the function, but got {'a': 0}.")): + jax.vmap(lambda x: x['a'], in_axes={'a': 0}) + + def test_vmap_in_axes_wrong_length_tuple_error(self): + # https://github.com/google/jax/issues/18548 + with self.assertRaisesRegex( + ValueError, + re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding to the " + "positional arguments passed to the function, but got len(in_axes)=2, len(args)=1")): + jax.vmap(lambda x: x['a'], in_axes=(0, {'a': 0}))({'a': jnp.zeros((3, 3))}) + def test_vmap_in_axes_tree_prefix_error(self): # https://github.com/google/jax/issues/795 value_tree = jnp.ones(3) self.assertRaisesRegex( ValueError, "vmap in_axes specification must be a tree prefix of the corresponding " - r"value, got specification \(0, 0\) for value tree " + r"value, got specification \(\[0\],\) for value tree " + re.escape(f"{tree_util.tree_structure((value_tree,))}."), - lambda: api.vmap(lambda x: x, in_axes=(0, 0))(value_tree) + lambda: api.vmap(lambda x: x, in_axes=([0],))(value_tree) ) def test_vmap_in_axes_leaf_types(self):