# External Callbacks in JAX

This guide is a work-in-progress outlining the uses of various callback functions, which allow JAX code to execute certain commands on the host, even while running under `jit`, `vmap`, `grad`, or another transformation.

This is a work-in-progress, and will be updated soon.

*TODO(jakevdp, sharadmv): fill-in some simple examples of {func}`jax.pure_callback`, {func}`jax.debug.callback`, {func}`jax.debug.print`, and others.*

## Example: `pure_callback` with `custom_jvp`

One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).
Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.

Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.
We can start by defining a straightforward `pure_callback`:

In [None]:
import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:

In [None]:
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)

In [None]:
print(j1(z))



[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


Here is the same result with `jit`:

In [None]:
print(jax.jit(j1)(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


And here is the same result again with `vmap`:

In [None]:
print(jax.vmap(j1)(z))

[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]


However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:

In [None]:
jax.grad(j1)(z)

ValueError: ignored

Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:

$$
d J_\nu(z) = \left\{
\begin{eqnarray}
-J_1(z),\ &\nu=0\\
[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0
\end{eqnarray}\right.
$$

The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.

We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:

In [None]:
jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

Now computing the gradient of our function will work correctly:

In [None]:
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))

-0.06447162


Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:

In [None]:
jax.hessian(j1)(2.0)

DeviceArray(-0.4003078, dtype=float32, weak_type=True)

Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.
When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.
However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities.