In [1]:
from flax import nnx
import jax
import jax.numpy as jnp

We're going to review what's going on with NNX randomness under the hood.

NNX uses a `nnx.Rngs` object in order to handle all random number generation". This object has "streams" of random numbers KEYS that can be utilized for initializing weights, dropping out connections in the model i.e. dropoput, etc.

You specify the seed as `params=()`, and from here that

`nnx.Rngs` contains "streams of random numbers". What this means technically is that `nnx.Rngs` is an infinite, procedurally generated list of unique keys

passing params specifies parameter initialization, which is used for e.g. model weights

In [12]:
nnx.Rngs(params=0)

Rngs( # RngState: 2 (12 B)
  params=RngStream( # RngState: 2 (12 B)
    tag='params',
    key=RngKey( # 1 (8 B)
      value=Array((), dtype=key<fry>) overlaying:
      [0 0],
      tag="'params'"
    ),
    count=RngCount( # 1 (4 B)
      value=Array(0, dtype=uint32),
      tag="'params'"
    )
  )
)

by calling `.params()`, this generates a new key each time a function is called

In [19]:
nnx.Rngs(params=0).params()

Array((), dtype=key<fry>) overlaying:
[1797259609 2579123966]

Now this process heavily differs from JAX random number generation, because you have to explicitly split the key from the original key in order to have variation in whatever the model is going to be learning

From Gemini:

In [21]:
class MLP(nnx.Module):
    def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
        # 1. Pass 'rngs' to the first layer.
        # It calls rngs.params(), gets Key #1, and advances the stream.
        self.layer1 = nnx.Linear(din, dhidden, rngs=rngs)

        # 2. Pass the SAME 'rngs' object to the second layer.
        # It calls rngs.params(), gets Key #2 (automatically unique).
        self.layer2 = nnx.Linear(dhidden, dout, rngs=rngs)

    def __call__(self, x):
        # Pass x through layer 1, then layer 2
        x = self.layer1(x)
        x = jax.nn.relu(x) # Standard activation function
        x = self.layer2(x)
        return x

# --- USAGE ---

# We only need ONE seed for the whole network
rngs = nnx.Rngs(params=0)

# We initialize the whole MLP at once
model = MLP(din=2, 
            dhidden=10, 
            dout=5, 
            rngs=rngs)

# Run it
y = model(jnp.ones((1, 2)))
print(y)
nnx.display(model)

[[-0.15895452 -1.1049775   1.2165147   0.62828314  0.571608  ]]
