In [1]:
import nnx
import jax
import jax.numpy as jnp

class Linear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        self.din = din
        self.dout = dout
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))

    def __call__(self, x):
        self.interm = nnx.var("interm", x)
        return x @ self.w + self.b
    
ctx = nnx.Context(jax.random.PRNGKey(0))
linear = Linear(2, 2, ctx=ctx)

linear

Linear(
  din=2,
  dout=2,
  w=MutableVariable(
      collection='params',
      value=Array([[0.31696808, 0.55285215],
             [0.31418085, 0.7399571 ]], dtype=float32)
  ),
  b=MutableVariable(
      collection='params',
      value=Array([0., 0.], dtype=float32)
  )
)

In [2]:

y = linear(jnp.ones((2, 2)))
interm = linear.pop("interm")
y

Array([[0.6308594, 1.2929688],
       [0.6308594, 1.2929688]], dtype=float32)

In [3]:
interm

State({
  ('interm',): Variable(
      collection='interm',
      value=Array([[1., 1.],
             [1., 1.]], dtype=float32),
      sharding=None
  )
})

In [4]:
state, moduledef = linear.split()

print(state)
print(moduledef)

State({
  ('b',): Variable(
      collection='params',
      value=Array([0., 0.], dtype=float32),
      sharding=None
  ),
  ('w',): Variable(
      collection='params',
      value=Array([[0.31696808, 0.55285215],
             [0.31418085, 0.7399571 ]], dtype=float32),
      sharding=None
  )
})
ModuleDef(
  type=Linear,
  index=0,
  submodules=(),
  static_fields=(('din', 2), ('dout', 2))
)


In [5]:
class SelfReferenceLinear(nnx.Module):
    def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
        self.din = din
        self.dout = dout
        self.x = jnp.array(1)
        self.w = nnx.param(jax.random.uniform(ctx.make_rng("params"), (din, dout)))
        self.b = nnx.param(jnp.zeros((dout,)))
        self.submodule = self

    def __call__(self, x):
        return x @ self.submodule.w + self.submodule.b
    
ctx = nnx.Context(jax.random.PRNGKey(0))
linear = SelfReferenceLinear(2, 2, ctx=ctx)

y = linear(jnp.ones((2, 2)))
y

Array([[0.6308594, 1.2929688],
       [0.6308594, 1.2929688]], dtype=float32)

In [6]:
state, moduledef = linear.split()

print(state)
print(moduledef)

State({
  ('b',): Variable(
      collection='params',
      value=Array([0., 0.], dtype=float32),
      sharding=None
  ),
  ('w',): Variable(
      collection='params',
      value=Array([[0.31696808, 0.55285215],
             [0.31418085, 0.7399571 ]], dtype=float32),
      sharding=None
  ),
  ('x',): Array(1, dtype=int32, weak_type=True)
})
ModuleDef(
  type=SelfReferenceLinear,
  index=0,
  submodules=(('submodule', 0),),
  static_fields=(('din', 2), ('dout', 2))
)


In [7]:
linear2 = moduledef.merge(state)
linear2.submodule is linear2

True

In [8]:
class Foo(nnx.Module):
    def __init__(self, ctx: nnx.Context):
        self.dict = nnx.Map({
            "a": nnx.param(jnp.ones((2, 2))),
            "b": nnx.Linear(2, 2, ctx=ctx),
            "c": "aaa",
        })
        self.seq = nnx.Seq([
            nnx.param(jnp.ones((2, 2))),
            nnx.Linear(2, 2, ctx=ctx),
            "bbb"
        ])

foo = Foo(ctx=ctx)

state, moduledef = foo.split()

print(state)
print(moduledef)

State({
  ('dict', 'a'): Variable(
      collection='params',
      value=Array([[1., 1.],
             [1., 1.]], dtype=float32),
      sharding=None
  ),
  ('dict', 'b', 'bias'): Variable(
      collection='params',
      value=Array([0., 0.], dtype=float32),
      sharding=None
  ),
  ('dict', 'b', 'kernel'): Variable(
      collection='params',
      value=Array([[0.16806164, 1.1717631 ],
             [0.44064113, 0.92326844]], dtype=float32),
      sharding=None
  ),
  ('seq', '0'): Variable(
      collection='params',
      value=Array([[1., 1.],
             [1., 1.]], dtype=float32),
      sharding=None
  ),
  ('seq', '1', 'bias'): Variable(
      collection='params',
      value=Array([0., 0.], dtype=float32),
      sharding=None
  ),
  ('seq', '1', 'kernel'): Variable(
      collection='params',
      value=Array([[-0.35356975,  0.06913038],
             [-1.2359275 ,  0.2171132 ]], dtype=float32),
      sharding=None
  )
})
ModuleDef(
  type=Foo,
  index=0,
  submodules=(('d