diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 20eafa767385..e93132a7806f 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -560,7 +560,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree, out_type) - flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) + flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) @@ -680,7 +680,7 @@ def bind(self, fun, fwd, bwd, *args, out_trees): fwd, env_trace_todo2 = process_env_traces_fwd( fwd, top_trace and top_trace.level, out_trees) tracers = map(top_trace.full_raise, args) # type: ignore - bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args)) + bwd_ = lambda *args: bwd(*args) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, out_trees=out_trees) fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) @@ -749,7 +749,7 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): def _custom_vjp_call_jaxpr_jvp( primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], - bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): + bwd: Callable, out_trees: Callable, num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): @@ -772,7 +772,7 @@ def _custom_vjp_call_jaxpr_jvp( def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], - bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): + bwd: Callable, out_trees: Callable, num_consts: int): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 08b3c5dae7d4..bdcf3b961e51 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -748,7 +748,7 @@ def raise_custom_vjp_error_on_jvp(*_, **__): def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals): res, _ = split_list(invals, [num_res]) cts_out = map(instantiate_zeros_aval, out_avals, cts_out) - cts_in = bwd.call_wrapped(*res, *cts_out) + cts_in = bwd(*res, *cts_out) return [None] * num_res + list(cts_in) primitive_transposes[custom_lin_p] = _custom_lin_transpose diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index cd6f86eeec05..019400aa948b 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -776,11 +776,16 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals): out_tangent_bds, out_dims, out_tangents) yield out_primals + out_tangents, out_dims * 2 -def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name): - bwd, out_dims_thunk = batch_subtrace(bwd) - bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, - spmd_axis_name) - return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) +def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, + main_type, spmd_axis_name): + def new_bwd(*args): + bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) + bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type, + spmd_axis_name) + bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, + out_dim_dests) + return bwd_.call_wrapped(*args) + return new_bwd @lu.transformation def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 7f9eb2b89e83..74b11b2303c0 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -308,6 +308,9 @@ def __eq__(self, other): return self.fun == other.fun return self.fun == other + def __repr__(self): + return f'_HashableCallableShim({repr(self.fun)})' + class Partial(functools.partial): """A version of functools.partial that works in pytrees. diff --git a/tests/api_test.py b/tests/api_test.py index eea297779d28..dd51eb2fe73f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8385,6 +8385,18 @@ def f_bwd(_, g): jax.grad(f)(A([1.])) # doesn't crash + def test_vmap_vjp_called_twice(self): + # https://github.com/google/jax/pull/14728 + @jax.custom_vjp + def f(x): + return x + f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) + + _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) + f_vjp(jnp.array([3.])) + f_vjp(jnp.array([3.])) # doesn't crash + + def transpose_unary(f, x_example): def transposed(y): x, = api.linear_transpose(f, x_example)(y)