In [1]:
import numpy as np
import nnx
import jax
import jax.numpy as jnp

class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        # static attributes
        self.din = din
        self.dout = dout
        # variables
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))
        # other state
        self.jax_array = jnp.array(1)
        self.numpy_array = np.array(1)

    def __call__(self, x):
        return x @ self.w + self.b
    
ctx = nnx.context(0)
linear = Linear(2, 2, ctx=ctx)

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [2]:
state, moduledef = linear.partition()

print(state)
print(moduledef)

State({
  ('b',): Variable(
    collection=params,
    value=Array([0., 0.], dtype=float32)
  ),
  ('jax_array',): Array(1, dtype=int32, weak_type=True),
  ('numpy_array',): array(1),
  ('w',): Variable(
    collection=params,
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  )
})
ModuleDef(
  type=Linear,
  index=0,
  submodules=(),
  static_fields=(('din', 2), ('dout', 2))
)


In [3]:
class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        self.din = din
        self.dout = dout
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))
        # introduce a self-reference
        self.submodule = self

    def __call__(self, x):
        return x @ self.submodule.w + self.submodule.b
    
ctx = nnx.context(0)
linear = Linear(2, 2, ctx=ctx)

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2,
  submodule=Linear(...)
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [4]:
state, moduledef = linear.partition()

print(state)
print(moduledef)

State({
  ('b',): Variable(
    collection=params,
    value=Array([0., 0.], dtype=float32)
  ),
  ('w',): Variable(
    collection=params,
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  )
})
ModuleDef(
  type=Linear,
  index=0,
  submodules=(
    ('submodule', 0)
  ),
  static_fields=(('din', 2), ('dout', 2))
)


In [5]:
linear2 = moduledef.merge(state)

print(linear2.submodule is linear2)

True


In [6]:

class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        # static attributes
        self.din = din
        self.dout = dout
        # variables
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))

    def __call__(self, x):
        y = x @ self.w + self.b
        self.y = nnx.var("intermediate", y)
        return y
    
ctx = nnx.context(0)
linear = Linear(2, 2, ctx=ctx)

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [7]:
intermediates = linear.pop_state("intermediate")
state, moduledef = linear.partition()

print(intermediates)
print(state)

State({
  ('y',): Variable(
    collection=intermediate,
    value=Array([[0.63114893, 1.2928092 ],
           [0.63114893, 1.2928092 ]], dtype=float32)
  )
})
State({
  ('b',): Variable(
    collection=params,
    value=Array([0., 0.], dtype=float32)
  ),
  ('w',): Variable(
    collection=params,
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  )
})


In [8]:
class Foo(nnx.Module):
    def __init__(self, ctx: nnx.Context):
        self.dict = nnx.Map({
            "a": nnx.param(jnp.ones((2, 2))),
            "b": Linear(2, 2, ctx=ctx),
            "c": "aaa",
        })
        self.seq = nnx.Sequence([
            nnx.param(jnp.ones((2, 2))),
            Linear(2, 2, ctx=ctx),
            "bbb"
        ])

foo = Foo(ctx=ctx)

state, moduledef = foo.partition()

print(foo)
print(state)
print(moduledef)

Foo(
  dict=Map(
    b=Linear(
      din=2,
      dout=2
    ),
    c=aaa
  ),
  seq=Sequence(
    1=Linear(
      din=2,
      dout=2
    ),
    2=bbb
  )
)
State({
  ('dict', 'a'): Variable(
    collection=params,
    value=Array([[1., 1.],
           [1., 1.]], dtype=float32)
  ),
  ('dict', 'b', 'b'): Variable(
    collection=params,
    value=Array([0., 0.], dtype=float32)
  ),
  ('dict', 'b', 'w'): Variable(
    collection=params,
    value=Array([[0.58674836, 0.94791114],
           [0.71813095, 0.8924824 ]], dtype=float32)
  ),
  ('seq', '0'): Variable(
    collection=params,
    value=Array([[1., 1.],
           [1., 1.]], dtype=float32)
  ),
  ('seq', '1', 'b'): Variable(
    collection=params,
    value=Array([0., 0.], dtype=float32)
  ),
  ('seq', '1', 'w'): Variable(
    collection=params,
    value=Array([[0.35177028, 0.38650823],
           [0.9912766 , 0.1499139 ]], dtype=float32)
  )
})
ModuleDef(
  type=Foo,
  index=0,
  submodules=(
    ('dict', ModuleDef(
      type