Skip to content

Commit

Permalink
Expose pure callback and enable rank polymorphic callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Aug 17, 2022
1 parent 8068e46 commit 393bca1
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 60 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports
the `concrete` option, following the previous version's deprecation; see
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
* Changes
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from functions compiled with `jax.jit` or `jax.pmap`.

## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
Expand Down
1 change: 1 addition & 0 deletions jax/__init__.py
Expand Up @@ -100,6 +100,7 @@
pmap as pmap,
process_count as process_count,
process_index as process_index,
pure_callback as pure_callback,
pxla, # TODO(phawkins): update users to avoid this.
remat as remat,
ShapedArray as ShapedArray,
Expand Down
45 changes: 45 additions & 0 deletions jax/_src/api.py
Expand Up @@ -44,6 +44,7 @@
treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)

from jax._src import callback as jcb
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
Expand Down Expand Up @@ -3230,6 +3231,50 @@ def try_to_block(x):
return x
return jax.tree_util.tree_map(try_to_block, x)

def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any):
"""Calls a pure Python callback function from staged out JAX programs.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on the CPU host.
The callback is treated as "pure" meaning it can be called multiple times when
transformed (for example in a ``vmap`` or ``pmap``), and it can also
potentially be removed from JAX programs via dead-code elimination. Pure
callbacks can also be reordered if data-dependence allows.
When ``pmap``-ed, the pure callback will be called several times (one on each axis
of the map). When `vmap`-ed the behavior will depend on the value of the
``rank_polymorphic`` keyword argument. If the callback is indicated as rank
polymorphic, the callback will be called directly on batched inputs (where the
batch axis is the leading dimension). Additionally, the callbacks should
return outputs that also have a leading batch axis. If not rank polymorphic,
``callback`` will be mapped sequentially across the batched axis.
Args:
callback: A Python callable. The callable will be passed in NumPy arrays and
should return a PyTree of NumPy arrays that matches
``result_shape_dtypes``.
result_shape_dtypes: A PyTree of Python objects that have ``shape`` and
``dtype`` properties that correspond to the shape and dtypes of the
outputs of ``callback``.
*args: The positional arguments to the callback. Must be PyTrees of JAX
types.
rank_polymorphic: A boolean that indicates whether or not ``callback`` is
rank polymorphic, meaning it can handle arrays with additional leading
dimensions. If ``rank_polymorphic`` is `True`, when the callback is mapped
via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the
leading axis.
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
types.
Returns:
The value of ``callback(*args, **kwargs)``.
"""
return jcb.pure_callback(callback, result_shape_dtypes, *args, **kwargs)

def clear_backends():
"""
Expand Down
74 changes: 29 additions & 45 deletions jax/_src/callback.py
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for JAX callbacks."""
from __future__ import annotations

import functools

from typing import Any, Callable
from typing import Any, Callable, Sequence

from jax import core
from jax import lax
from jax import tree_util
from jax._src import lib as jaxlib
from jax._src import util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
import jax.numpy as jnp


# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
Expand All @@ -35,15 +35,16 @@


@pure_callback_p.def_impl
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any]):
del result_avals
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any],
rank_polymorphic: bool):
del rank_polymorphic, result_avals
return callback(*args)


@pure_callback_p.def_abstract_eval
def pure_callback_abstract_eval(*avals, callback: Callable[..., Any],
result_avals):
del avals, callback
result_avals, rank_polymorphic: bool):
del avals, callback, rank_polymorphic
return result_avals


Expand All @@ -66,13 +67,26 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule


def pure_callback_batching_rule(args, dims, *, callback, **params):
def pure_callback_batching_rule(args, dims, *, callback, rank_polymorphic: bool,
result_avals: Sequence[core.ShapedArray]):
axis_size = next(a.shape[0] for a, d in zip(args, dims)
if d is not batching.not_mapped)
new_args = []
for arg, dim in zip(args, dims):
new_args.append(jnp.rollaxis(arg, dim))
outvals = lax.map(
functools.partial(pure_callback_p.bind, callback=callback, **params),
*new_args)
new_args.append(batching.moveaxis(arg, dim, 0))
if rank_polymorphic:
result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals)
outvals = pure_callback_p.bind(
*new_args, callback=callback, rank_polymorphic=rank_polymorphic,
result_avals=result_avals)
else:
from jax._src.lax.control_flow import map as lax_map
outvals = lax_map(
functools.partial(pure_callback_p.bind, callback=callback,
rank_polymorphic=rank_polymorphic, result_avals=result_avals),
*new_args)
return tuple(outvals), (0,) * len(outvals)


Expand Down Expand Up @@ -101,38 +115,7 @@ def _callback(*flat_args):


def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any):
"""Calls a pure Python callback function from staged out JAX programs.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on the CPU host.
The callback is treated as "pure" meaning it can be called multiple times when
transformed (for example in a ``vmap`` or ``pmap``), and it can also
potentially be removed from JAX programs via dead-code elimination. Pure
callbacks can also be reordered if data-dependence allows.
When both `pmap` and `vmap`-ed, the pure callback will be called several times
(one on each axis of the map). In the `pmap` case, these calls will happen
across several threads whereas in the `vmap` case, they will happen serially.
Args:
callback: A Python callable. The callable will be passed in NumPy arrays and
should return a PyTree of NumPy arrays that matches
``result_shape_dtypes``.
result_shape_dtypes: A PyTree of Python objects that have ``shape`` and
``dtype`` properties that correspond to the shape and dtypes of the
outputs of ``callback``.
*args: The positional arguments to the callback. Must be PyTrees of JAX
types.
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
types.
Returns:
The value of ``callback(*args, **kwargs)``.
"""

*args: Any, rank_polymorphic: bool = False, **kwargs: Any):
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))
Expand All @@ -142,5 +125,6 @@ def _flat_callback(*flat_args):
lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
out_flat = pure_callback_p.bind(
*flat_args, callback=_flat_callback, result_avals=flat_result_avals)
*flat_args, callback=_flat_callback,
result_avals=tuple(flat_result_avals), rank_polymorphic=rank_polymorphic)
return tree_util.tree_unflatten(out_tree, out_flat)
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1154,6 +1154,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"getslice",
"full_to_shard",
"shard_to_full",
"pure_callback",

# Not high priority?
"after_all",
Expand Down

0 comments on commit 393bca1

Please sign in to comment.