In [1]:
from flax.experimental import nnx
import jax
import jax.numpy as jnp


class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # static attributes
    self.din = din
    self.dout = dout
    # variables
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    return x @ self.w + self.b


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.


Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [2]:
state, graphdef = linear.split()

print(state)
print(graphdef)

State({
  'w': Array([[0.31696808, 0.55285215],
         [0.31418085, 0.7399571 ]], dtype=float32),
  'b': Array([0., 0.], dtype=float32)
})
GraphDef(
  type=Linear,
  index=0,
  subgraphs=(),
  static_fields=(('din', 2), ('dout', 2)),
  variables=(('w', Param(
      value=Empty
    )), ('b', Param(
      value=Empty
    ))),
  metadata=<class '__main__.Linear'>
)


In [3]:
class Nested(nnx.Module):
    def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
        self.linear = Linear(din, dout, rngs=rngs)
    
module = Nested(2, 2, rngs=nnx.Rngs(0))

state, static = module.split()
state

State({
  'linear': {
    'w': Array([[0.31696808, 0.55285215],
           [0.31418085, 0.7399571 ]], dtype=float32),
    'b': Array([0., 0.], dtype=float32)
  }
})

In [4]:
class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.din = din
    self.dout = dout
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    # introduce a self-reference
    self.submodule = self

  def __call__(self, x):
    return x @ self.submodule.w + self.submodule.b


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2,
  submodule=Linear(...)
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [5]:
state, graphdef = linear.split()

print(state)
print(graphdef)

State({
  'w': Array([[0.31696808, 0.55285215],
         [0.31418085, 0.7399571 ]], dtype=float32),
  'b': Array([0., 0.], dtype=float32)
})
GraphDef(
  type=Linear,
  index=0,
  subgraphs=(
    ('submodule', 0)
  ),
  static_fields=(('din', 2), ('dout', 2)),
  variables=(('w', Param(
      value=Empty
    )), ('b', Param(
      value=Empty
    ))),
  metadata=<class '__main__.Linear'>
)


In [6]:
linear2 = graphdef.merge(state)

linear2.submodule is linear2

True

In [7]:
class Linear(nnx.Module):

  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    # static attributes
    self.din = din
    self.dout = dout
    # variables
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    y = x @ self.w + self.b
    self.y = nnx.Intermediate(y)
    return y


linear = Linear(2, 2, rngs=nnx.Rngs(0))

y = linear(jnp.ones((2, 2)))

print(linear)
print(y)

Linear(
  din=2,
  dout=2
)
[[0.63114893 1.2928092 ]
 [0.63114893 1.2928092 ]]


In [8]:
intermediates = linear.pop(nnx.Intermediate)
state, graphdef = linear.split()

print(intermediates)
print(state)

State({
  'y': Array([[0.63114893, 1.2928092 ],
         [0.63114893, 1.2928092 ]], dtype=float32)
})
State({
  'w': Array([[0.31696808, 0.55285215],
         [0.31418085, 0.7399571 ]], dtype=float32),
  'b': Array([0., 0.], dtype=float32),
  'y': Empty
})
