diff --git a/jax/_src/api.py b/jax/_src/api.py index c59ae055c695..51f1ae81ebf2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2531,8 +2531,30 @@ def fun(*tangents): return apply_flat_fun(fun, io_tree, *py_args) -def _vjp_pullback_wrapper(cotangent_dtypes, cotangent_shapes, - io_tree, fun, py_args): +def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree, + fun, *py_args_): + if len(py_args_) != 1: + msg = (f"The function returned by `jax.vjp` applied to {name} was called " + f"with {len(py_args_)} arguments, but functions returned by " + "`jax.vjp` must be called with a single argument corresponding to " + f"the single value returned by {name} (even if that returned " + "value is a tuple or other container).\n" + "\n" + "For example, if we have:\n" + "\n" + " def f(x):\n" + " return (x, x)\n" + " _, f_vjp = jax.vjp(f, 1.0)\n" + "\n" + "the function `f` returns a single tuple as output, and so we call " + "`f_vjp` with a single tuple as its argument:\n" + "\n" + " x_bar, = f_vjp((2.0, 2.0))\n" + "\n" + "If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted " + "out' as arguments rather than in a tuple, this error can arise.") + raise TypeError(msg) + py_args, = py_args_ in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten(py_args) if in_tree != in_tree_expected: @@ -2637,9 +2659,8 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()): ct_shapes = [np.shape(x) for x in out_primal] # Ensure that vjp_py is a PyTree so that we can pass it from the forward to the # backward pass in a custom VJP. - vjp_py = Partial(partial(_vjp_pullback_wrapper, - ct_dtypes, ct_shapes, - (out_tree, in_tree)), + vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__, + ct_dtypes, ct_shapes, (out_tree, in_tree)), out_vjp) if not has_aux: return out_primal_py, vjp_py diff --git a/tests/api_test.py b/tests/api_test.py index 3fc7ad6a0973..1dc048d83a44 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4095,6 +4095,14 @@ def h(x): b(8) # don't crash + def test_vjp_multiple_arguments_error_message(self): + # https://github.com/google/jax/issues/13099 + def foo(x): + return (x, x) + _, f_vjp = jax.vjp(foo, 1.0) + with self.assertRaisesRegex(TypeError, "applied to foo"): + f_vjp(1.0, 1.0) + @jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True) class SubcallTraceCacheTest(jtu.JaxTestCase):