Skip to content

Commit

Permalink
Clarified the type of the inputs to callback APIs
Browse files Browse the repository at this point in the history
The callback APIs were migrated to use jax.Arrays for both inputs and outputs
in JAX 0.4.27.

PiperOrigin-RevId: 634473890
  • Loading branch information
superbobry authored and jax authors committed May 16, 2024
1 parent 380503b commit 01194bd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/_tutorials/external-callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def f(x):
result = f(2)
```

This works by passing the runtime value represented by `y` back to the host process, where the host can print the value.
This works by passing the runtime value of `y` as a CPU {class}`jax.Array` back to the host process, where the host can print it.

(external-callbacks-flavors-of-callback)=
## Flavors of callback
Expand Down
23 changes: 18 additions & 5 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@ def pure_callback_impl(
vectorized: bool,
):
del sharding, vectorized, result_avals
cpu_device, *_ = jax.local_devices(backend="cpu")
try:
cpu_device, *_ = jax.local_devices(backend="cpu")
except RuntimeError as e:
raise RuntimeError(
"jax.pure_callback failed to find a local CPU device to place the"
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
" JAX_PLATFORMS environment variable."
) from e
args = jax.device_put(args, cpu_device)
with jax.default_device(cpu_device):
try:
Expand Down Expand Up @@ -262,9 +269,8 @@ def pure_callback(
For more explanation, see `External Callbacks`_.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on CPU, like any
Python+NumPy function.
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
it should also return JAX arrays on CPU.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
Expand Down Expand Up @@ -357,7 +363,14 @@ def io_callback_impl(
ordered: bool,
):
del result_avals, sharding, ordered
cpu_device, *_ = jax.local_devices(backend="cpu")
try:
cpu_device, *_ = jax.local_devices(backend="cpu")
except RuntimeError as e:
raise RuntimeError(
"jax.io_callback failed to find a local CPU device to place the"
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
" JAX_PLATFORMS environment variable."
) from e
args = jax.device_put(args, cpu_device)
with jax.default_device(cpu_device):
try:
Expand Down
11 changes: 9 additions & 2 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,20 @@ class OrderedDebugEffect(effects.Effect):
def debug_callback_impl(*args, callback: Callable[..., Any],
effect: DebugEffect):
del effect
cpu_device, *_ = jax.local_devices(backend="cpu")
try:
cpu_device, *_ = jax.local_devices(backend="cpu")
except RuntimeError as e:
raise RuntimeError(
"jax.debug.callback failed to find a local CPU device to place the"
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
" JAX_PLATFORMS environment variable."
) from e
args = jax.device_put(args, cpu_device)
with jax.default_device(cpu_device):
try:
callback(*args)
except BaseException:
logger.exception("jax.debug_callback failed")
logger.exception("jax.debug.callback failed")
raise
return ()

Expand Down

0 comments on commit 01194bd

Please sign in to comment.