Skip to content

Commit

Permalink
improve error when f_vjp gets more than one argument
Browse files Browse the repository at this point in the history
fixes #13099
  • Loading branch information
mattjj committed Nov 3, 2022
1 parent dba9fc0 commit 4033007
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
31 changes: 26 additions & 5 deletions jax/_src/api.py
Expand Up @@ -2531,8 +2531,30 @@ def fun(*tangents):

return apply_flat_fun(fun, io_tree, *py_args)

def _vjp_pullback_wrapper(cotangent_dtypes, cotangent_shapes,
io_tree, fun, py_args):
def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree,
fun, *py_args_):
if len(py_args_) != 1:
msg = (f"The function returned by `jax.vjp` applied to {name} was called "
f"with {len(py_args_)} arguments, but functions returned by "
"`jax.vjp` must be called with a single argument corresponding to "
f"the single value returned by {name} (even if that returned "
"value is a tuple or other container).\n"
"\n"
"For example, if we have:\n"
"\n"
" def f(x):\n"
" return (x, x)\n"
" _, f_vjp = jax.vjp(f, 1.0)\n"
"\n"
"the function `f` returns a single tuple as output, and so we call "
"`f_vjp` with a single tuple as its argument:\n"
"\n"
" x_bar, = f_vjp((2.0, 2.0))\n"
"\n"
"If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted "
"out' as arguments rather than in a tuple, this error can arise.")
raise TypeError(msg)
py_args, = py_args_
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
Expand Down Expand Up @@ -2637,9 +2659,8 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
ct_shapes = [np.shape(x) for x in out_primal]
# Ensure that vjp_py is a PyTree so that we can pass it from the forward to the
# backward pass in a custom VJP.
vjp_py = Partial(partial(_vjp_pullback_wrapper,
ct_dtypes, ct_shapes,
(out_tree, in_tree)),
vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__,
ct_dtypes, ct_shapes, (out_tree, in_tree)),
out_vjp)
if not has_aux:
return out_primal_py, vjp_py
Expand Down
8 changes: 8 additions & 0 deletions tests/api_test.py
Expand Up @@ -4095,6 +4095,14 @@ def h(x):

b(8) # don't crash

def test_vjp_multiple_arguments_error_message(self):
# https://github.com/google/jax/issues/13099
def foo(x):
return (x, x)
_, f_vjp = jax.vjp(foo, 1.0)
with self.assertRaisesRegex(TypeError, "applied to foo"):
f_vjp(1.0, 1.0)


@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
class SubcallTraceCacheTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 4033007

Please sign in to comment.