In [1]:
from flax.experimental import nnx
import jax
import jax.numpy as jnp



In [2]:
class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

In [4]:
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(y)
nnx.display(model)

[[1.245453   0.74195766 0.8553282  0.6763327  1.2617068 ]]
Linear(
  w=Param(
    value=Array(shape=(2, 5), dtype=float32)
  ),
  b=Param(
    value=Array(shape=(5,), dtype=float32)
  ),
  din=2,
  dout=5
)
