In [8]:
from flax.experimental import nnx
import jax
import jax.numpy as jnp


class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # static attributes
    self.din = din
    self.dout = dout
    # variables
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    return x @ self.w + self.b


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [9]:
state, graphdef = linear.split()

print(state)
print(graphdef)

State({
  'b': Param(
    value=Array([0., 0.], dtype=float32)
  ),
  'w': Param(
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  )
})
GraphDef(
  type=Linear,
  index=0,
  static_fields=(('din', 2), ('dout', 2)),
  variables=(('b', Param(
      value=Empty
    )), ('w', Param(
      value=Empty
    ))),
  submodules=()
)


In [10]:
class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    # introduce a self-reference
    self.submodule = self

  def __call__(self, x):
    return x @ self.submodule.w + self.submodule.b


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2,
  submodule=Linear(...)
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [11]:
state, graphdef = linear.split()

print(state)
print(graphdef)

State({
  'b': Param(
    value=Array([0., 0.], dtype=float32)
  ),
  'w': Param(
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  )
})
GraphDef(
  type=Linear,
  index=0,
  static_fields=(('din', 2), ('dout', 2)),
  variables=(('b', Param(
      value=Empty
    )), ('w', Param(
      value=Empty
    ))),
  submodules=(
    ('submodule', 0)
  )
)


In [12]:
linear2 = graphdef.merge(state)

linear2.submodule is linear2

True

In [13]:
class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # static attributes
    self.din = din
    self.dout = dout
    # variables
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    y = x @ self.w + self.b
    self.y = nnx.Intermediate(y)
    return y


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [14]:
intermediates = linear.pop(nnx.Intermediate)
state, graphdef = linear.split()

print(intermediates)
print(state)

State({
  'y': Intermediate(
    value=Array([[0.63114893, 1.2928092 ],
           [0.63114893, 1.2928092 ]], dtype=float32)
  )
})
State({
  'b': Param(
    value=Array([0., 0.], dtype=float32)
  ),
  'w': Param(
    value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  )
})
