Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update nnx_basics.md #4050

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import jax.numpy as jnp
```

## The Module System
To begin lets see how to create a `Linear` Module using NNX. The main difference between
To begin, let's see how to create a `Linear` Module using NNX. The main difference between
NNX and Module systems like Haiku or Linen is that in NNX everything is **explicit**. This
means among other things that 1) the Module itself holds the state (e.g. parameters) directly,
2) the RNG state is threaded by the user, and 3) all shape information must be provided on
Expand All @@ -43,7 +43,7 @@ As shown next, dynamic state is usually stored in `nnx.Param`s, and static state
(all types not handled by NNX) such as integers or strings are stored directly.
Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic
state, although storing them inside `nnx.Variable`s such as `Param` is preferred.
Also, the `nnx.Rngs` object by can be used to get new unique keys based on a root
Also, the `nnx.Rngs` object can be used to get new unique keys based on a root
key passed to the constructor.

```{code-cell} ipython3
Expand All @@ -58,16 +58,16 @@ class Linear(nnx.Module):
return x @ self.w + self.b
```

`nnx.Variable`'s inner values can be accessed using the `.value` property, however
for convenience they implement all numeric operators and can be used directly in
arithmetic expressions (as shown above). Additionally, Variables can passed
`nnx.Variable`'s inner values can be accessed using the `.value` property. However,
for convenience, they implement all numeric operators and can be used directly in
arithmetic expressions (as shown above). Additionally, Variables can be passed
to any JAX function as they implement the `__jax_array__` protocol (as long as their
inner value is a JAX array).

To actually initialize a Module you simply call the constructor, all the parameters
of a Module are usually created eagerly. Since Modules hold their own state methods
can be called directly without the need for a separate `apply` method, this is very
convenient for debugging as entire structure of the model can be inspected directly.
To actually initialize a Module, you simply call the constructor: all the parameters
of a Module are usually created eagerly. Since Modules hold their own state, methods
can be called directly without the need for a separate `apply` method. This is very
convenient for debugging as the entire structure of the model can be inspected directly.

```{code-cell} ipython3
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
Expand All @@ -84,7 +84,7 @@ The above visualization by `nnx.display` is generated using the awesome [Penzai]
### Stateful Computation

Implementing layers such as `BatchNorm` requires performing state updates during the
forward pass. To implement this in NNX you just create a `Variable` and update its
forward pass. To implement this in NNX, you just create a `Variable` and update its
`.value` during the forward pass.

```{code-cell} ipython3
Expand All @@ -103,16 +103,16 @@ counter()
print(f'{counter.count.value = }')
```

Mutable references are usually avoided in JAX, however as we'll see in later sections
Mutable references are usually avoided in JAX. However, as we'll see in later sections,
NNX provides sound mechanisms to handle them.

+++

### Nested Modules

As expected, Modules can be used to compose other Modules in a nested structure, these can
be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g.
`list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that
As expected, Modules can be used to compose other Modules in a nested structure. These can
be assigned directly as attributes, or inside an attribute of any (nested) PyTree type e.g.
`list`, `dict`, `tuple`, etc. In the example below, we define a simple `MLP` Module that
consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer.

```{code-cell} ipython3
Expand All @@ -134,16 +134,16 @@ y = model(x=jnp.ones((3, 2)))
nnx.display(model)
```

In NNX `Dropout` is a stateful module that stores an `Rngs` object so that it can generate
In NNX, `Dropout` is a stateful module that stores an `Rngs` object so that it can generate
new masks during the forward pass without the need for the user to pass a new key each time.

+++

#### Model Surgery
NNX Modules are mutable by default, this means their structure can be changed at any time,
this makes model surgery quite easy as any submodule attribute can be replaced with anything
else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over,
`Variable`s can also be modified or replaced / shared.
NNX Modules are mutable by default, meaning that their structure can be changed at any time.
This makes model surgery quite easy, as any submodule attribute can be replaced with anything
else e.g. new Modules, existing shared Modules, Modules of different types, etc. Moreover,
`Variable`s can also be modified or replaced/shared.

The following example shows how to replace the `Linear` layers in the `MLP` model
from before with `LoraLinear` layers.
Expand Down Expand Up @@ -179,11 +179,11 @@ They are supersets of their equivalent JAX counterpart with the addition of
being aware of the object's state and providing additional APIs to transform
it. One of the main features of NNX Transforms is the preservation of reference semantics,
meaning that any mutation of the object graph that occurs inside the transform is
propagated outisde as long as its legal within the transform rules. In practice this
means that NNX programs can be express using imperative code, highly simplifying
propagated outisde as long as it's legal within the transform rules. In practice, this
means that NNX programs can be expressed using imperative code, highly simplifying
the user experience.

In the following example we define a `train_step` function that takes a `MLP` model,
In the following example, we define a `train_step` function that takes an `MLP` model,
an `Optimizer`, and a batch of data, and returns the loss for that step. The loss
and the gradients are computed using the `nnx.value_and_grad` transform over the
`loss_fn`. The gradients are passed to the optimizer's `update` method to update
Expand Down Expand Up @@ -214,15 +214,15 @@ print(f'{loss = }')
print(f'{optimizer.step.value = }')
```

Theres a couple of things happening in this example that are worth mentioning:
1. The updates to the `BatchNorm` and `Dropout` layer's state is automatically propagated
There are a couple of things happening in this example that are worth mentioning:
1. The updates to the `BatchNorm` and `Dropout` layers' state are automatically propagated
from within `loss_fn` to `train_step` all the way to the `model` reference outside.
2. `optimizer` holds a mutable reference to `model`, this relationship is preserved
2. `optimizer` holds a mutable reference to `model`. This relationship is preserved
inside the `train_step` function making it possible to update the model's parameters
using the optimizer alone.

#### Scan over layers
Next lets take a look at a different example using `nnx.vmap` to create an
Next, let's take a look at a different example using `nnx.vmap` to create an
`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the
input (scan over layers).

Expand Down Expand Up @@ -256,7 +256,7 @@ nnx.display(model)
```

How do NNX transforms achieve this? To understand how NNX objects interact with
JAX transforms lets take a look at the Functional API.
JAX transforms, let's take a look at the Functional API.

+++

Expand Down Expand Up @@ -337,11 +337,11 @@ print(f'{model.count.value = }')

The key insight of this pattern is that using mutable references is
fine within a transform context (including the base eager interpreter)
but its necessary to use the Functional API when crossing boundaries.
but it's necessary to use the Functional API when crossing boundaries.

**Why aren't Module's just Pytrees?** The main reason is that it is very
easy to lose track of shared references by accident this way, for example
if you pass two Module that have a shared Module through a JAX boundary
**Why aren't Modules just Pytrees?** The main reason is that it is very
easy to lose track of shared references by accident this way. For example,
if you pass two Modules that have a shared Module through a JAX boundary,
you will silently lose that sharing. The Functional API makes this
behavior explicit, and thus it is much easier to reason about.

Expand All @@ -353,7 +353,7 @@ Seasoned Linen and Haiku users might recognize that having all the state in
a single structure is not always the best choice as there are cases in which
you might want to handle different subsets of the state differently. This a
common occurrence when interacting with JAX transforms, for example, not all
the model's state can or should be differentiated when interacting which `grad`,
the model's state can or should be differentiated when interacting with `grad`,
or sometimes there is a need to specify what part of the model's state is a
carry and what part is not when using `scan`.

Expand All @@ -368,9 +368,9 @@ graphdef, params, counts = nnx.split(model, nnx.Param, Count)
nnx.display(params, counts)
```

Note that filters must be exhaustive, if a value is not matched an error will be raised.
Note that filters must be exhaustive: If a value is not matched, an error will be raised.

As expected the `merge` and `update` methods naturally consume multiple States:
As expected, the `merge` and `update` methods naturally consume multiple States:

```{code-cell} ipython3
# merge multiple States
Expand Down
Loading