diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 3d66b8bb6fcb..9c083f69a31c 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -972,31 +972,33 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, fwd_jaxpr_thunk, num_consts, bwd, out_trees, symbolic_zeros): err_vals, err_tree = jtu.tree_flatten(in_err) - fun = lu.wrap_init( + num_errs = err_tree.num_leaves + checkified_fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, fun_jaxpr.consts, enabled_errors, err_tree)) - fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun) + checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( + checkified_fun) @lu.wrap_init - def fwd(*args): + def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] + xs, zeros = xs[num_errs:], zeros[num_errs:] fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) - fwd, fwd_out_tree = flatten_fun_output(fwd) + bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args)) + checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd) all_outs = custom_derivatives.custom_vjp_call_p.bind( - fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees, + checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: err_and_out_tree, _ = out_metadata out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) else: - err_vals, out_vals = split_list(all_outs, [len(err_vals)]) - # forward input error to output - out_err = jtu.tree_unflatten(err_tree, err_vals) + out_err, out_vals = in_err, all_outs return out_err, out_vals error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule diff --git a/tests/checkify_test.py b/tests/checkify_test.py index f939c1e9f001..14c28c545dc5 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -831,6 +831,29 @@ def h_out(fext): h_grad = jax.grad(h_out) h_grad(0.) # doesn't crash + def test_goodfellow_custom_vjp(self): + @jax.custom_vjp + def sin(x): + return jnp.sin(x) + def sin_fwd(x): + return jnp.sin(x), 2. * x + def sin_bwd(x2, g): + return jnp.cos(x2 / 2.) * g, + sin.defvjp(sin_fwd, sin_bwd) + + def h(fext): + checkify.check(True, "") + return sin(fext) + + h = checkify.checkify(h) + + def h_out(fext): + _, out = h(fext) + return out + + h_grad = jax.grad(h_out) + h_grad(0.) # doesn't crash + def test_closed_call(self): # lots of golfing went into this test y = jnp.array([3.14])