In [1]:
import numpy as np
import nnx
import jax
import jax.numpy as jnp

class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        # static attributes
        self.din = din
        self.dout = dout
        # variables
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))
        # other state
        self.jax_array = jnp.array(1)
        self.numpy_array = np.array(1)

    def __call__(self, x):
        return x @ self.w + self.b
    
linear = Linear(2, 2, ctx=nnx.context(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [2]:
state, moduledef = linear.partition()

print(state)
print(moduledef)

State({
  'b': Node(
    value=Array([0., 0.], dtype=float32),
    collection='params',
    sharding=None
  ),
  'jax_array': Array(1, dtype=int32, weak_type=True),
  'numpy_array': array(1),
  'w': Node(
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32),
    collection='params',
    sharding=None
  )
})
ModuleDef(
  type=Linear,
  index=0,
  submodules=(),
  static_fields=(('din', 2), ('dout', 2))
)


In [3]:
class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        self.din = din
        self.dout = dout
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))
        # introduce a self-reference
        self.submodule = self

    def __call__(self, x):
        return x @ self.submodule.w + self.submodule.b
    
linear = Linear(2, 2, ctx=nnx.context(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2,
  submodule=Linear(...)
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [4]:
state, moduledef = linear.partition()

print(state)
print(moduledef)

State({
  'b': Node(
    value=Array([0., 0.], dtype=float32),
    collection='params',
    sharding=None
  ),
  'w': Node(
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32),
    collection='params',
    sharding=None
  )
})
ModuleDef(
  type=Linear,
  index=0,
  submodules=(
    ('submodule', 0)
  ),
  static_fields=(('din', 2), ('dout', 2))
)


In [6]:
linear2 = moduledef.merge(state)

linear2.submodule is linear2

True

In [7]:

class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        # static attributes
        self.din = din
        self.dout = dout
        # variables
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))

    def __call__(self, x):
        y = x @ self.w + self.b
        self.y = nnx.variable("intermediate", y)
        return y
    
linear = Linear(2, 2, ctx=nnx.context(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

AttributeError: module 'nnx' has no attribute 'var'

In [8]:
intermediates = linear.pop_state("intermediate")
state, moduledef = linear.partition()

print(intermediates)
print(state)

State({})
State({
  'b': Node(
    value=Array([0., 0.], dtype=float32),
    collection='params',
    sharding=None
  ),
  'w': Node(
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32),
    collection='params',
    sharding=None
  )
})
