In [1]:
from flax.experimental import nnx
import jax
import jax.numpy as jnp


class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # static attributes
    self.din = din
    self.dout = dout
    # variables
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    return x @ self.w.value + self.b.value


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [2]:
state, graphdef = linear.split()

print(state)
print(graphdef)

State({
  'w': Param(
    raw_value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  ),
  'b': Param(
    raw_value=Array([0., 0.], dtype=float32)
  )
})
GraphDef(
  type=Linear,
  index=0,
  attributes=('din', 'dout', 'w', 'b'),
  subgraphs={},
  static_fields={
    'din': 2,
    'dout': 2
  },
  variables={
    'w': VariableDef(
      type=Param,
      index=1,
      metadata={
        'get_value_hooks': (),
        'set_value_hooks': (),
        'create_value_hooks': (),
        'add_axis_hooks': (),
        'remove_axis_hooks': ()
      }
    ),
    'b': VariableDef(
      type=Param,
      index=2,
      metadata={
        'get_value_hooks': (),
        'set_value_hooks': (),
        'create_value_hooks': (),
        'add_axis_hooks': (),
        'remove_axis_hooks': ()
      }
    )
  },
  metadata=<class '__main__.Linear'>
)


In [3]:
class Nested(nnx.Module):
    def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
        self.linear = Linear(din, dout, rngs=rngs)
    
module = Nested(2, 2, rngs=nnx.Rngs(0))

state, static = module.split()
state

State({
  'linear': {
    'w': Param(
      raw_value=Array([[0.31696808, 0.55285215],
             [0.31418085, 0.7399571 ]], dtype=float32)
    ),
    'b': Param(
      raw_value=Array([0., 0.], dtype=float32)
    )
  }
})

In [4]:
class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    # introduce a self-reference
    self.submodule = self

  def __call__(self, x):
    return x @ self.submodule.w.value + self.submodule.b.value


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2,
  submodule=Linear(...)
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [5]:
state, graphdef = linear.split()

print(state)
print(graphdef)

State({
  'w': Param(
    raw_value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  ),
  'b': Param(
    raw_value=Array([0., 0.], dtype=float32)
  )
})
GraphDef(
  type=Linear,
  index=0,
  attributes=('din', 'dout', 'w', 'b', 'submodule'),
  subgraphs={
    'submodule': 0
  },
  static_fields={
    'din': 2,
    'dout': 2
  },
  variables={
    'w': VariableDef(
      type=Param,
      index=1,
      metadata={
        'get_value_hooks': (),
        'set_value_hooks': (),
        'create_value_hooks': (),
        'add_axis_hooks': (),
        'remove_axis_hooks': ()
      }
    ),
    'b': VariableDef(
      type=Param,
      index=2,
      metadata={
        'get_value_hooks': (),
        'set_value_hooks': (),
        'create_value_hooks': (),
        'add_axis_hooks': (),
        'remove_axis_hooks': ()
      }
    )
  },
  metadata=<class '__main__.Linear'>
)


In [6]:
linear2 = graphdef.merge(state)

linear2.submodule is linear2

True

In [7]:
class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # static attributes
    self.din = din
    self.dout = dout
    # variables
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    y = x @ self.w.value + self.b.value
    self.y = nnx.Intermediate(y)
    return y


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [8]:
intermediates = linear.pop(nnx.Intermediate)
state, graphdef = linear.split()

print(intermediates)
print(state)

State({
  'y': Intermediate(
    raw_value=Array([[0.63114893, 1.2928092 ],
           [0.63114893, 1.2928092 ]], dtype=float32)
  )
})
State({
  'w': Param(
    raw_value=Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32)
  ),
  'b': Param(
    raw_value=Array([0., 0.], dtype=float32)
  ),
  'y': Intermediate(
    raw_value=Empty
  )
})


In [9]:
class Foo(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs) -> None:
        self.bar = nnx.Linear(2, 2, rngs=rngs)
        self.baz = nnx.Linear(2, 2, rngs=rngs)

        # tie the weights
        self.baz.variables.kernel = self.bar.variables.kernel

model = Foo(rngs=nnx.Rngs(0))
state, static = model.split()

print(f'{state = }')
print(f'{static = }')

state = State({
  'bar': {
    'kernel': Array([[-0.3641057 ,  0.10192434],
           [-0.37005556,  0.49028906]], dtype=float32),
    'bias': Array([0., 0.], dtype=float32)
  },
  'baz': {
    'bias': Array([0., 0.], dtype=float32)
  }
})
static = GraphDef(
  type=Foo,
  index=0,
  attributes=('bar', 'baz'),
  subgraphs={
    'bar': GraphDef(
      type=Linear,
      index=1,
      attributes=('kernel', 'bias', 'in_features', 'out_features', 'use_bias', 'dtype', 'param_dtype', 'precision', 'kernel_init', 'bias_init', 'dot_general'),
      subgraphs={},
      static_fields={
        'in_features': 2,
        'out_features': 2,
        'use_bias': True,
        'dtype': None,
        'param_dtype': <class 'jax.numpy.float32'>,
        'precision': None,
        'kernel_init': <function variance_scaling.<locals>.init at 0x1391425f0>,
        'bias_init': <function zeros at 0x1294268c0>,
        'dot_general': <function dot_general at 0x11f78f1c0>
      },
      variables={
        'kern