From 7edc80d63561ed614ae78ed6cfbeb52ad87d88fd Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Wed, 13 Dec 2023 22:49:54 +0000 Subject: [PATCH] Update thinking-in-jax working-with-pytrees --- docs/tutorials/thinking-in-jax.md | 25 +++++++++++++++---------- docs/tutorials/working-with-pytrees.md | 22 +++++++++++----------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/docs/tutorials/thinking-in-jax.md b/docs/tutorials/thinking-in-jax.md index 3ada021b56d4..2a0e0ee5010c 100644 --- a/docs/tutorials/thinking-in-jax.md +++ b/docs/tutorials/thinking-in-jax.md @@ -45,7 +45,7 @@ y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) plt.plot(x_jnp, y_jnp); ``` -The code blocks are identical aside from replacing `np` with `jnp`, and the results are the same. JAX arrays can often be used directly in place of NumPy arrays for things like plotting. +The code blocks are identical aside from replacing NumPy (`np`) with JAX NumPy (`jnp`), and the results are the same. JAX arrays can often be used directly in place of NumPy arrays for things like plotting. The arrays themselves are implemented as different Python types: @@ -99,7 +99,7 @@ print(y) - `jax.Array` is the default array implementation in JAX. - The JAX array is a unified distributed datatype for representing arrays, even with physical storage spanning multiple devices -- Automatic parallelization: You can operate over sharded `jax.Array`s without copying data onto a device using the `jax.jit` transformation. You can also replicate a `jax.Array` to every device on a mesh. +- Automatic parallelization: You can operate over sharded `jax.Array`s without copying data onto a device using the {func}`jax.jit` transformation. You can also replicate a `jax.Array` to every device on a mesh. Consider this simple example: @@ -127,7 +127,7 @@ The `jax.Array` type also helps make parallelism a core feature of JAX. JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — they are called JAX pytrees (also known as nests, or just trees). In the context of machine learning, a pytree can contain model parameters, dataset entries, and reinforcement learning agent observations. -Below is an example of a simple pytree. In JAX, you can use `jax.tree_*`, to extract the flattened leaves from the trees, as demonstrated here: +Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree_util.tree_leaves`, to extract the flattened leaves from the trees, as demonstrated here: ```{code-cell} example_trees = [ @@ -153,8 +153,8 @@ You can learn more in the {ref}`working-with-pytrees` tutorial. **Key concepts:** -- `jax.numpy` is a high-level wrapper that provides a familiar interface. -- `jax.lax` is a lower-level API that is stricter and often more powerful. +- {mod}`jax.numpy` is a high-level wrapper that provides a familiar interface. +- {mod}`jax.lax` is a lower-level API that is stricter and often more powerful. - All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) — the Accelerated Linear Algebra compiler. If you look at the source of {mod}`jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in {mod}`jax.lax`. You can think of {mod}`jax.lax` as a stricter, but often more powerful, API for working with multi-dimensional arrays. @@ -218,7 +218,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently. -For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of `jax.numpy` operations: +For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of {mod}`jax.numpy` operations: ```{code-cell} import jax.numpy as jnp @@ -281,7 +281,7 @@ This is because the function generates an array whose shape is not known at comp - Variables that you don't want to be traced can be marked as *static* -To use `jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function: +To use {func}`jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function: ```{code-cell} @jit @@ -300,7 +300,7 @@ f(x, y) Notice that the print statements execute, but rather than printing the data you passed to the function, though, it prints *tracer* objects that stand-in for them. -These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code. +These tracer objects are what {func}`jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code. When you call the compiled function again on matching inputs, no recompilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python: @@ -310,7 +310,7 @@ y2 = np.random.randn(4) f(x2, y2) ``` -The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the `jax.make_jaxpr` transformation: +The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the {func}`jax.make_jaxpr` transformation: ```python from jax import make_jaxpr @@ -395,7 +395,12 @@ f(x) Notice that although `x` is traced, `x.shape` is a static value. However, when you use {func}`jnp.array` and {func}`jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static). -A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this: +A useful pattern is to: + +- Use NumPy (`numpy`) for operations that should be static (i.e., done at compile-time); and +- Use JAX NumPy (`jax.numpy`) for operations that should be traced (i.e. compiled and executed at run-time). + +For this function, it might look like this: ```{code-cell} from jax import jit diff --git a/docs/tutorials/working-with-pytrees.md b/docs/tutorials/working-with-pytrees.md index 646944124cf5..3f05ec74ccaf 100644 --- a/docs/tutorials/working-with-pytrees.md +++ b/docs/tutorials/working-with-pytrees.md @@ -155,7 +155,7 @@ def update(params, x, y): (pytrees-custom-pytree-nodes)= ## Custom pytree nodes -This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {meth}`jax.tree_util.register_pytree_node` with {func}`jax.tree_map`. +This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree_map`. Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example: @@ -186,7 +186,7 @@ except TypeError as e: As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively. -First, register a new type using {meth}`jax.tree_util.register_pytree_node`: +First, register a new type using {func}`jax.tree_util.register_pytree_node`: ```{code-cell} from jax.tree_util import register_pytree_node @@ -269,11 +269,11 @@ Notice that the `name` field now appears as a leaf, because all tuple elements a (pytree-and-jax-transformations)= ## Pytree and JAX's transformations -Many JAX functions, like {meth}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. +Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. -Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax,vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees. +Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax.vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees. -For example, if you pass the following input to {func}`jax,vmap` (note that the input arguments to a function are considered a tuple): +For example, if you pass the following input to {func}`jax.vmap` (note that the input arguments to a function are considered a tuple): ``` (a1, {"k1": a2, "k2": a3}) @@ -287,7 +287,7 @@ then you can use the following `in_axes` pytree to specify that only the `k2` ar The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree. -For example, if you have the same {func}`jax,vmap` input as above, but wish to only map over the dictionary argument, you can use: +For example, if you have the same {func}`jax.vmap` input as above, but wish to only map over the dictionary argument, you can use: ``` (None, 0) # equivalent to (None, {"k1": 0, "k2": 0}) @@ -299,7 +299,7 @@ Alternatively, if you want every argument to be mapped, you can write a single l 0 ``` -This happens to be the default `in_axes` value for {func}`jax,vmap`. +This happens to be the default `in_axes` value for {func}`jax.vmap`. The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as `out_axes` in {func}`jax.vmap`. @@ -312,9 +312,9 @@ For built-in pytree node types, the set of keys for any pytree node instance is JAX has the following `jax.tree_util.*` methods for working with key paths: -- {meth}`jax.tree_util.tree_flatten_with_path`: Works similarly to {meth}`jax.tree_util.tree_flatten`, but returns key paths. -- {meth}`jax.tree_util.tree_map_with_path``: Works similarly to {meth}`jax.tree_util.tree_map`, but the function also takes key paths as arguments. -- {meth}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression. +- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree_util.tree_flatten`, but returns key paths. +- {func}`jax.tree_util.tree_map_with_path``: Works similarly to {func}`jax.tree_util.tree_map`, but the function also takes key paths as arguments. +- {func}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression. For example, one use case is to print debugging information related to a certain leaf value: @@ -336,7 +336,7 @@ To express key paths, JAX provides a few default key types for the built-in pytr * `DictKey(key: Hashable)`: For dictionaries. * `GetAttrKey(name: str)`: For `namedtuple`s and preferably custom pytree nodes (more in the next section) -You are free to define your own key types for your custom nodes. They will work with {meth}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression. +You are free to define your own key types for your custom nodes. They will work with {func}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression. ```{code-cell} for key_path, _ in flattened: