<!-- add open in colab -->
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cgarciae/nnx/blob/main/docs/tiny_nnx.ipynb)

In [31]:
import hashlib
import typing as tp
import jax
import jax.numpy as jnp
from jax import random
import dataclasses

A = tp.TypeVar("A") 
M = tp.TypeVar("M", bound="Module")
Sharding = tp.Tuple[tp.Optional[str], ...]
KeyArray = random.KeyArray


class Variable(tp.Generic[A]):
    def __init__(
        self,
        collection: str,
        value: A,
        *,
        sharding: tp.Optional[Sharding] = None,
    ):
        self.value = value
        self.collection = collection
        self.sharding = sharding

    def __repr__(self) -> str:
        return f"Variable(value={self.value}, collection={self.collection}, sharding={self.sharding})"
    

jax.tree_util.register_pytree_node(
    Variable,
    lambda x: ((x.value,), (x.collection, x.sharding)),
    lambda metadata, value: Variable(metadata[0], value[0], sharding=metadata[1]),
)

class State(dict[str, Variable[tp.Any]]):
    def filter(self, collection: str) -> "State":
        return State(
            {
                path: variable
                for path, variable in self.items()
                if variable.collection == collection
            }
        )


jax.tree_util.register_pytree_node(
    State,
    # in reality, values and paths should be sorted by path
    lambda x: (tuple(x.values()), tuple(x.keys())),
    lambda paths, values: State(dict(zip(paths, values))),
)


@dataclasses.dataclass
class ModuleDef(tp.Generic[M]):
    type: tp.Type[M]
    index: int
    submodules: tp.Dict[str, tp.Union["ModuleDef[Module]", int]]
    static_fields: tp.Dict[str, tp.Any]

    def apply(self, state: State) -> tp.Callable[..., tuple[tp.Any, tuple[State, "ModuleDef[M]"]]]:
        def _apply(*args, **kwargs):
            module = self.merge(state)
            out = module(*args, **kwargs) # type: ignore
            return out, module.partition()
        return _apply

    def merge(self, state: State) -> M:
        module = ModuleDef._build_module_recursive(self, {})
        module.update_state(state)
        return module

    @staticmethod
    def _build_module_recursive(
        moduledef: tp.Union["ModuleDef[M]", int],
        index_to_module: tp.Dict[int, "Module"],
    ) -> M:
        if isinstance(moduledef, int):
            return index_to_module[moduledef] # type: ignore

        assert moduledef.index not in index_to_module

        # add a dummy module to the index to avoid infinite recursion
        module = object.__new__(moduledef.type)
        index_to_module[moduledef.index] = module

        submodules = {
            name: ModuleDef._build_module_recursive(submodule, index_to_module)
            for name, submodule in moduledef.submodules.items()
        }
        vars(module).update(moduledef.static_fields)
        vars(module).update(submodules)
        return module


class Module:
    def partition(self: M) -> tp.Tuple[State, ModuleDef[M]]:
        state = State()
        moduledef = Module._partition_recursive(
            module=self, module_id_to_index={}, path_parts=(), state=state)
        assert isinstance(moduledef, ModuleDef)
        return state, moduledef

    @staticmethod
    def _partition_recursive(
        module: M,
        module_id_to_index: tp.Dict[int, int],
        path_parts: tp.Tuple[str, ...],
        state: State,
    ) -> tp.Union[ModuleDef[M], int]:
        if id(module) in module_id_to_index:
            return module_id_to_index[id(module)]

        index = len(module_id_to_index)
        module_id_to_index[id(module)] = index

        submodules = {}
        static_fields = {}

        # iterate fields sorted by name to ensure deterministic order
        for name, value in sorted(vars(module).items(), key=lambda x: x[0]):
            value_path = (*path_parts, name)
            # if value is a Module, recurse
            if isinstance(value, Module):
                submoduledef = Module._partition_recursive(
                    value, module_id_to_index, value_path, state)
                submodules[name] = submoduledef
            # if value is a Variable, add to state
            elif isinstance(value, Variable):
                state["/".join(value_path)] = value
            else: # otherwise, add to static fields
                static_fields[name] = value

        return ModuleDef(
            type=type(module),
            index=index,
            submodules=submodules,
            static_fields=static_fields,
        )

    def update_state(self, state: State) -> None:
        for path, value in state.items():
            path_parts = path.split("/")
            Module._set_value_at_path(self, path_parts, value)

    @staticmethod
    def _set_value_at_path(module: "Module", path_parts: tp.Sequence[str], value: Variable[tp.Any]) -> None:
        if len(path_parts) == 1:
            setattr(module, path_parts[0], value)
        else:
            Module._set_value_at_path(getattr(module, path_parts[0]), path_parts[1:], value)


@dataclasses.dataclass
class Context:
    key: KeyArray
    count: int = 0
    count_path: tuple[int, ...] = ()

    def fork(self) -> "Context":
        """Forks the context, guaranteeing that all the random numbers generated
        will be different from the ones generated in the original context. Fork is
        used to create a new Context that can be passed to a JAX transform"""
        count_path = self.count_path + (self.count,)
        self.count += 1
        return Context(self.key, count_path=count_path)

    def make_rng(self, name: str) -> jax.Array:
        fold_data = self._stable_hash(self.count_path + (self.count,))
        self.count += 1
        return random.fold_in(self.key, fold_data) # type: ignore

    @staticmethod
    def _stable_hash(data: tuple[int, ...]) -> int:
        hash_str = " ".join(str(x) for x in data)
        _hash = hashlib.blake2s(hash_str.encode())
        hash_bytes = _hash.digest()
        # uint32 is represented as 4 bytes in big endian
        return int.from_bytes(hash_bytes[:4], byteorder="big")

# in the real NNX Context is not a pytree, instead
# it has a partition/merge API similar to Module
# but for simplicity we use a pytree here
jax.tree_util.register_pytree_node(
    Context,
    lambda x: ((x.key,),(x.count, x.count_path)),
    lambda metadata, value: Context(value[0], *metadata),
)

In [30]:
class Linear(Module):
    def __init__(self, din: int, dout: int, *, ctx: Context):
        self.din = din
        self.dout = dout
        key = ctx.make_rng("params")
        self.w = Variable("params", random.uniform(key, (din, dout)))
        self.b = Variable("params", jnp.zeros((dout,)))
        self.count = Variable("counts", 0)

    def __call__(self, x: jax.Array) -> jax.Array:
        self.count.value += 1
        return x @ self.w.value + self.b.value

module = Linear(2, 2, ctx=Context(random.PRNGKey(0)))
y = module(jnp.ones((1, 2)))
y = module(jnp.ones((1, 2)))

state, moduledef = module.partition()
print(f"{state=}")
print(f"{moduledef=}")

state={'b': Variable(value=[0. 0.], collection=params, sharding=None), 'count': Variable(value=2, collection=counts, sharding=None), 'w': Variable(value=[[0.31696808 0.55285215]
 [0.31418085 0.7399571 ]], collection=params, sharding=None)}
moduledef=ModuleDef(type=<class '__main__.Linear'>, index=0, submodules={}, static_fields={'din': 2, 'dout': 2})


In [26]:
params = state.filter("params")
counts = state.filter("counts")

print(f"{params=}")
print(f"{counts=}")

params={'b': Variable(value=[0. 0.], collection=params, sharding=None), 'w': Variable(value=[[0.31696808 0.55285215]
 [0.31418085 0.7399571 ]], collection=params, sharding=None)}
counts={'count': Variable(value=2, collection=counts, sharding=None)}


In [27]:
y, (state, moduledef) = moduledef.apply(State(**params, **counts))(jnp.ones((1, 2)))
y

Array([[0.63114893, 1.2928092 ]], dtype=float32)

In [28]:
module = moduledef.merge(state)

assert module.count.value == 3