Skip to content

Commit

Permalink
custom_vjp symbolic zeros support
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 21, 2023
1 parent 023bfa8 commit 86907ec
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 80 deletions.
2 changes: 1 addition & 1 deletion jax/_src/ad_util.py
Expand Up @@ -95,7 +95,7 @@ def _stop_gradient_impl(x: T) -> T:


class SymbolicZero:
def __init__(self, aval: core.AbstractValue) -> None:
def __init__(self, aval: core.AbstractValue):
self.aval = aval

def __repr__(self) -> str:
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/checkify.py
Expand Up @@ -915,23 +915,26 @@ def jvp(*xs):
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule

def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fwd_jaxpr_thunk, num_consts, bwd, out_trees):
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
symbolic_zeros):
err_vals, err_tree = jtu.tree_flatten(in_err)
fun = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
fun_jaxpr.consts, enabled_errors, err_tree))
fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun)

@lu.wrap_init
def fwd(*xs):
def fwd(*args):
# TODO(lenamartens, sharadmv): why not checkify here?
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk()
xs, zeros = args[::2], args[1::2]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
xs_without_consts = xs[num_consts:]
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)

fwd, fwd_out_tree = flatten_fun_output(fwd)
all_outs = custom_derivatives.custom_vjp_call_p.bind(
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees)
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
if fst:
err_and_out_tree, _ = out_metadata
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/core.py
Expand Up @@ -512,7 +512,8 @@ def process_custom_transpose(self, prim, call, tracers, **params):
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
Expand Down
110 changes: 79 additions & 31 deletions jax/_src/custom_derivatives.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from functools import update_wrapper, reduce, partial
import inspect
from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any)
Expand All @@ -31,8 +32,8 @@
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import traceback_util
from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval,
stop_gradient_p)
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
Expand Down Expand Up @@ -163,12 +164,13 @@ def defjvp(self,
and the second element is the tangent output. Elements of the input and
output tuples may be arrays or any nested tuples/lists/dicts thereof.
symbolic_zeros: boolean, indicating whether the rule should be passed
objects representing static symbolic zeros in its tangent tuple
argument; otherwise, only standard JAX types (e.g. array-likes) are
passed. Setting this option to True allows a JVP rule to detect whether
certain inputs are not involved in differentiation, but at the cost of
needing special handling for these objects (which e.g. can't be passed
into jax.numpy functions). Default False.
objects representing static symbolic zeros in its tangent argument in
correspondence with unperturbed values; otherwise, only standard JAX
types (e.g. array-likes) are passed. Setting this option to ``True``
allows a JVP rule to detect whether certain inputs are not involved in
differentiation, but at the cost of needing special handling for these
objects (which e.g. can't be passed into jax.numpy functions). Default
``False``.
Returns:
None.
Expand Down Expand Up @@ -485,7 +487,9 @@ def __init__(self,

def defvjp(self,
fwd: Callable[..., Tuple[ReturnValue, Any]],
bwd: Callable[..., Tuple[Any, ...]]) -> None:
bwd: Callable[..., Tuple[Any, ...]],
symbolic_zeros: bool = False,
) -> None:
"""Define a custom VJP rule for the function represented by this instance.
Args:
Expand All @@ -506,6 +510,27 @@ def defvjp(self,
function, and the tuple elements may be arrays or nested
tuples/lists/dicts thereof so as to match the structure of the primal
input arguments.
symbolic_zeros: boolean, indicating whether to indicate symbolic zeros in
the ``fwd`` and ``bwd`` rules. Setting this option to ``True`` allows
custom derivative rules to detect when certain inputs, and when certain
cotangent outputs, are not involved in differentiation. If ``True``:
* ``fwd`` must accept, for each leaf value ``x`` in the pytree
comprising an argument to the original function, a pair ``(x, zero)``,
where ``x`` is the original argument and ``zero`` is a boolean. The
``zero`` part indicates whether or not the argument is not involved in
differentiation (i.e., whether the corresponding Jacobian "column" is
zero).
* ``bwd`` will be passed objects representing static symbolic zeros in
its cotangent argument in correspondence with unperturbed values;
otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to ``True`` allows these rules to detect whether
certain inputs and outputs are not involved in differentiation, but at
the cost of special handling: the signature of ``fwd`` changes, and
``bwd`` receives objects that, for instance, cannot be passed to
``jax.numpy`` functions. Default ``False``.
Returns:
None.
Expand All @@ -527,14 +552,15 @@ def f_bwd(res, g):
"""
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.fwd or not self.bwd:
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if config.jax_enable_custom_vjp_by_custom_transpose:
if self.nondiff_argnums:
Expand All @@ -555,17 +581,30 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
fwd = _project_fwd(fwd, self.symbolic_zeros)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
*args_flat, out_trees=out_trees,
symbolic_zeros=self.symbolic_zeros)
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)

@dataclasses.dataclass
class ZeroTagged:
val: Any
zero: bool

@lu.transformation
def _project_fwd(symbolic_zeros, *args, **kwargs):
project_leaf = ((lambda x: (x.val, x.zero)) if symbolic_zeros else
(lambda x: x.val))
yield (yield tree_map(project_leaf, (args, kwargs)))

def _check_for_tracers(x):
for leaf in tree_leaves(x):
if isinstance(x, core.Tracer):
Expand All @@ -579,8 +618,10 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)

@lu.transformation_with_aux
def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args):
py_args = tree_unflatten(in_tree, args)
def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type,
*args):
tagged_args = [ZeroTagged(x, z) for x, z in zip(args[::2], args[1::2])]
py_args = tree_unflatten(in_tree, tagged_args)
pair_out = yield py_args, {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
Expand Down Expand Up @@ -672,7 +713,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive

def bind(self, fun, fwd, bwd, *args, out_trees):
def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = process_env_traces(
Expand All @@ -682,7 +723,8 @@ def bind(self, fun, fwd, bwd, *args, out_trees):
tracers = map(top_trace.full_raise, args) # type: ignore
bwd_ = lambda *args: bwd(*args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
out_trees=out_trees)
out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if fst:
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
Expand Down Expand Up @@ -749,30 +791,33 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: Callable, out_trees: Callable, num_consts: int):
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
raise ad.CustomVJPException()
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers!
zeros = [type(t) is Zero for t in args_dot]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers!
out_tree, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out)
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: Callable, out_trees: Callable, num_consts: int):
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

Expand All @@ -785,8 +830,8 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
out_dims2 = []

@pe._memoize
def batched_fwd_jaxpr_thunk():
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
def batched_fwd_jaxpr_thunk(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
main_type)
Expand All @@ -795,17 +840,20 @@ def batched_fwd_jaxpr_thunk():

fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
fwd_args_batched, main_type, spmd_axis_name)
batched_bwd = batching.batch_custom_vjp_bwd(
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
spmd_axis_name)

batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
out_trees=out_trees, num_consts=num_consts)
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None)
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
_custom_vjp_call_jaxpr_vmap, None)

xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)

Expand Down
30 changes: 21 additions & 9 deletions jax/_src/interpreters/ad.py
Expand Up @@ -27,9 +27,9 @@
from jax._src import core
from jax._src import source_info_util
from jax._src.ad_util import (
add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
zeros_like_p, Zero, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros)
add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval,
zeros_like_jaxval, zeros_like_p)
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
Expand Down Expand Up @@ -387,16 +387,23 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros):
def post_process_custom_jvp_call(self, out_tracers, _):
raise CustomJVPException()

def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
symbolic_zeros):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
tangents_in = map(instantiate_zeros, tangents_in)
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
fwd_in = [(core.full_lower(p), type(t) is Zero)
for p, t in zip(primals_in, tangents_in)]
fwd_in = [x for pair in fwd_in for x in pair] # flatten
res_and_primals_out = fwd.call_wrapped(*fwd_in)
out_tree, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
# We don't need to handle any symbolic zeros on tangents_in or
# tangents_out below, because custom_lin_p is never executed and
# doesn't correspond to any custom user rule.
tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out)
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)

Expand Down Expand Up @@ -745,10 +752,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__):
"function.")
custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)

def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals,
symbolic_zeros):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
if symbolic_zeros:
cts_out = map(replace_internal_symbolic_zeros, cts_out)
else:
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
cts_in = bwd(*res, *cts_out)
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose

Expand Down

0 comments on commit 86907ec

Please sign in to comment.