Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
887b7ce by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
  • Loading branch information
jax authors committed Aug 18, 2022
1 parent af7d1c4 commit fe665b3
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 80 deletions.
174 changes: 116 additions & 58 deletions jax/_src/custom_derivatives.py
Expand Up @@ -15,8 +15,8 @@
from functools import update_wrapper, reduce, partial
import inspect
import operator as op
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
Any)
from typing import (Any, Callable, Generic, List, Optional, Sequence, Set,
Tuple, TypeVar)

from jax import core
from jax import linear_util as lu
Expand Down Expand Up @@ -276,6 +276,8 @@ def _flatten_jvp(in_tree, *args):
yield primals_out + tangents_out, out_tree

class CustomJVPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive

def bind(self, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
Expand All @@ -295,23 +297,6 @@ def impl(self, fun, _, *args):
def post_process(self, trace, out_tracers, jvp_was_run: bool):
return trace.post_process_custom_jvp_call(out_tracers, jvp_was_run)

def get_bind_params(self, params):
new_params = dict(params)
call_jaxpr = new_params.pop('call_jaxpr')
num_consts = new_params.pop('num_consts')
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr))

@lu.wrap_init
def jvp(*xs):
jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk()
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
return core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *tangents)

return [fun, jvp], new_params

@lu.transformation_with_aux
def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
outs = yield args, {}
Expand All @@ -329,23 +314,6 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable

def get_bind_params(self, params):
new_params = dict(params)
call_jaxpr = new_params.pop('call_jaxpr')
num_consts = new_params.pop('num_consts')
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr))

@lu.wrap_init
def jvp(*xs):
jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk()
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
return core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *tangents)

return [fun, jvp], new_params

def _apply_todos(todos, outs):
todos_list = list(todos)
while todos_list:
Expand All @@ -356,35 +324,122 @@ def _apply_todos(todos, outs):
allowed_effects: Set[core.Effect] = set()
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
disallowed_effects = {eff for eff in call_jaxpr.effects if eff not in

def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):
del params # other params ignored because we're just executing the primal fun
return core.jaxpr_as_fun(fun_jaxpr)(*args)

def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
del args, params
disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in
allowed_effects}
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
return call_jaxpr.out_avals, call_jaxpr.effects
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
return fun_jaxpr.out_avals, fun_jaxpr.effects

def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
num_consts):
del jvp_jaxpr_thunk, num_consts
args_ = map(mlir.wrap_singleton_ir_values, args)
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.tokens_in, consts, *args_)
ctx.set_tokens_out(tokens)
return out
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)
custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_effectful_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p

mlir.register_lowering(custom_jvp_call_jaxpr_p, mlir.lower_fun(
_custom_jvp_call_jaxpr_impl, multiple_results=True))


def _custom_jvp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
raise ad.CustomJVPException()
jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() # consts can be tracers!
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom jvp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
outs = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *args, *args_dot)
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
if config.jax_enable_checks:
assert all(map(core.typecheck, fun_jaxpr.out_avals, primals_out))
return primals_out, tangents_out
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp

def _custom_jvp_call_jaxpr_vmap(
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
num_out = len(fun_jaxpr.out_avals)

in_batched = [d is not not_mapped for d in in_dims]
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk

@pe._memoize
def batched_jvp_jaxpr_thunk():
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
_, args_batched = split_list(in_batched, [num_consts])
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, axis_size, args_batched * 2, False,
axis_name, main_type)
primals_batched, tangents_batched = split_list(all_batched, [num_out])
out_batched = map(op.or_, primals_batched, tangents_batched)
out_dims2.append([0 if b else not_mapped for b in out_batched])
batched_jvp_jaxpr, _ = batching.batch_jaxpr(
jvp_jaxpr, axis_size, args_batched * 2, out_batched * 2,
axis_name, main_type)
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

batched_outs = custom_jvp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap

xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p)

# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already
# been linearized, we can drop the jvp rule.
def _custom_jvp_call_transpose(params, jaxpr, args, ct, _, reduce_axes):
del params
return ad.backward_pass(jaxpr.jaxpr, reduce_axes, None, jaxpr.consts, args, ct)
ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
jvp_jaxpr_thunk, num_consts):
del jvp_jaxpr_thunk, num_consts
return ad.backward_pass(
fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts)
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose

def custom_jvp_jaxpr_custom_partial_eval_rule(
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
eqn: core.JaxprEqn
) -> Tuple[Optional[core.JaxprEqn], core.JaxprEqn, List[bool], List[bool], List[core.Var]]:
# It doesn't make sense to unzip (i.e. break up) a custom_jvp function into
# constituent parts, so we always perform full remat. An alternative would be
# to allow the policy function to decide whether the value of a
# custom_jvp-decorated function's application should be saved or not.
# TODO(mattjj,jekbradbury): the user writing the custom_jvp-decorated function
# probably has a better idea for what to do under remat (e.g. if the function
# contains dots or not), so we should allow for more expressive interaction
# (e.g. allow the policy to depend on which custom_jvp-decorated function is
# being applied, or annotating the behavior where custom_vjp is called.)
inst_out = [True] * len(eqn.outvars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
if any(unks_in):
unks_out = [True] * len(eqn.outvars)
return None, eqn, unks_out, inst_out, new_inst
else:
unks_out = [False] * len(eqn.outvars)
return eqn, eqn, unks_out, inst_out, new_inst
pe.partial_eval_jaxpr_custom_rules[custom_jvp_call_jaxpr_p] = \
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore


### VJPs
Expand Down Expand Up @@ -721,6 +776,9 @@ def batched_fwd_jaxpr_thunk():
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
mlir.register_lowering(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)

pe.partial_eval_jaxpr_custom_rules[custom_vjp_call_jaxpr_p] = \
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore


def custom_gradient(fun):
"""Convenience function for defining custom VJP rules (aka custom gradients).
Expand Down
3 changes: 1 addition & 2 deletions jax/core.py
Expand Up @@ -500,8 +500,7 @@ def escaped_tracer_error(tracer, detail=None):
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with type {tracer.aval.str_short()} wrapped in a '
f'{type(tracer).__name__} to escape the scope of the transformation.\n'
f'with shape {tracer.shape} and dtype {tracer.dtype} to escape.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer, '_debug_info', None)
Expand Down
1 change: 1 addition & 0 deletions jax/custom_derivatives.py
Expand Up @@ -20,6 +20,7 @@
custom_gradient as custom_gradient,
custom_jvp as custom_jvp,
custom_jvp_call_p as custom_jvp_call_p,
custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p,
custom_vjp as custom_vjp,
custom_vjp_call_p as custom_vjp_call_p,
custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p,
Expand Down
6 changes: 5 additions & 1 deletion jax/experimental/callback.py
Expand Up @@ -264,7 +264,9 @@ def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
vals = [t.val for t in tracers]

new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls)
if primitive == cd.custom_vjp_call_jaxpr_p:
if primitive == cd.custom_jvp_call_jaxpr_p:
thunk_name = 'jvp_jaxpr_thunk'
elif primitive == cd.custom_vjp_call_jaxpr_p:
thunk_name = 'fwd_jaxpr_thunk'
params['bwd'] = callback_subtrace(params['bwd'], main)
else:
Expand All @@ -285,5 +287,7 @@ def new_thunk():
num_consts=new_num_consts, **params)
return safe_map(trace.pure, out)

custom_callback_rules[cd.custom_jvp_call_jaxpr_p] = partial(
_custom_derivative_call_jaxpr_callback_rule, cd.custom_jvp_call_jaxpr_p)
custom_callback_rules[cd.custom_vjp_call_jaxpr_p] = partial(
_custom_derivative_call_jaxpr_callback_rule, cd.custom_vjp_call_jaxpr_p)
6 changes: 3 additions & 3 deletions jax/experimental/host_callback.py
Expand Up @@ -1615,8 +1615,8 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
# cased to just pass-through the token
in_axes=eqn.params["in_axes"] + (None, None),
out_axes=eqn.params["out_axes"] + (0, 0))))
elif eqn.primitive is custom_derivatives.custom_jvp_call_p:
fun_jaxpr = eqn.params["call_jaxpr"]
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
fun_jaxpr = eqn.params["fun_jaxpr"]

def unreachable_thunk():
assert False, "Should not be reached"
Expand All @@ -1627,7 +1627,7 @@ def unreachable_thunk():
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
jvp_jaxpr_thunk=unreachable_thunk
)))
elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
Expand Down
7 changes: 3 additions & 4 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -2786,15 +2786,14 @@ def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):

tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve

def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr,
def _custom_jvp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable,
num_consts: int) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
del jvp_jaxpr_thunk, num_consts
return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp")
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_jvp")


tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call
tf_impl[custom_derivatives.custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr


def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
Expand Down
8 changes: 8 additions & 0 deletions jax/experimental/jet.py
Expand Up @@ -61,6 +61,7 @@
import jax
from jax import core
from jax import lax
from jax.custom_derivatives import custom_jvp_call_jaxpr_p
from jax.interpreters import xla
import jax.linear_util as lu
import jax.numpy as jnp
Expand Down Expand Up @@ -681,6 +682,13 @@ def select_min_and_avg_eq(x_i, y_i):
return primal_out, series_out
jet_rules[lax.min_p] = _lax_min_taylor_rule

def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
jvp_jaxpr_thunk):
# TODO(mattjj): do something better than ignoring custom jvp rules for jet?
del jvp_jaxpr_thunk
return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule

def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
mode):
Expand Down
49 changes: 41 additions & 8 deletions jax/interpreters/partial_eval.py
Expand Up @@ -463,12 +463,45 @@ def out_axes_transform(out_axes):
def _current_truncated_name_stack(self):
return source_info_util.current_name_stack()[len(self.name_stack):]

def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# We assume partial evaluation is only performed to build linear functions,
# and hence we don't need to keep the custom JVP rule around anymore.
del jvp
assert not all(t.is_known() for t in tracers)
return fun.call_wrapped(*tracers)
def process_custom_jvp_call(self, prim, f, jvp, tracers):
# TODO(mattjj): after old remat is deleted, make this method trivial:
# https://github.com/google/jax/pull/9137/files#diff-440d9df723b313bb263bc7704103cad1dcc886ff6553aa78c30188b0b323b686R319-R323
# Because we instantiate all tracers, in_knowns is all False.
tracers = map(self.instantiate_const_abstracted, tracers)
in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers])
f = trace_to_subjaxpr_nounits(f, self.main, True)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
out_flat = prim.bind(f, jvp)
out_knowns, out_avals, jaxpr, env = aux()
out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
res_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())

@_memoize
def jvp_jaxpr_thunk():
jvp_ = trace_to_subjaxpr_nounits(jvp, self.main, True)
jvp_, aux = partial_eval_wrapper_nounits(
jvp_, tuple(in_knowns) * 2, tuple(in_avals) * 2)
with core.new_sublevel():
out_flat = jvp_.call_wrapped()
out_knowns, out_avals, jaxpr, env = aux()
_, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*res, *env)

name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers),
out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(res)+len(env)),
jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

def post_process_custom_jvp_call(self, out_tracers, _):
# This path should only be reachable if we expose a partial eval API
Expand Down Expand Up @@ -1763,8 +1796,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers):
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_fun_jaxpr,
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts)),
fun_jaxpr.effects,
Expand Down

0 comments on commit fe665b3

Please sign in to comment.