# Tiny NNX
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cgarciae/nnx/blob/main/docs/tiny_nnx.ipynb)

A pedagogical implementation of NNX's core APIs.

## Core APIs

In [67]:
import hashlib
import typing as tp
import jax
import jax.numpy as jnp
from jax import random
import dataclasses

A = tp.TypeVar("A") 
M = tp.TypeVar("M", bound="Module")
Sharding = tp.Tuple[tp.Optional[str], ...]
KeyArray = random.KeyArray


class Variable(tp.Generic[A]):
    def __init__(
        self,
        collection: str,
        value: A,
        *,
        sharding: tp.Optional[Sharding] = None,
    ):
        self.value = value
        self.collection = collection
        self.sharding = sharding

    def __repr__(self) -> str:
        return f"Variable(value={self.value}, collection={self.collection}, sharding={self.sharding})"
    

jax.tree_util.register_pytree_node(
    Variable,
    lambda x: ((x.value,), (x.collection, x.sharding)),
    lambda metadata, value: Variable(metadata[0], value[0], sharding=metadata[1]),
)

class State(dict[str, Variable[tp.Any]]):
    def filter(self, collection: str) -> "State":
        return State(
            {
                path: variable
                for path, variable in self.items()
                if variable.collection == collection
            }
        )

    def __repr__(self) -> str:
        elems = ",\n  ".join(f"'{path}': {variable}".replace("\n", "\n    ") for path, variable in self.items())
        return f"State({{\n  {elems}\n}})"


jax.tree_util.register_pytree_node(
    State,
    # in reality, values and paths should be sorted by path
    lambda x: (tuple(x.values()), tuple(x.keys())),
    lambda paths, values: State(dict(zip(paths, values))),
)


@dataclasses.dataclass
class ModuleDef(tp.Generic[M]):
    type: tp.Type[M]
    index: int
    submodules: tp.Dict[str, tp.Union["ModuleDef[Module]", int]]
    static_fields: tp.Dict[str, tp.Any]

    def apply(self, state: State) -> tp.Callable[..., tuple[tp.Any, tuple[State, "ModuleDef[M]"]]]:
        def _apply(*args, **kwargs):
            module = self.merge(state)
            out = module(*args, **kwargs) # type: ignore
            return out, module.partition()
        return _apply

    def merge(self, state: State) -> M:
        module = ModuleDef._build_module_recursive(self, {})
        module.update_state(state)
        return module

    @staticmethod
    def _build_module_recursive(
        moduledef: tp.Union["ModuleDef[M]", int],
        index_to_module: tp.Dict[int, "Module"],
    ) -> M:
        if isinstance(moduledef, int):
            return index_to_module[moduledef] # type: ignore

        assert moduledef.index not in index_to_module

        # add a dummy module to the index to avoid infinite recursion
        module = object.__new__(moduledef.type)
        index_to_module[moduledef.index] = module

        submodules = {
            name: ModuleDef._build_module_recursive(submodule, index_to_module)
            for name, submodule in moduledef.submodules.items()
        }
        vars(module).update(moduledef.static_fields)
        vars(module).update(submodules)
        return module


class Module:
    def partition(self: M) -> tp.Tuple[State, ModuleDef[M]]:
        state = State()
        moduledef = Module._partition_recursive(
            module=self, module_id_to_index={}, path_parts=(), state=state)
        assert isinstance(moduledef, ModuleDef)
        return state, moduledef

    @staticmethod
    def _partition_recursive(
        module: M,
        module_id_to_index: tp.Dict[int, int],
        path_parts: tp.Tuple[str, ...],
        state: State,
    ) -> tp.Union[ModuleDef[M], int]:
        if id(module) in module_id_to_index:
            return module_id_to_index[id(module)]

        index = len(module_id_to_index)
        module_id_to_index[id(module)] = index

        submodules = {}
        static_fields = {}

        # iterate fields sorted by name to ensure deterministic order
        for name, value in sorted(vars(module).items(), key=lambda x: x[0]):
            value_path = (*path_parts, name)
            # if value is a Module, recurse
            if isinstance(value, Module):
                submoduledef = Module._partition_recursive(
                    value, module_id_to_index, value_path, state)
                submodules[name] = submoduledef
            # if value is a Variable, add to state
            elif isinstance(value, Variable):
                state["/".join(value_path)] = value
            else: # otherwise, add to static fields
                static_fields[name] = value

        return ModuleDef(
            type=type(module),
            index=index,
            submodules=submodules,
            static_fields=static_fields,
        )

    def update_state(self, state: State) -> None:
        for path, value in state.items():
            path_parts = path.split("/")
            Module._set_value_at_path(self, path_parts, value)

    @staticmethod
    def _set_value_at_path(module: "Module", path_parts: tp.Sequence[str], value: Variable[tp.Any]) -> None:
        if len(path_parts) == 1:
            setattr(module, path_parts[0], value)
        else:
            Module._set_value_at_path(getattr(module, path_parts[0]), path_parts[1:], value)


@dataclasses.dataclass
class Context:
    key: KeyArray
    count: int = 0
    count_path: tuple[int, ...] = ()

    def fork(self) -> "Context":
        """Forks the context, guaranteeing that all the random numbers generated
        will be different from the ones generated in the original context. Fork is
        used to create a new Context that can be passed to a JAX transform"""
        count_path = self.count_path + (self.count,)
        self.count += 1
        return Context(self.key, count_path=count_path)

    def make_rng(self) -> jax.Array:
        fold_data = self._stable_hash(self.count_path + (self.count,))
        self.count += 1
        return random.fold_in(self.key, fold_data) # type: ignore

    @staticmethod
    def _stable_hash(data: tuple[int, ...]) -> int:
        hash_str = " ".join(str(x) for x in data)
        _hash = hashlib.blake2s(hash_str.encode())
        hash_bytes = _hash.digest()
        # uint32 is represented as 4 bytes in big endian
        return int.from_bytes(hash_bytes[:4], byteorder="big")

# in the real NNX Context is not a pytree, instead
# it has a partition/merge API similar to Module
# but for simplicity we use a pytree here
jax.tree_util.register_pytree_node(
    Context,
    lambda x: ((x.key,),(x.count, x.count_path)),
    lambda metadata, value: Context(value[0], *metadata),
)

## Basic Layers

In [84]:
class Linear(Module):
    def __init__(self, din: int, dout: int, *, ctx: Context):
        self.din = din
        self.dout = dout
        key = ctx.make_rng()
        self.w = Variable("params", random.uniform(key, (din, dout)))
        self.b = Variable("params", jnp.zeros((dout,)))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.w.value + self.b.value

class BatchNorm(Module):
    def __init__(self, din: int, mu: float = 0.95):
        self.mu = mu
        self.scale = Variable("params", jax.numpy.ones((din,)))
        self.bias = Variable("params", jax.numpy.zeros((din,)))
        self.mean = Variable("batch_stats", jax.numpy.zeros((din,)))
        self.var = Variable("batch_stats", jax.numpy.ones((din,)))

    def __call__(self, x, train: bool) -> jax.Array:
        if train:
            axis = tuple(range(x.ndim - 1))
            mean = jax.numpy.mean(x, axis=axis, keepdims=True)
            var = jax.numpy.var(x, axis=axis, keepdims=True)
            # ema update
            self.mean.value = self.mu * self.mean.value + (1 - self.mu) * mean
            self.var.value = self.mu * self.var.value + (1 - self.mu) * var
        else:
            mean, var = self.mean.value, self.var.value

        scale, bias = self.scale.value, self.bias.value
        x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias
        return x
    
class Dropout(Module):
    def __init__(self, rate: float):
        self.rate = rate

    def __call__(self, x: jax.Array, *, train: bool, ctx: Context) -> jax.Array:
        if train:
            mask = random.bernoulli(ctx.make_rng(), (1 - self.rate), x.shape)
            x = x * mask / (1 - self.rate)
        return x

## Scan Over Layers Example

In [None]:

class Block(Module):
    def __init__(self, din: int, dout: int, *, ctx: Context):
        self.linear = Linear(din, dout, ctx=ctx)
        self.bn = BatchNorm(dout)
        self.dropout = Dropout(0.1)

    def __call__(self, x: jax.Array, *, train: bool, ctx: Context) -> jax.Array:
        x = self.linear(x)
        x = self.bn(x, train=train)
        x = jax.nn.gelu(x)
        x = self.dropout(x, train=train, ctx=ctx)
        return x
    

class ScanMLP(Module):
    def __init__(self, hidden_size: int, n_layers: int, *, ctx: Context):
        self.n_layers = n_layers
        
        # lift init
        key = random.split(ctx.make_rng(), n_layers - 1)
        moduledef: ModuleDef[Block] = None # type: ignore

        def init_fn(key):
            nonlocal moduledef
            state, moduledef = Block(hidden_size, hidden_size, ctx=Context(key)).partition()
            return state
        
        state = jax.vmap(init_fn)(key)
        self.layers = moduledef.merge(state)
        self.linear = Linear(hidden_size, hidden_size, ctx=ctx)

    def __call__(self, x: jax.Array, *, train: bool, ctx: Context) -> jax.Array:
        # lift call
        key: jax.Array = random.split(ctx.make_rng(), self.n_layers - 1) # type: ignore
        state, moduledef = self.layers.partition()

        def scan_fn(x, inputs: tuple[jax.Array, State]):
            key, state = inputs
            x, (state, _) = moduledef.apply(state)(
                x, train=train, ctx=Context(key)
            )
            return x, state
        
        x, state = jax.lax.scan(scan_fn, x, (key, state))
        self.layers.update_state(state)
        x = self.linear(x)
        return x

In [87]:

module = ScanMLP(hidden_size=10, n_layers=5, ctx=Context(random.PRNGKey(0)))
x = jax.random.normal(random.PRNGKey(0), (2, 10))
y = module(x, train=True, ctx=Context(random.PRNGKey(1)))

state, moduledef = module.partition()
print("state =", jax.tree_map(jnp.shape, state))
print("moduledef =", moduledef)

state = State({
  'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),
  'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),
  'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),
  'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),
  'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),
  'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),
  'linear/b': Variable(value=(10,), collection=params, sharding=None),
  'linear/w': Variable(value=(10, 10), collection=params, sharding=None)
})
moduledef = ModuleDef(type=<class '__main__.ScanMLP'>, index=0, submodules={'layers': ModuleDef(type=<class '__main__.Block'>, index=1, submodules={'bn': ModuleDef(type=<class '__main__.BatchNorm'>, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': ModuleDef(type=<class '__main__.Dropout'>, index=3, submodules={}, static_fiel

### Filtering State

In [89]:
# split
params = state.filter("params")
batch_stats = state.filter("batch_stats")
# merge
state = State({**params, **batch_stats})

print("params =", jax.tree_map(jnp.shape, params))
print("batch_stats =", jax.tree_map(jnp.shape, batch_stats))

params = State({
  'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),
  'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),
  'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),
  'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),
  'linear/b': Variable(value=(10,), collection=params, sharding=None),
  'linear/w': Variable(value=(10, 10), collection=params, sharding=None)
})
batch_stats = State({
  'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),
  'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None)
})
