Skip to content

Commit

Permalink
fix checkify + custom_vjp after symbolic zeros change
Browse files Browse the repository at this point in the history
Co-authored-by: Lena Martens <lenamartens@google.com>
  • Loading branch information
mattjj and LenaMartens committed Jun 8, 2023
1 parent 01ed663 commit 01fa7e0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
18 changes: 10 additions & 8 deletions jax/_src/checkify.py
Expand Up @@ -972,31 +972,33 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
symbolic_zeros):
err_vals, err_tree = jtu.tree_flatten(in_err)
fun = lu.wrap_init(
num_errs = err_tree.num_leaves
checkified_fun = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
fun_jaxpr.consts, enabled_errors, err_tree))
fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun)
checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
checkified_fun)

@lu.wrap_init
def fwd(*args):
def checkified_fwd(*args):
# TODO(lenamartens, sharadmv): why not checkify here?
xs, zeros = args[::2], args[1::2]
xs, zeros = xs[num_errs:], zeros[num_errs:]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
xs_without_consts = xs[num_consts:]
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)

fwd, fwd_out_tree = flatten_fun_output(fwd)
bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args))
checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd)
all_outs = custom_derivatives.custom_vjp_call_p.bind(
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees,
checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
if fst:
err_and_out_tree, _ = out_metadata
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
else:
err_vals, out_vals = split_list(all_outs, [len(err_vals)])
# forward input error to output
out_err = jtu.tree_unflatten(err_tree, err_vals)
out_err, out_vals = in_err, all_outs
return out_err, out_vals
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule

Expand Down
23 changes: 23 additions & 0 deletions tests/checkify_test.py
Expand Up @@ -831,6 +831,29 @@ def h_out(fext):
h_grad = jax.grad(h_out)
h_grad(0.) # doesn't crash

def test_goodfellow_custom_vjp(self):
@jax.custom_vjp
def sin(x):
return jnp.sin(x)
def sin_fwd(x):
return jnp.sin(x), 2. * x
def sin_bwd(x2, g):
return jnp.cos(x2 / 2.) * g,
sin.defvjp(sin_fwd, sin_bwd)

def h(fext):
checkify.check(True, "")
return sin(fext)

h = checkify.checkify(h)

def h_out(fext):
_, out = h(fext)
return out

h_grad = jax.grad(h_out)
h_grad(0.) # doesn't crash

def test_closed_call(self):
# lots of golfing went into this test
y = jnp.array([3.14])
Expand Down

0 comments on commit 01fa7e0

Please sign in to comment.