Skip to content

Commit

Permalink
[custom_vjp] bwd function should not be WrappedFun, may run multiple …
Browse files Browse the repository at this point in the history
…times
  • Loading branch information
mattjj committed Mar 2, 2023
1 parent abc6c9b commit bf07395
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
8 changes: 4 additions & 4 deletions jax/_src/custom_derivatives.py
Expand Up @@ -560,7 +560,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
Expand Down Expand Up @@ -680,7 +680,7 @@ def bind(self, fun, fwd, bwd, *args, out_trees):
fwd, env_trace_todo2 = process_env_traces_fwd(
fwd, top_trace and top_trace.level, out_trees)
tracers = map(top_trace.full_raise, args) # type: ignore
bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
bwd_ = lambda *args: bwd(*args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
out_trees=out_trees)
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
Expand Down Expand Up @@ -749,7 +749,7 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
bwd: Callable, out_trees: Callable, num_consts: int):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
Expand All @@ -772,7 +772,7 @@ def _custom_vjp_call_jaxpr_jvp(
def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
bwd: Callable, out_trees: Callable, num_consts: int):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Expand Up @@ -748,7 +748,7 @@ def raise_custom_vjp_error_on_jvp(*_, **__):
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
cts_in = bwd(*res, *cts_out)
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose

Expand Down
15 changes: 10 additions & 5 deletions jax/_src/interpreters/batching.py
Expand Up @@ -776,11 +776,16 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2

def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
bwd, out_dims_thunk = batch_subtrace(bwd)
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
main_type, spmd_axis_name):
def new_bwd(*args):
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd))
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
out_dim_dests)
return bwd_.call_wrapped(*args)
return new_bwd

@lu.transformation
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/tree_util.py
Expand Up @@ -308,6 +308,9 @@ def __eq__(self, other):
return self.fun == other.fun
return self.fun == other

def __repr__(self):
return f'_HashableCallableShim({repr(self.fun)})'


class Partial(functools.partial):
"""A version of functools.partial that works in pytrees.
Expand Down
12 changes: 12 additions & 0 deletions tests/api_test.py
Expand Up @@ -8385,6 +8385,18 @@ def f_bwd(_, g):

jax.grad(f)(A([1.])) # doesn't crash

def test_vmap_vjp_called_twice(self):
# https://github.com/google/jax/pull/14728
@jax.custom_vjp
def f(x):
return x
f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,))

_, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.]))
f_vjp(jnp.array([3.]))
f_vjp(jnp.array([3.])) # doesn't crash


def transpose_unary(f, x_example):
def transposed(y):
x, = api.linear_transpose(f, x_example)(y)
Expand Down

0 comments on commit bf07395

Please sign in to comment.