<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ! pip install -U flax treescope

In [None]:
from flax import nnx
import jax
from jax import numpy as jnp


# Simple Linear

In [None]:
class Linear(nnx.Module):

  def __init__(self, din:int, dout:int, *, rngs:nnx.Rngs): # * force the succeeding variables must be provided with keyword arg
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x:jnp.ndarray):
    return jnp.dot(x, self.w) + self.b

model = Linear(3, 5, rngs=nnx.Rngs(params=0))
x = jnp.ones((6, 3))
y = model(x)
print(y)
nnx.display(model)


[[1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]]
Linear(
  w=Param(
    value=Array(shape=(3, 5), dtype=float32)
  ),
  b=Param(
    value=Array(shape=(5,), dtype=float32)
  ),
  din=3,
  dout=5
)


# Stateful computation
How to update states during forward pass (batchNorm etc)

In [None]:
class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count.value += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')

counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)
