diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index ca838042a..7c177df67 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -33,7 +33,7 @@ import jax.numpy as jnp ``` ## The Module System -To begin lets see how to create a `Linear` Module using NNX. The main difference between +To begin, let's see how to create a `Linear` Module using NNX. The main difference between NNX and Module systems like Haiku or Linen is that in NNX everything is **explicit**. This means among other things that 1) the Module itself holds the state (e.g. parameters) directly, 2) the RNG state is threaded by the user, and 3) all shape information must be provided on @@ -43,7 +43,7 @@ As shown next, dynamic state is usually stored in `nnx.Param`s, and static state (all types not handled by NNX) such as integers or strings are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic state, although storing them inside `nnx.Variable`s such as `Param` is preferred. -Also, the `nnx.Rngs` object by can be used to get new unique keys based on a root +Also, the `nnx.Rngs` object can be used to get new unique keys based on a root key passed to the constructor. ```{code-cell} ipython3 @@ -58,16 +58,16 @@ class Linear(nnx.Module): return x @ self.w + self.b ``` -`nnx.Variable`'s inner values can be accessed using the `.value` property, however -for convenience they implement all numeric operators and can be used directly in -arithmetic expressions (as shown above). Additionally, Variables can passed +`nnx.Variable`'s inner values can be accessed using the `.value` property. However, +for convenience, they implement all numeric operators and can be used directly in +arithmetic expressions (as shown above). Additionally, Variables can be passed to any JAX function as they implement the `__jax_array__` protocol (as long as their inner value is a JAX array). -To actually initialize a Module you simply call the constructor, all the parameters -of a Module are usually created eagerly. Since Modules hold their own state methods -can be called directly without the need for a separate `apply` method, this is very -convenient for debugging as entire structure of the model can be inspected directly. +To actually initialize a Module, you simply call the constructor: all the parameters +of a Module are usually created eagerly. Since Modules hold their own state, methods +can be called directly without the need for a separate `apply` method. This is very +convenient for debugging as the entire structure of the model can be inspected directly. ```{code-cell} ipython3 model = Linear(2, 5, rngs=nnx.Rngs(params=0)) @@ -84,7 +84,7 @@ The above visualization by `nnx.display` is generated using the awesome [Penzai] ### Stateful Computation Implementing layers such as `BatchNorm` requires performing state updates during the -forward pass. To implement this in NNX you just create a `Variable` and update its +forward pass. To implement this in NNX, you just create a `Variable` and update its `.value` during the forward pass. ```{code-cell} ipython3 @@ -103,16 +103,16 @@ counter() print(f'{counter.count.value = }') ``` -Mutable references are usually avoided in JAX, however as we'll see in later sections +Mutable references are usually avoided in JAX. However, as we'll see in later sections, NNX provides sound mechanisms to handle them. +++ ### Nested Modules -As expected, Modules can be used to compose other Modules in a nested structure, these can -be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g. - `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that +As expected, Modules can be used to compose other Modules in a nested structure. These can +be assigned directly as attributes, or inside an attribute of any (nested) PyTree type e.g. + `list`, `dict`, `tuple`, etc. In the example below, we define a simple `MLP` Module that consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. ```{code-cell} ipython3 @@ -134,16 +134,16 @@ y = model(x=jnp.ones((3, 2))) nnx.display(model) ``` -In NNX `Dropout` is a stateful module that stores an `Rngs` object so that it can generate +In NNX, `Dropout` is a stateful module that stores an `Rngs` object so that it can generate new masks during the forward pass without the need for the user to pass a new key each time. +++ #### Model Surgery -NNX Modules are mutable by default, this means their structure can be changed at any time, -this makes model surgery quite easy as any submodule attribute can be replaced with anything -else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, -`Variable`s can also be modified or replaced / shared. +NNX Modules are mutable by default, meaning that their structure can be changed at any time. +This makes model surgery quite easy, as any submodule attribute can be replaced with anything +else e.g. new Modules, existing shared Modules, Modules of different types, etc. Moreover, +`Variable`s can also be modified or replaced/shared. The following example shows how to replace the `Linear` layers in the `MLP` model from before with `LoraLinear` layers. @@ -179,11 +179,11 @@ They are supersets of their equivalent JAX counterpart with the addition of being aware of the object's state and providing additional APIs to transform it. One of the main features of NNX Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is -propagated outisde as long as its legal within the transform rules. In practice this -means that NNX programs can be express using imperative code, highly simplifying +propagated outisde as long as it's legal within the transform rules. In practice, this +means that NNX programs can be expressed using imperative code, highly simplifying the user experience. -In the following example we define a `train_step` function that takes a `MLP` model, +In the following example, we define a `train_step` function that takes an `MLP` model, an `Optimizer`, and a batch of data, and returns the loss for that step. The loss and the gradients are computed using the `nnx.value_and_grad` transform over the `loss_fn`. The gradients are passed to the optimizer's `update` method to update @@ -214,15 +214,15 @@ print(f'{loss = }') print(f'{optimizer.step.value = }') ``` -Theres a couple of things happening in this example that are worth mentioning: -1. The updates to the `BatchNorm` and `Dropout` layer's state is automatically propagated +There are a couple of things happening in this example that are worth mentioning: +1. The updates to the `BatchNorm` and `Dropout` layers' state are automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside. -2. `optimizer` holds a mutable reference to `model`, this relationship is preserved +2. `optimizer` holds a mutable reference to `model`. This relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone. #### Scan over layers -Next lets take a look at a different example using `nnx.vmap` to create an +Next, let's take a look at a different example using `nnx.vmap` to create an `MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the input (scan over layers). @@ -256,7 +256,7 @@ nnx.display(model) ``` How do NNX transforms achieve this? To understand how NNX objects interact with -JAX transforms lets take a look at the Functional API. +JAX transforms, let's take a look at the Functional API. +++ @@ -337,11 +337,11 @@ print(f'{model.count.value = }') The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) -but its necessary to use the Functional API when crossing boundaries. +but it's necessary to use the Functional API when crossing boundaries. -**Why aren't Module's just Pytrees?** The main reason is that it is very -easy to lose track of shared references by accident this way, for example -if you pass two Module that have a shared Module through a JAX boundary +**Why aren't Modules just Pytrees?** The main reason is that it is very +easy to lose track of shared references by accident this way. For example, +if you pass two Modules that have a shared Module through a JAX boundary, you will silently lose that sharing. The Functional API makes this behavior explicit, and thus it is much easier to reason about. @@ -353,7 +353,7 @@ Seasoned Linen and Haiku users might recognize that having all the state in a single structure is not always the best choice as there are cases in which you might want to handle different subsets of the state differently. This a common occurrence when interacting with JAX transforms, for example, not all -the model's state can or should be differentiated when interacting which `grad`, +the model's state can or should be differentiated when interacting with `grad`, or sometimes there is a need to specify what part of the model's state is a carry and what part is not when using `scan`. @@ -368,9 +368,9 @@ graphdef, params, counts = nnx.split(model, nnx.Param, Count) nnx.display(params, counts) ``` -Note that filters must be exhaustive, if a value is not matched an error will be raised. +Note that filters must be exhaustive: If a value is not matched, an error will be raised. -As expected the `merge` and `update` methods naturally consume multiple States: +As expected, the `merge` and `update` methods naturally consume multiple States: ```{code-cell} ipython3 # merge multiple States