From 14407b9a06fded75a0747035dc0661c7874e4e89 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Dec 2023 21:10:29 +0000 Subject: [PATCH] Upgrade JAX debugging doc --- docs/tutorials/advanced-autodiff.md | 2 +- docs/tutorials/advanced-debugging.md | 16 ++ docs/tutorials/automatic-differentiation.md | 2 +- docs/tutorials/debugging.md | 165 +++++++++++++++++++- docs/tutorials/external-callbacks.md | 42 ++--- docs/tutorials/index.rst | 4 +- docs/tutorials/thinking-in-jax.md | 16 ++ 7 files changed, 219 insertions(+), 28 deletions(-) create mode 100644 docs/tutorials/advanced-debugging.md create mode 100644 docs/tutorials/thinking-in-jax.md diff --git a/docs/tutorials/advanced-autodiff.md b/docs/tutorials/advanced-autodiff.md index 7dd09576b042..f01176503a0e 100644 --- a/docs/tutorials/advanced-autodiff.md +++ b/docs/tutorials/advanced-autodiff.md @@ -11,5 +11,5 @@ For the time being, you may find some related content in the old documentation: - {doc}`../notebooks/Custom_derivative_rules_for_Python_code`. ``` -(defining-custom-derivative-rules)= +(advanced-autodiff-custom-derivative-rules)= ## Defining custom derivative rules diff --git a/docs/tutorials/advanced-debugging.md b/docs/tutorials/advanced-debugging.md new file mode 100644 index 000000000000..316d4470f85e --- /dev/null +++ b/docs/tutorials/advanced-debugging.md @@ -0,0 +1,16 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(advanced-debugging)= +# Advanced debugging diff --git a/docs/tutorials/automatic-differentiation.md b/docs/tutorials/automatic-differentiation.md index 49486060fe2b..d37074ad8464 100644 --- a/docs/tutorials/automatic-differentiation.md +++ b/docs/tutorials/automatic-differentiation.md @@ -213,4 +213,4 @@ check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives ## Next steps -The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`defining-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested. +The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested. diff --git a/docs/tutorials/debugging.md b/docs/tutorials/debugging.md index c2f5465305b4..3f2757b80661 100644 --- a/docs/tutorials/debugging.md +++ b/docs/tutorials/debugging.md @@ -1,9 +1,162 @@ -# Debugging +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- -```{note} -This is a placeholder for a section in the new {ref}`jax-tutorials`. +(debugging)= +# Debugging 101 -For the time being, you may find some related content in the old documentation: -- {doc}`../debugging/index` -- {doc}`../errors` +This tutorial introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations. + +Let's begin with {func}`jax.debug.print`. + +## JAX `debug.print` for high-level debugging + +**TL;DR** Here is a rule of thumb: + +- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. +- Use Python `print` for static values, such as dtypes and array shapes. + +With some JAX transformations, such as {func}`jax.grad` and {func}`jax.vmap`, you can use Python’s built-in `print` function to print out numerical values. However, with {func}`jax.jit` for example, you need to use {func}`jax.debug.print`, because those transformations delay numerical evaluation. + +Below is a basic example with {func}`jax.jit`: + +```{code-cell} +import jax +import jax.numpy as jnp + +@jax.jit +def f(x): + jax.debug.print("This is `jax.debug.print` of x {x}", x=x) + y = jnp.sin(x) + jax.debug.print("This is `jax.debug.print` of y {y} 🤯", y=y) + return y + +f(2.) +``` + +{func}`jax.debug.print` can reveal the information about how computations are evaluated. + +Here's an example with {func}`jax.vmap`: + +```{code-cell} +def f(x): + jax.debug.print("This is `jax.debug.print` of x: {}", x) + y = jnp.sin(x) + jax.debug.print("This is `jax.debug.print` of y: {}", y) + return y + +xs = jnp.arange(3.) + +jax.vmap(f)(xs) +``` + +Here's an example with {func}`jax.lax.map`: + +```{code-cell} +jax.lax.map(f, xs) +``` + +Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect. + +Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's `print`, but it's consistent if you apply {func}`jax.jit` during the call. + +```{code-cell} +def f(x): + jax.debug.print("This is `jax.debug.print` of x: {}", x) + return x ** 2 + +jax.grad(f)(1.) ``` + +Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter. + +For example: + +```{code-cell} +@jax.jit +def f(x, y): + jax.debug.print("This is `jax.debug.print of x: {}", x, ordered=True) + jax.debug.print("This is `jax.debug.print of y: {}", y, ordered=True) + return x + y +``` + +To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. + + +## JAX `debug.breakpoint` for `pdb`-like debugging + +**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. + +To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. + +To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.) + +Example: + +```{code-cell} +:tags: [raises-exception] + +def breakpoint_if_nonfinite(x): + is_finite = jnp.isfinite(x).all() + def true_fn(x): + pass + def false_fn(x): + jax.debug.breakpoint() + jax.lax.cond(is_finite, true_fn, false_fn, x) + +@jax.jit +def f(x, y): + z = x / y + breakpoint_if_nonfinite(z) + return z +f(2., 0.) # ==> Pauses during execution +``` + +![JAX debugger](../_static/debugger.gif) + +## JAX `debug.callback` for more control during debugging + +As mentioned in the beginning, {func}`jax.debug.print` is a small wrapper around {func}`jax.debug.callback`. The {func}`jax.debug.callback` method allows you to have greater control over string formatting and the debugging output, like printing or plotting. It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in {ref]`external-callbacks` for more information). + +For example: + +```{code-cell} +def log_value(x): + print("log:", x) + +@jax.jit +def f(x): + jax.debug.callback(log_value, x) + return x + +f(1.0); +``` + +This callback is compatible with {func}`jax.vmap` and {func}`jax.grad`: + +```{code-cell} +x = jnp.arange(5.0) +jax.vmap(f)(x); +``` + +```{code-cell} +jax.grad(f)(1.0); +``` + +This can make {func}`jax.debug.callback` useful for general-purpose debugging. + +You can learn more about different flavors of JAX callbacks in {ref}`external-callbacks-flavors-of-callback` and {ref}`external-callbacks-exploring-debug-callback`. + +## Next steps + +Check out the {ref}`advanced-debugging` to learn more about debugging in JAX. diff --git a/docs/tutorials/external-callbacks.md b/docs/tutorials/external-callbacks.md index 811c4356a719..4564e6aaff6d 100644 --- a/docs/tutorials/external-callbacks.md +++ b/docs/tutorials/external-callbacks.md @@ -15,7 +15,7 @@ kernelspec: (external-callbacks)= # External callbacks -This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, or another transformation. +This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`. ## Why callbacks? @@ -35,9 +35,9 @@ def f(x): result = f(2) ``` -What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)). +What is printed is not the runtime value, but the trace-time abstract value (if you're not familiar with *tracing* in JAX, a good primer can be found in {ref}`thinking-in-jax`. -To print the value at runtime we need a callback, for example {func}`jax.debug.print`: +To print the value at runtime, you need a callback, for example {func}`jax.debug.print` (you can learn more about debugging in {ref}`debugging`): ```{code-cell} @jax.jit @@ -51,15 +51,16 @@ result = f(2) This works by passing the runtime value represented by `y` back to the host process, where the host can print the value. -## Flavors of Callback +(external-callbacks-flavors-of-callback)= +## Flavors of callback In earlier versions of JAX, there was only one kind of callback available, implemented in {func}`jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations: -- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect. +- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effects. - {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk. - {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler. -(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`). +(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`). From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. @@ -232,7 +233,7 @@ jax.grad(f)(1.0); Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation. - +(external-callbacks-exploring-debug-callback)= ### Exploring `debug.callback` Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program. @@ -270,11 +271,12 @@ This can make `debug.callback` more useful for general-purpose debugging than ei ## Example: `pure_callback` with `custom_jvp` -One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`). -Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the {mod}`jax.scipy` or {mod}`jax.numpy` wrappers. +One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp`. (Refer to {ref}`advanced-autodiff-custom-derivative-rules` for more details on {func}`jax.custom_jvp`). + +Suppose you want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the {mod}`jax.scipy` or {mod}`jax.numpy` wrappers. Here, we'll consider creating a wrapper for the Bessel function of the first kind, available in {mod}`scipy.special.jv`. -We can start by defining a straightforward {func}`~jax.pure_callback`: +You can start by defining a straightforward {func}`~jax.pure_callback`: ```{code-cell} import jax @@ -300,7 +302,7 @@ def jv(v, z): shape=jnp.broadcast_shapes(v.shape, z.shape), dtype=z.dtype) - # We use vectorize=True because scipy.special.jv handles broadcasted inputs. + # You use vectorize=True because scipy.special.jv handles broadcasted inputs. return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True) ``` @@ -328,7 +330,7 @@ And here is the same result again with {func}`~jax.vmap`: print(jax.vmap(j1)(z)) ``` -However, if we call {func}`~jax.grad`, we see an error because there is no autodiff rule defined for this function: +However, if you call {func}`~jax.grad`, you will get an error because there is no autodiff rule defined for this function: ```{code-cell} :tags: [raises-exception] @@ -336,7 +338,7 @@ However, if we call {func}`~jax.grad`, we see an error because there is no autod jax.grad(j1)(z) ``` -Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`: +Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), you find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`: $$ d J_\nu(z) = \left\{ @@ -346,9 +348,9 @@ d J_\nu(z) = \left\{ \end{eqnarray}\right. $$ -The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example. +The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types you don't need to worry about its gradient for the sake of this example. -We can use {func}`jax.custom_jvp` to define this automatic differentiation rule for our callback function: +You can use {func}`jax.custom_jvp` to define this automatic differentiation rule for your callback function: ```{code-cell} jv = jax.custom_jvp(jv) @@ -362,19 +364,21 @@ def _jv_jvp(primals, tangents): return jv(v, z), z_dot * djv_dz ``` -Now computing the gradient of our function will work correctly: +Now computing the gradient of your function will work correctly: ```{code-cell} j1 = partial(jv, 1) print(jax.grad(j1)(2.0)) ``` -Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free: +Further, since we've defined your gradient in terms of `jv` itself, JAX's architecture means that you get second-order and higher derivatives for free: ```{code-cell} jax.hessian(j1)(2.0) ``` -Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of {func}`scipy.special.jv` from the host back to the device. +Keep in mind that although this all works correctly with JAX, each call to your callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of {func}`scipy.special.jv` from the host back to the device. + When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called. -However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities. + +However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern a relatively straightforward way to extend JAX's capabilities. diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 30fcd7dba73a..56bd9a4fe1f6 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -20,9 +20,11 @@ JAX 101 installation quickstart jax-as-accelerated-numpy + thinking-in-jax jit-compilation automatic-vectorization automatic-differentiation + debugging random-numbers working-with-pytrees single-host-sharding @@ -38,8 +40,8 @@ JAX 201 parallelism advanced-autodiff + advanced-debugging external-callbacks - debugging profiling-and-performance diff --git a/docs/tutorials/thinking-in-jax.md b/docs/tutorials/thinking-in-jax.md new file mode 100644 index 000000000000..41fd4261f887 --- /dev/null +++ b/docs/tutorials/thinking-in-jax.md @@ -0,0 +1,16 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(thinking-in-jax)= +# Thinking in JAX