Skip to content

Commit

Permalink
[host_callback] Remove old backwards compatibility flag jax_host_call…
Browse files Browse the repository at this point in the history
…back_ad_transforms.

This flag was added in #8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
  • Loading branch information
gnecula authored and jax authors committed Aug 16, 2023
1 parent 14fa067 commit ad15a38
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 390 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -32,6 +32,8 @@ Remember to align the itemized text with the first line of an item within a list
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
* `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true).
* The backwards compatibility flag `--jax_host_callback_ad_transforms`
introduced in December 2021, has been removed.

* Deprecations:
* Several `jax.numpy` APIs have been deprecated following
Expand Down
174 changes: 5 additions & 169 deletions jax/experimental/host_callback.py
Expand Up @@ -555,16 +555,6 @@ def power3_with_cotangents(x):
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
_HOST_CALLBACK_AD_TRANSFORMS = config.DEFINE_bool(
'jax_host_callback_ad_transforms',
config.bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
help=(
'Enable support for jvp/vjp for the host_callback primitives. Default is '
'False, which means that host_callback operates only on primals. '
'The flag exists only temporarily, for backward compatibility.'
)
)


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -650,14 +640,9 @@ def id_tap(tap_func,
"pre-apply keyword arguments, either by using a closure or by passing "
"``functools.partial(tap_func, **kwargs)``.")
raise TypeError(msg)
if _HOST_CALLBACK_AD_TRANSFORMS.value:
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
'backwards compatibility mode. This flag, and the behavior '
'it enabled will be removed soon.',
FutureWarning)

if result is not None:
flat_results, result_treedef = tree_util.tree_flatten(result)
flat_results, _ = tree_util.tree_flatten(result)
for r in flat_results:
dispatch.check_arg(r)

Expand All @@ -670,18 +655,7 @@ def id_tap(tap_func,
device_index=device_index)

if result is not None:
# Return the results, but add a dependency on the call, to ensure it
# is kept in the graph.
if _HOST_CALLBACK_AD_TRANSFORMS.value:
call_flat_results, _ = tree_util.tree_flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
for r in flat_results]
else:
call_flat_results = flat_results
return result_treedef.unflatten(call_flat_results)
else:
return result
return result
else:
return call_res

Expand Down Expand Up @@ -918,60 +892,6 @@ def pp_val(arg) -> pp.Doc:
def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
return tuple(core.raise_to_shaped(core.get_aval(v)) for v in vals)

### The id_tap_dep primitive
# The id_tap_dep_p primitive is used to create a dependency of the result of
# id_tap on the actual tap operation. This is only needed when the
# id_tap function is used with the `result` parameter. This primitive acts
# as the identity operator on the first argument.
#
# For example, given `id_tap(f, (a, b), result=(r, s)`, we convert this to
#
# a1, b1 = outside_call_p(f, a, b)
# r1 = id_tap_dep_p(r, a1)
# s1 = id_tap_dep_p(s, a1)
#
# There are always two arguments and the result is equal to the first.
id_tap_dep_p = core.Primitive("id_tap_dep")
id_tap_dep_p.multiple_results = False
id_tap_dep_p.def_impl(lambda r, _: r)
xla.register_translation(id_tap_dep_p,
lambda ctx, avals_in, avals_out, a_res, a_tap: [a_res])
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)

def _id_tap_dep_jvp_rule(primals, tangents):
if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]),
id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1]))

ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule

def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res)
else:
ct_res = None
if ad.is_undefined_primal(arg_tap):
ct_tap = ad.Zero(arg_tap.aval)
else:
ct_tap = None
return (ct_res, ct_tap)

ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule


def _id_tap_dep_batching_rule(batched_args, batch_dims):
if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
arg_res, arg_tap = batched_args
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]


batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule

### The outside_call primitive
"""
This primitive is used to implement the `call` and `id_tap` functions.
Expand Down Expand Up @@ -1439,82 +1359,12 @@ def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
if _HOST_CALLBACK_AD_TRANSFORMS.value:
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))

arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
(arg_treedef.unflatten(primals),
arg_treedef.unflatten(tangents_instantiated)))
out_all = outside_call_p.bind(
*jvp_flat_args,
**dict(_add_transform(params, "jvp"),
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
else:
out_primals_tapped = outside_call_p.bind(*primals, **params)
return tuple(out_primals_tapped), tangents
out_primals_tapped = outside_call_p.bind(*primals, **params)
return tuple(out_primals_tapped), tangents


ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule


def _outside_call_partial_eval_rule(trace, *args, **params):
# partial eval is used after jvp and before transpose.
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
# TODO: just remote the partial eval rule
return trace.default_process_primitive(outside_call_p, args, params)
transforms = params.get("transforms", ())
if not transforms or transforms[-1] != ("jvp",):
# We are not in the process of computing VJP
return trace.default_process_primitive(outside_call_p, args, params)

# The args have been prepared by the id_tap_jvp_rule: primals, tangents. The
# result is a pair of the primal outputs and output tangents.
# One invariant that JAX requires is that if the primals arguments are known
# then the primal outputs must be known. So, if the primal arguments are known
# and some of the tangents are unknown, then we must split the tap into
# one for the primals (thus the output will be considered known), and a
# separate tap for the tangents.
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.")

assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = util.split_list(args, [nr_primals])
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)

if not (all_primals_known and some_tangents_unknown):
return trace.default_process_primitive(outside_call_p, args, params)

prims, _ = params["arg_treedef"].unflatten(args)
_, primals_treedef = api.tree_flatten(prims)

outs_known = trace.default_process_primitive(
outside_call_p, primals,
dict(params,
arg_treedef=primals_treedef,
transforms=transforms[:-1]))
# Now compute the unknowns using the whole tap, and merge them with the tapped ones
outs_all_unknown = trace.default_process_primitive(outside_call_p, args, params)
outs_primals_unknown, outs_tangents_unknown = util.split_list(
outs_all_unknown, [nr_primals])
outs_combined = (
[pe.JaxprTracer(trace, pe.PartialVal.known(primal_known),
primal_unknown.recipe)
for primal_known, primal_unknown in util.safe_zip(outs_known, outs_primals_unknown)] +
outs_tangents_unknown)
return tuple(outs_combined)


pe.custom_partial_eval_rules[outside_call_p] = _outside_call_partial_eval_rule


def _outside_call_transpose_rule(cts, *args, **params):
if not params["identity"]:
raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.")
Expand All @@ -1531,21 +1381,7 @@ def _outside_call_transpose_rule(cts, *args, **params):
*cts_instantiated,
**_add_transform(params, "transpose"))

if not _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False

assert len(args) % 2 == 0
nr_primals = len(args) // 2

args_unflat, tan_unflat = params["arg_treedef"].unflatten(args)
_, vjp_arg_treedef = api.tree_flatten(args_unflat)
# We want to tap the cts_tapped_tangents
cts_primals, cts_tangents = util.split_list(cts_instantiated, [nr_primals])
cts_tangents_through_tap = outside_call_p.bind(
*cts_tangents,
**dict(_add_transform(params, "transpose"),
arg_treedef=vjp_arg_treedef))
return cts_primals + cts_tangents_through_tap
assert False


ad.primitive_transposes[outside_call_p] = _outside_call_transpose_rule
Expand Down

0 comments on commit ad15a38

Please sign in to comment.