<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
!pip install --upgrade treescope

Collecting treescope
  Downloading treescope-0.1.5-py3-none-any.whl.metadata (5.9 kB)
Downloading treescope-0.1.5-py3-none-any.whl (174 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/174.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.2/174.2 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: treescope
Successfully installed treescope-0.1.5


In [1]:
from flax import nnx
import jax
from jax import numpy as jnp


# Simple Linear

In [2]:
class Linear(nnx.Module):

  def __init__(self, din:int, dout:int, *, rngs:nnx.Rngs): # * force the succeeding variables must be provided with keyword arg
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x:jnp.ndarray):
    return x @ self.w + self.b

model = Linear(3, 5, rngs=nnx.Rngs(params=0))
x = jnp.ones((6, 3))
y = model(x)
print(y)
nnx.display(model)


[[1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]
 [1.4492468 1.6527531 1.1650311 1.6205449 1.5738977]]
Linear(
  w=Param(
    value=Array(shape=(3, 5), dtype=float32)
  ),
  b=Param(
    value=Array(shape=(5,), dtype=float32)
  ),
  din=3,
  dout=5
)


# Stateful computation
How to update states during forward pass (batchNorm etc)

In [4]:
class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')

counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)


# Nested Modules

In [10]:
class MLP(nnx.Module):
  def __init__(self, din:int, dmid:int, dout:int, *, rngs:nnx.Rngs):
    self.lin1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.lin2 = Linear(dmid, dout, rngs=rngs)
  def __call__(self, x:jnp.ndarray):
    x = self.lin1(x)
    x = self.bn(x)
    x = self.dropout(x)
    x = nnx.gelu(x)
    x = self.lin2(x)
    return x

model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

y = model(x=jnp.ones((3, 2)))

nnx.display(model)
print(y)

MLP(
  lin1=Linear(
    w=Param(
      value=Array(shape=(2, 16), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    din=2,
    dout=16
  ),
  dropout=Dropout(rate=0.5, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(5, dtype=uint32),
        tag='default'
      )
    )
  )),
  bn=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    num_features=16,
    use_running_average=False,
    axis=-1,
    momentum=0.99,
    epsilon=1e-05,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    use_

# Model surgery

In [6]:
class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)

# Model surgery.
model.lin1 = LoraLinear(model.lin1, 4, rngs=rngs)
model.lin2 = LoraLinear(model.lin2, 4, rngs=rngs)

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

MLP(
  lin1=LoraLinear(
    linear=Linear(
      w=Param(
        value=Array(shape=(2, 32), dtype=float32)
      ),
      b=Param(
        value=Array(shape=(32,), dtype=float32)
      ),
      din=2,
      dout=32
    ),
    A=LoraParam(
      value=Array(shape=(2, 4), dtype=float32)
    ),
    B=LoraParam(
      value=Array(shape=(4, 32), dtype=float32)
    )
  ),
  dropout=Dropout(rate=0.5, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(9, dtype=uint32),
        tag='default'
      )
    )
  )),
  bn=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(32,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(32,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(32,), dtype=

# train simple MLP

In [7]:
import optax

model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
opt = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, opt)

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)
    return jnp.mean((y_pred - y)**2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)

  return loss

x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

print(f'{loss = }')
print(f'{optimizer.step.value = }')

loss = Array(1.0000308, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)


# vmap

In [11]:
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)

@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
y = forward(model, x)

print(f'{y.shape = }')
nnx.display(model)

y.shape = (3, 10)
MLP(
  bn=BatchNorm(
    axis=-1,
    axis_index_groups=None,
    axis_name=None,
    bias=Param(
      value=Array(shape=(5, 32), dtype=float32)
    ),
    bias_init=<function zeros at 0x7cbe3bd120e0>,
    dtype=None,
    epsilon=1e-05,
    mean=BatchStat(
      value=Array(shape=(5, 32), dtype=float32)
    ),
    momentum=0.99,
    num_features=32,
    param_dtype=<class 'jax.numpy.float32'>,
    scale=Param(
      value=Array(shape=(5, 32), dtype=float32)
    ),
    scale_init=<function ones at 0x7cbe3bd12290>,
    use_bias=True,
    use_fast_variance=True,
    use_running_average=False,
    use_scale=True,
    var=BatchStat(
      value=Array(shape=(5, 32), dtype=float32)
    )
  ),
  dropout=Dropout(rate=0.5, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      count=RngCount(
        tag='default',
        value=Array(shape=(5,), dtype=uint32)
      ),
      key=RngKey(
        tag='default',
        value=Arr