Skip to content

Commit

Permalink
DOC: one last readthrough of the new 101 tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 17, 2024
1 parent d44b16c commit 48e8457
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 84 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/advanced-autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ kernelspec:
---

(advanced-autodiff)=
# Advanced automatic differentiation 201
# Advanced automatic differentiation

In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.docs.g

Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/automatic-differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ kernelspec:
---

(automatic-differentiation)=
# Automatic differentiation 101
# Automatic differentiation

In this tutorial, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:

- {ref}`automatic-differentiation-taking-gradients`
- {ref}`automatic-differentiation-linear logistic regression`
Expand All @@ -28,9 +28,9 @@ Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advan
While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.

(automatic-differentiation-taking-gradients)=
## 1.Taking gradients with `jax.grad`
## 1. Taking gradients with `jax.grad`

In JAX, you can differentiate a function with the {func}`jax.grad` transformation:
In JAX, you can differentiate a scalar-valued function with the {func}`jax.grad` transformation:

```{code-cell}
import jax
Expand Down Expand Up @@ -162,7 +162,7 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for
(automatic-differentiation-nested-lists-tuples-and-dicts)=
## 3. Differentiating with respect to nested lists, tuples, and dicts

Due to JAX's PyTree abstraction (see {ref}`pytrees-what-is-a-pytree`), differentiating with
Due to JAX's PyTree abstraction (see {ref}`working-with-pytrees`), differentiating with
respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.

Continuing the previous example:
Expand All @@ -176,7 +176,7 @@ def loss2(params_dict):
print(grad(loss2)({'W': W, 'b': b}))
```

You can {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on).
You can create {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on).


(automatic-differentiation-evaluating-using-jax-value_and_grad)=
Expand Down
11 changes: 6 additions & 5 deletions docs/tutorials/automatic-vectorization.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ kernelspec:
---

(automatic-vectorization)=
# Automatic Vectorization in JAX
# Automatic vectorization

In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`.
In the previous section we discussed JIT compilation via the {func}`jax.jit` function.
This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`.

## Manual Vectorization
## Manual vectorization

Consider the following simple code that computes the convolution of two one-dimensional vectors:

Expand Down Expand Up @@ -72,9 +73,9 @@ def manually_vectorized_convolve(xs, ws):
manually_vectorized_convolve(xs, ws)
```

Such re-implementation is messy and error-prone; fortunately JAX provides another way.
Such re-implementation can be messy and error-prone as the complexity of a function increases; fortunately JAX provides another way.

## Automatic Vectorization
## Automatic vectorization

In JAX, the {func}`jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:

Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ kernelspec:
---

(debugging)=
# Debugging 101
# Introduction to debugging

This tutorial introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.

Let's begin with {func}`jax.debug.print`.

## JAX `debug.print` for high-level debugging
## JAX `debug.print` for high-level

**TL;DR** Here is a rule of thumb:

- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
- Use Python `print` for static values, such as dtypes and array shapes.
- Use Python {func}`print` for static values, such as dtypes and array shapes.

Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`,
the Python code is executed with abstract tracers in place of your arrays. Because of this,
the Python `print` statement will only print this tracer value:
the Python {func}`print` function will only print this tracer value:

```{code-cell}
import jax
Expand Down Expand Up @@ -82,7 +82,7 @@ result = jax.lax.map(f, xs)

Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.

Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's `print`, but it's consistent if you apply {func}`jax.jit` during the call.
Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's {func}`print`, but it's consistent if you apply {func}`jax.jit` during the call.

```{code-cell}
def f(x):
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/jax-primitives.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ kernelspec:
---

(jax-internals-jax-primitives)=
# JAX internals 301: JAX primitives
# JAX Internals: primitives

## Introduction to JAX primitives

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/jaxpr.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ kernelspec:
---

(jax-internals-jaxpr)=
# JAX internals 301: The jaxpr language
# JAX internals: The jaxpr language

Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF).

Expand Down
47 changes: 26 additions & 21 deletions docs/tutorials/jit-compilation.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ compilation of a JAX Python function so it can be executed efficiently in XLA.
## How JAX transformations work

In the previous section, we discussed that JAX allows us to transform Python functions.
This is done by first converting the Python function into a simple intermediate language called jaxpr.
The transformations then work on the jaxpr representation.
JAX accomplishes this by reducing each function into a sequence of {term}`primitive` operations, each
representing one fundamental unit of computation.

We can show a representation of the jaxpr of a function by using {func}`jax.make_jaxpr`:
One way to see the sequence of primitives behind a function is using {func}`jax.make_jaxpr`:

```{code-cell}
import jax
Expand All @@ -51,9 +51,11 @@ print(jax.make_jaxpr(log2)(3.0))

The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output.

Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`.
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).

Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour under transformations. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JIT-compiled function to run once (during the first call), and never again, due to JAX's traced execution model.

When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.

Expand All @@ -73,7 +75,8 @@ See how the printed `x` is a `Traced` object? That's the JAX internals at work.

The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation.

A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:
A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it.
For example, if we have a Python conditional, the jaxpr will only know about the branch we take:

```{code-cell}
def log2_if_rank_2(x):
Expand Down Expand Up @@ -143,8 +146,7 @@ def f(x):
else:
return 2 * x
f_jit = jax.jit(f)
f_jit(10) # Should raise an error.
jax.jit(f)(10) # Raises an error
```

```{code-cell}
Expand All @@ -158,19 +160,17 @@ def g(x, n):
i += 1
return x + i
g_jit = jax.jit(g)
g_jit(10, 20) # Should raise an error.
jax.jit(g)(10, 20) # Raises an error
```

The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it.
The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.
Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as
`shape` or `dtype`, and not via their values.
For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).

The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.

For {func}`jax.jit`, the default level is {class}`~jax.core.ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.

In {func}`jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).

One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical.
In that case, you can consider JIT-compiling only part of the function.
For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):

```{code-cell}
# While loop conditioned on x and n with a jitted body.
Expand All @@ -188,7 +188,11 @@ def g_inner_jitted(x, n):
g_inner_jitted(10, 20)
```

If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.
## Marking arguments as static

If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`.
The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input.
It is only a good strategy if the function is guaranteed to see a limited set of static values.

```{code-cell}
f_jit_correct = jax.jit(f, static_argnums=0)
Expand Down Expand Up @@ -227,9 +231,10 @@ print("g:")
%timeit g(10, 20)
```

This is because {func}`jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
This is because {func}`jax.jit` introduces some overhead itself, and so it usually only saves time if the compiled function is nontrivial, or if you will run it numerous times.
Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.

Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
Generally, you want to JIT-compile the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.

## JIT and caching

Expand Down
7 changes: 3 additions & 4 deletions docs/tutorials/key-concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ This section briefly introduces some key concepts of the JAX package.
(key-concepts-jax-arrays)=
## JAX arrays ({class}`jax.Array`)

- `jax.Array` is the default array implementation in JAX.
- `jax.Array` objects are never created directly, but rather using familiar
array creation APIs.
- JAX arrays may be stored on a single device, or sharded across many devices.
The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to
the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it
has some important differences.

### Array creation

Expand Down

0 comments on commit 48e8457

Please sign in to comment.