# Why NNX?

Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance.  We've learned a lot from our users over these years.

We introduced some ideas that have proven to be good:
 - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops.
 - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms)
 - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses.

One choice we made was to use functional "define by call" semantics for NN programming via the lazy (ie just in time) initialization of parameters.  This made for  concise (`compact`) implementation code and allowed for a single specification when transforming a layer.  It also aligned our API to be closer to Haiku.  However that lazy-init meant that the semantics of variables in Flax were non-pythonic and often surprising.  It also led to implementation complexity and obscured the core ideas of transformations on neural nets.

NNX is an attempt to keep the features that made Linen great while introducing some new principles:

- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references.
- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks.

### NNX is Pythonic
The main feature of NNX Module is that it adheres to Python semantics. This means that:

* fields are mutable so you can perform inplace updates
* Module references can be shared between multiple Modules
* Module construction implies parameter initialization
* Module methods can be called directly

In [1]:
from flax.experimental import nnx
import jax
from jax import random, numpy as jnp

class Count(nnx.Variable): pass

class CounterLinear(nnx.Module):
  def __init__(self, din, dout, *, rngs): # explicit RNG threading
    self.linear = nnx.Linear(din, dout, rngs=rngs)
    self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections

  def __call__(self, x):
    self.count += 1  # inplace stateful updates
    return self.linear(x)


model = CounterLinear(4, 4, rngs=nnx.Rngs(0))  # no special `init` method
y = model(jnp.ones((2, 4)))  # call methods directly

print(f'{model = }')

A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.


model = CounterLinear(
  linear=Linear(
    in_features=4,
    out_features=4,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x7f5d3c57baf0>,
    bias_init=<function zeros at 0x7f5ddf0e4ca0>,
    dot_general=<function dot_general at 0x7f5ddf79d4c0>
  )
)


Because NNX Modules contain their own state, they are very easily to inspect:

In [2]:
print(f'{model.count = }')
print(f'{model.linear.kernel = }')

model.count = Array(1, dtype=int32)
model.linear.kernel = Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],
       [ 0.38802508,  0.5655534 ,  0.4870657 ,  0.2267774 ],
       [-0.9015767 ,  0.24465278, -0.5844707 ,  0.18421966],
       [-0.06992685, -0.64693886,  0.20232596,  1.1200062 ]],      dtype=float32)


#### Intuitive Surgery

In NNX surgery can be done at the Module level by simply updating / replacing existing fields.

In [3]:
def load_pretrained():
  return nnx.Linear(4, 4, rngs=nnx.Rngs(42))  # pretend this is pretrained

model.linear = load_pretrained()  # you can replace modules

y = model(jnp.ones((2, 4)))

The benefit of this is not only that its easier than messing with dictionary structures, but can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before).

In [4]:
from functools import partial

rngs = nnx.Rngs(0)
model = nnx.Sequence(
  [
    nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs),
    partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),
    nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs),
    partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),
    lambda x: x.reshape((x.shape[0], -1)),  # flatten
    nnx.Linear(32 * 7 * 7, 10, rngs=rngs),
  ]
)

y = model(jnp.ones((2, 28, 28, 1)))

for i, layer in enumerate(model):
  if isinstance(layer, nnx.Conv):
    model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs)

y = model(jnp.ones((2, 28, 28, 1)))

Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc.

### Interacting with JAX is easy

While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations. NNX has two very simple APIs to interact with JAX: `split` and `merge`.

The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module.

In [5]:
model = CounterLinear(4, 4, rngs=nnx.Rngs(0))

state, static = model.split()

print(f'{state = }')

state = State({
  'count': Array(0, dtype=int32),
  'linear/bias': Array([0., 0., 0., 0.], dtype=float32),
  'linear/kernel': Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],
         [ 0.38802508,  0.5655534 ,  0.4870657 ,  0.2267774 ],
         [-0.9015767 ,  0.24465278, -0.5844707 ,  0.18421966],
         [-0.06992685, -0.64693886,  0.20232596,  1.1200062 ]],      dtype=float32)
})


The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. 

Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:

In [6]:
@jax.jit
def forward(state: nnx.State, x: jax.Array):
  model = static.merge(state)
  y = model(x)
  state, _ = model.split()
  return y, state

x = jnp.ones((2, 4))
y, state = forward(state, x)

print(f'{y.shape = }')
print(f'{state["count"] = }')

y.shape = (2, 4)
state["count"] = Array(1, dtype=int32)


#### Custom lifted Modules

By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes this so easy that its realistic to implement custom lifted Modules for specific use cases.

As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple:

In [7]:
class LinearEnsemble(nnx.Module):
  def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs):
    # get raw rng seeds
    keys = rngs.fork(num_models) # split all keys into `num_models`

    # define pure init fn and vmap
    def vmap_init(keys):
      return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split(
        nnx.Param, Count
      )

    params, counts, static = jax.vmap(
      vmap_init, in_axes=(0,), out_axes=(0, None, None)
    )(keys)
    # update wrapped submodule reference
    self.models = static.merge(params, counts)

  def __call__(self, x):
    # get module values, define pure fn
    params, counts, static = self.models.split(nnx.Param, Count)

    def vmap_apply(x, params, counts, static):
      model = static.merge(params, counts)
      y = model(x)
      params, counts, static = model.split(nnx.Param, Count)
      return y, params, counts, static

    # vmap and call
    y, params, counts, static = jax.vmap(
      vmap_apply, in_axes=(None, 0, None, None), out_axes=(0, 0, None, None)
    )(x, params, counts, static)
    # update wrapped module
    self.models.update(params, counts, static) # use `update` to integrate the new state
    return y

x = jnp.ones((4,))
ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0))

# forward pass
y = ensemble(x)

print(f'{y.shape = }')
print(f'{ensemble.models.count = }')
print(f'state = {jax.tree_map(jnp.shape, ensemble.get_state())}')

y.shape = (8, 4)
ensemble.models.count = Array(1, dtype=int32)
state = State({
  'models/count': (),
  'models/linear/bias': (8, 4),
  'models/linear/kernel': (8, 4, 4)
})


### Why Modules are not Pytrees?

Finally one of the most common questions we get is why NNX Modules are not Pytrees? Given the existance of Pytree-based NN frameworks like Equinox, Treex, [PytreeClass](https://github.com/ASEM000/PyTreeClass), it is a fair question.

The answer is that Pytrees assume value semantics (referencial transparency) while Modules assume reference semantics, and therefore its not a good idea for Modules to be Pytrees. As an example, lets take a look at what would happen if we allowed this very simple program to be valid:

In [8]:
@jax.jit
def f(m1: nnx.Module, m2: nnx.Module):
  return m1, m2

Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong?

There are two main problems with this:
* Shared references are not maintained, that is, if `m1.shared is m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`.
* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undisired asymmetry and `jit` would no longer be a no-op.