Skip to content

Commit

Permalink
Merge pull request #10048 from froystig:custom-vjp-by-jvp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 437891536
  • Loading branch information
jax authors committed Mar 29, 2022
2 parents 4761b1e + 941b221 commit ab8dd4e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 27 deletions.
29 changes: 24 additions & 5 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def int_env(varname: str, default: int) -> int:
return int(os.getenv(varname, str(default)))


UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
"point all uses of the flag will be considered deprecated (following "
"https://jax.readthedocs.io/en/latest/api_compatibility.html).")

UPGRADE_BOOL_EXTRA_DESC = " (transient)"


class Config:
_HAS_DYNAMIC_ATTRIBUTES = True

Expand Down Expand Up @@ -181,6 +189,7 @@ def define_bool_state(
self, name: str, default: bool, help: str, *,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
upgrade: bool = False,
extra_description: str = ""):
"""Set up thread-local state and return a contextmanager for managing it.
Expand All @@ -203,6 +212,9 @@ def define_bool_state(
update_thread_local_hook: a optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
upgrade: optional indicator that this flag controls a canonical feature
upgrade, so that it is `True` for the incoming functionality, `False`
for the outgoing functionality to be deprecated.
extra_description: string, optional: extra information to add to the
summary description.
Expand All @@ -228,8 +240,12 @@ def define_bool_state(
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
name = name.lower()
if upgrade:
help += ' ' + UPGRADE_BOOL_HELP
extra_description += UPGRADE_BOOL_EXTRA_DESC
self.DEFINE_bool(name, bool_env(name.upper(), default), help,
update_hook=update_global_hook)
self._contextmanager_flags.add(name)
Expand Down Expand Up @@ -548,15 +564,12 @@ def update_thread_local_jit_state(**kw):
'computations. Logging is performed with `absl.logging` at WARNING '
'level.'))


enable_custom_prng = config.define_bool_state(
name='jax_enable_custom_prng',
default=False,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations. This will '
'be enabled by default in future versions of JAX, at which point '
'disabling it will be considered deprecated. In a version '
'after that the flag will be removed altogether.'),
extra_description=" (transient)")
'pseudo-random number generator implementations.'))

default_prng_impl = config.define_enum_state(
name='jax_default_prng_impl',
Expand All @@ -565,6 +578,12 @@ def update_thread_local_jit_state(**kw):
help=('Select the default PRNG implementation, used when one is not '
'explicitly provided at seeding time.'))

enable_custom_vjp_by_custom_transpose = config.define_bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))

hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,
Expand Down
121 changes: 99 additions & 22 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from jax import core
from jax import linear_util as lu
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
tree_multimap, treedef_is_leaf, treedef_tuple,
register_pytree_node_class)
Expand Down Expand Up @@ -536,29 +537,35 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
msg = "No VJP defined for custom_vjp function {} using defvjp."
raise AttributeError(msg.format(self.__name__))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
for i in self.nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums = set(self.nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args,
require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
if config.jax_enable_custom_vjp_by_custom_transpose:
if self.nondiff_argnums:
raise NotImplementedError(
'nondiff_argnums not implemented for new custom_vjp')
return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
return tree_unflatten(out_tree, out_flat)
if self.nondiff_argnums:
for i in self.nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums = set(self.nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
args, require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
return tree_unflatten(out_tree, out_flat)

@partial(partial, tree_map)
def _check_for_tracers(x):
Expand Down Expand Up @@ -1160,3 +1167,73 @@ def _linear_call_abstract_eval(*args, **kwargs):
initial_style=True)
mlir.register_lowering(linear_call_p, mlir.lower_fun(
_linear_call_impl, multiple_results=True))


# A stageable primitive that fails when evaluated
unreachable_p: core.Primitive = core.Primitive('unreachable')
unreachable_p.multiple_results = True

def unreachable_impl(*_, out_avals, exc_type, message):
del out_avals
raise exc_type(message)

# Evaluation raises an exception
unreachable_p.def_impl(unreachable_impl)

# Translation raises an exception
# TODO(frostig,mattjj): We have no good way to translate a function
# that errs. Since translation over-approximates concrete evaluation,
# we err on translation for the time being.
xla.register_translation(unreachable_p, unreachable_impl)

# Abstract evaluation proceeds without issue, to allow for staging
unreachable_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)

def unreachable(*args, out_avals=None, exc_type=TypeError,
message='unreachable'):
"""Fail when evaluated concretely (but allow for staging).
This function allows one to assert an impossibility of
evaluation. It can be used to guarantee that evaluation does not
"reach" a certain point in the sense that it does not execute, but
it can nonetheless be staged out by JAX without error.
Args:
*args: The arbitrary pytree of arguments to the function.
out_avals: Optional specification of the output types of this
function invocation from the point of view of staging. If
``None``, these are chosen as equal to types of input arguments.
exc_type: Optional constructor for the Python exception raised if
evaluated.
message: Optional string message for the Python exception raised
if evaluated.
"""
if out_avals is None:
out_avals = tree_map(core.get_aval, args)

args_flat, in_tree = tree_flatten(args)
out_avals_flat, out_tree = tree_flatten(out_avals)
out = unreachable_p.bind(*args_flat, out_avals=out_avals_flat,
exc_type=exc_type, message=message)
return tree_unflatten(out_tree, out)


disallow_jvp = partial(
unreachable,
exc_type=TypeError,
message="can't apply forward-mode autodiff (jvp) to a custom_vjp function.")


def custom_vjp_by_custom_transpose(fun, fwd, bwd):
fun = custom_jvp(fun)

@fun.defjvp
def jvp(primals, tangents):
outs, residuals = fwd(*primals)
tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs)
tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types))
tan_fn.def_transpose(bwd)
return outs, tan_fn(tan_out_types, residuals, tangents)

return fun
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"schur",
"name",
"optimization_barrier",
"unreachable",

# Not high priority?
"after_all",
Expand Down

0 comments on commit ab8dd4e

Please sign in to comment.