From c8bfd2a6311fb3ebed31caef305b7bd62b8790a0 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Thu, 10 Oct 2024 23:21:26 +0000 Subject: [PATCH] Update Flax NNX vs JAX Transformations guide --- docs_nnx/guides/jax_and_nnx_transforms.rst | 66 ++++++++++++---------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/docs_nnx/guides/jax_and_nnx_transforms.rst b/docs_nnx/guides/jax_and_nnx_transforms.rst index 2b5edbd54..0a21351ed 100644 --- a/docs_nnx/guides/jax_and_nnx_transforms.rst +++ b/docs_nnx/guides/jax_and_nnx_transforms.rst @@ -1,12 +1,11 @@ -Flax NNX vs JAX Transformations -========================== +Flax NNX vs JAX transformations +=============================== -.. attention:: - This page relates to the new Flax NNX API. - -In this guide, you will learn the differences using Flax NNX and JAX transformations, and how to -seamlessly switch between them or use them together. We will be focusing on the ``jit`` and -``grad`` function transformations in this guide. +This guide describes the differences between +`Flax NNX transformations `__ +and `JAX transformations `__, +and how to seamlessly switch between them or use them side-by-side. The examples here will focus on +``nnx.jit``, ``jax.jit``, ``nnx.grad`` and ``jax.grad`` function transformations (transforms). First, let's set up imports and generate some dummy data: @@ -18,27 +17,34 @@ First, let's set up imports and generate some dummy data: x = jax.random.normal(jax.random.key(0), (1, 2)) y = jax.random.normal(jax.random.key(1), (1, 3)) -Differences between NNX and JAX transformations -*********************************************** +Differences +*********** + +Flax NNX transformations can transform functions that are not pure and make mutations and +side-effects: +- Flax NNX transforms enable you to transform functions that take in Flax NNX graph objects as +arguments - such as ``nnx.Module``, ``nnx.Rngs``, ``nnx.Optimizer``, and so on - even those whose state +will be mutated. +- In comparison, these kinds of objects aren't recognized in JAX transformations. + +The Flax NNX `Functional API `_ +provides a way to convert graph structures to `pytrees `__ +and back. By doing this at every function boundary you can effectively use graph structures with any +JAX transforms and propagate state updates in a way consistent with functional purity. -The primary difference between Flax NNX and JAX transformations is that Flax NNX transformations allow you to -transform functions that take in Flax NNX graph objects as arguments (`Module`, `Rngs`, `Optimizer`, etc), -even those whose state will be mutated, whereas they aren't recognized in JAX transformations. -Therefore Flax NNX transformations can transform functions that are not pure and make mutations and -side-effects. +Flax NNX custom transforms, such as ``nnx.jit`` and ``nnx.grad``, simply remove the boilerplate, and +as a result the code looks stateful. -Flax NNX's `Functional API `_ -provides a way to convert graph structures to pytrees and back. By doing this at every function -boundary you can effectively use graph structures with any JAX transform and propagate state updates -in a way consistent with functional purity. Flax NNX custom transforms such as ``nnx.jit`` and ``nnx.grad`` -simply remove the boilerplate, as a result the code looks stateful. +Below is an example of using the ``nnx.jit`` and ``nnx.grad`` transforms compared to the +the code that uses ``jax.jit`` and ``jax.grad`` transforms. -Below is an example of using the ``nnx.jit`` and ``nnx.grad`` transformations compared to using the -``jax.jit`` and ``jax.grad`` transformations. Notice the function signature of Flax NNX-transformed -functions can accept the ``nnx.Linear`` module directly and can make stateful updates to the module, -whereas the function signature of JAX-transformed functions can only accept the pytree-registered -``State`` and ``GraphDef`` objects and must return an updated copy of them to maintain the purity of -the transformed function. +Notice that: + +- The function signature of Flax NNX-transformed functions can accept the ``nnx.Linear`` + ``nnx.Module`` instances directly and make stateful updates to the ``Module``. +- The function signature of JAX-transformed functions can only accept the pytree-registered + ``nnx.State`` and ``nnx.GraphDef`` objects, and must return an updated copy of them to maintain the + purity of the transformed function. .. codediff:: :title: Flax NNX transforms, JAX transforms @@ -79,11 +85,11 @@ the transformed function. graphdef, state = train_step(graphdef, state, x, y) #! -Mixing Flax NNX and JAX transformations +Mixing Flax NNX and JAX transforms ********************************** -Flax NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is -pure and has valid argument types that are recognized by JAX. +Both Flax NNX transforms and JAX transforms can be mixed together, so long as the JAX-transformed function +in your code is pure and has valid argument types that are recognized by JAX. .. codediff:: :title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad`` @@ -121,5 +127,3 @@ pure and has valid argument types that are recognized by JAX. graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) graphdef, state = train_step(graphdef, state, x, y) - -