Skip to content

Commit

Permalink
Upgrade JAX debugging doc
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 13, 2023
1 parent 384e29e commit 14407b9
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/advanced-autodiff.md
Expand Up @@ -11,5 +11,5 @@ For the time being, you may find some related content in the old documentation:
- {doc}`../notebooks/Custom_derivative_rules_for_Python_code`.
```

(defining-custom-derivative-rules)=
(advanced-autodiff-custom-derivative-rules)=
## Defining custom derivative rules
16 changes: 16 additions & 0 deletions docs/tutorials/advanced-debugging.md
@@ -0,0 +1,16 @@
---
jupytext:
formats: md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3
language: python
name: python3
---

(advanced-debugging)=
# Advanced debugging
2 changes: 1 addition & 1 deletion docs/tutorials/automatic-differentiation.md
Expand Up @@ -213,4 +213,4 @@ check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives

## Next steps

The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`defining-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested.
The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested.
165 changes: 159 additions & 6 deletions docs/tutorials/debugging.md
@@ -1,9 +1,162 @@
# Debugging
---
jupytext:
formats: md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3
language: python
name: python3
---

```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
(debugging)=
# Debugging 101

For the time being, you may find some related content in the old documentation:
- {doc}`../debugging/index`
- {doc}`../errors`
This tutorial introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.

Let's begin with {func}`jax.debug.print`.

## JAX `debug.print` for high-level debugging

**TL;DR** Here is a rule of thumb:

- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
- Use Python `print` for static values, such as dtypes and array shapes.

With some JAX transformations, such as {func}`jax.grad` and {func}`jax.vmap`, you can use Python’s built-in `print` function to print out numerical values. However, with {func}`jax.jit` for example, you need to use {func}`jax.debug.print`, because those transformations delay numerical evaluation.

Below is a basic example with {func}`jax.jit`:

```{code-cell}
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("This is `jax.debug.print` of x {x}", x=x)
y = jnp.sin(x)
jax.debug.print("This is `jax.debug.print` of y {y} 🤯", y=y)
return y
f(2.)
```

{func}`jax.debug.print` can reveal the information about how computations are evaluated.

Here's an example with {func}`jax.vmap`:

```{code-cell}
def f(x):
jax.debug.print("This is `jax.debug.print` of x: {}", x)
y = jnp.sin(x)
jax.debug.print("This is `jax.debug.print` of y: {}", y)
return y
xs = jnp.arange(3.)
jax.vmap(f)(xs)
```

Here's an example with {func}`jax.lax.map`:

```{code-cell}
jax.lax.map(f, xs)
```

Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.

Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's `print`, but it's consistent if you apply {func}`jax.jit` during the call.

```{code-cell}
def f(x):
jax.debug.print("This is `jax.debug.print` of x: {}", x)
return x ** 2
jax.grad(f)(1.)
```

Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter.

For example:

```{code-cell}
@jax.jit
def f(x, y):
jax.debug.print("This is `jax.debug.print of x: {}", x, ordered=True)
jax.debug.print("This is `jax.debug.print of y: {}", y, ordered=True)
return x + y
```

To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`.


## JAX `debug.breakpoint` for `pdb`-like debugging

**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.

To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack.

To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.)

Example:

```{code-cell}
:tags: [raises-exception]
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution
```

![JAX debugger](../_static/debugger.gif)

## JAX `debug.callback` for more control during debugging

As mentioned in the beginning, {func}`jax.debug.print` is a small wrapper around {func}`jax.debug.callback`. The {func}`jax.debug.callback` method allows you to have greater control over string formatting and the debugging output, like printing or plotting. It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in {ref]`external-callbacks` for more information).

For example:

```{code-cell}
def log_value(x):
print("log:", x)
@jax.jit
def f(x):
jax.debug.callback(log_value, x)
return x
f(1.0);
```

This callback is compatible with {func}`jax.vmap` and {func}`jax.grad`:

```{code-cell}
x = jnp.arange(5.0)
jax.vmap(f)(x);
```

```{code-cell}
jax.grad(f)(1.0);
```

This can make {func}`jax.debug.callback` useful for general-purpose debugging.

You can learn more about different flavors of JAX callbacks in {ref}`external-callbacks-flavors-of-callback` and {ref}`external-callbacks-exploring-debug-callback`.

## Next steps

Check out the {ref}`advanced-debugging` to learn more about debugging in JAX.
42 changes: 23 additions & 19 deletions docs/tutorials/external-callbacks.md
Expand Up @@ -15,7 +15,7 @@ kernelspec:
(external-callbacks)=
# External callbacks

This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, or another transformation.
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.

## Why callbacks?

Expand All @@ -35,9 +35,9 @@ def f(x):
result = f(2)
```

What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).
What is printed is not the runtime value, but the trace-time abstract value (if you're not familiar with *tracing* in JAX, a good primer can be found in {ref}`thinking-in-jax`.

To print the value at runtime we need a callback, for example {func}`jax.debug.print`:
To print the value at runtime, you need a callback, for example {func}`jax.debug.print` (you can learn more about debugging in {ref}`debugging`):

```{code-cell}
@jax.jit
Expand All @@ -51,15 +51,16 @@ result = f(2)

This works by passing the runtime value represented by `y` back to the host process, where the host can print the value.

## Flavors of Callback
(external-callbacks-flavors-of-callback)=
## Flavors of callback

In earlier versions of JAX, there was only one kind of callback available, implemented in {func}`jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:

- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.
- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effects.
- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.
- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.

(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).
(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).

From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.

Expand Down Expand Up @@ -232,7 +233,7 @@ jax.grad(f)(1.0);
Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.



(external-callbacks-exploring-debug-callback)=
### Exploring `debug.callback`

Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program.
Expand Down Expand Up @@ -270,11 +271,12 @@ This can make `debug.callback` more useful for general-purpose debugging than ei

## Example: `pure_callback` with `custom_jvp`

One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).
Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the {mod}`jax.scipy` or {mod}`jax.numpy` wrappers.
One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp`. (Refer to {ref}`advanced-autodiff-custom-derivative-rules` for more details on {func}`jax.custom_jvp`).

Suppose you want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the {mod}`jax.scipy` or {mod}`jax.numpy` wrappers.

Here, we'll consider creating a wrapper for the Bessel function of the first kind, available in {mod}`scipy.special.jv`.
We can start by defining a straightforward {func}`~jax.pure_callback`:
You can start by defining a straightforward {func}`~jax.pure_callback`:

```{code-cell}
import jax
Expand All @@ -300,7 +302,7 @@ def jv(v, z):
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# We use vectorize=True because scipy.special.jv handles broadcasted inputs.
# You use vectorize=True because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
```

Expand Down Expand Up @@ -328,15 +330,15 @@ And here is the same result again with {func}`~jax.vmap`:
print(jax.vmap(j1)(z))
```

However, if we call {func}`~jax.grad`, we see an error because there is no autodiff rule defined for this function:
However, if you call {func}`~jax.grad`, you will get an error because there is no autodiff rule defined for this function:

```{code-cell}
:tags: [raises-exception]
jax.grad(j1)(z)
```

Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:
Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), you find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:

$$
d J_\nu(z) = \left\{
Expand All @@ -346,9 +348,9 @@ d J_\nu(z) = \left\{
\end{eqnarray}\right.
$$

The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.
The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types you don't need to worry about its gradient for the sake of this example.

We can use {func}`jax.custom_jvp` to define this automatic differentiation rule for our callback function:
You can use {func}`jax.custom_jvp` to define this automatic differentiation rule for your callback function:

```{code-cell}
jv = jax.custom_jvp(jv)
Expand All @@ -362,19 +364,21 @@ def _jv_jvp(primals, tangents):
return jv(v, z), z_dot * djv_dz
```

Now computing the gradient of our function will work correctly:
Now computing the gradient of your function will work correctly:

```{code-cell}
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
```

Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:
Further, since we've defined your gradient in terms of `jv` itself, JAX's architecture means that you get second-order and higher derivatives for free:

```{code-cell}
jax.hessian(j1)(2.0)
```

Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of {func}`scipy.special.jv` from the host back to the device.
Keep in mind that although this all works correctly with JAX, each call to your callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of {func}`scipy.special.jv` from the host back to the device.

When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.
However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities.

However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern a relatively straightforward way to extend JAX's capabilities.
4 changes: 3 additions & 1 deletion docs/tutorials/index.rst
Expand Up @@ -20,9 +20,11 @@ JAX 101
installation
quickstart
jax-as-accelerated-numpy
thinking-in-jax
jit-compilation
automatic-vectorization
automatic-differentiation
debugging
random-numbers
working-with-pytrees
single-host-sharding
Expand All @@ -38,8 +40,8 @@ JAX 201

parallelism
advanced-autodiff
advanced-debugging
external-callbacks
debugging
profiling-and-performance


Expand Down
16 changes: 16 additions & 0 deletions docs/tutorials/thinking-in-jax.md
@@ -0,0 +1,16 @@
---
jupytext:
formats: md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3
language: python
name: python3
---

(thinking-in-jax)=
# Thinking in JAX

0 comments on commit 14407b9

Please sign in to comment.