Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
from .transforms.transforms import cond as cond
from .transforms.transforms import switch as switch
from .transforms.transforms import checkify as checkify
from .transforms.transforms import make_jaxpr as make_jaxpr
from .transforms.iteration import while_loop as while_loop
from .transforms.iteration import fori_loop as fori_loop
from .transforms.iteration import StateAxes as StateAxes
Expand Down
95 changes: 95 additions & 0 deletions flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,101 @@ def checkify_wrapper(*args, **kwargs):
return checkify_wrapper # type: ignore


@dataclasses.dataclass(eq=False)
class SimpleMakeJaxprFn:
f: tp.Callable[..., tp.Any]
graph: bool

def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())

@extract.treemap_copy_args
def __call__(self, *args, **kwargs):
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('make_jaxpr', args=args, kwargs=kwargs, out=out)
return out


@tp.overload
def make_jaxpr(
f: tp.Callable[..., A],
*,
graph: bool | None = None,
graph_updates: bool | None = None,
static_argnums: int | tp.Sequence[int] = (),
) -> tp.Callable[..., tp.Any]: ...

@tp.overload
def make_jaxpr(
*,
graph: bool | None = None,
graph_updates: bool | None = None,
static_argnums: int | tp.Sequence[int] = (),
) -> tp.Callable[[F], tp.Callable[..., tp.Any]]: ...

def make_jaxpr(
f: tp.Callable[..., A] | Missing = MISSING,
*,
graph: bool | None = None,
graph_updates: bool | None = None,
static_argnums: int | tp.Sequence[int] = (),
) -> tp.Callable[..., tp.Any] | tp.Callable[[F], tp.Callable[..., tp.Any]]:
"""A "lifted" version of `jax.make_jaxpr <https://jax.readthedocs.io/en/latest/jaxpr.html>`_
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
/ graph nodes as arguments.

Args:
f: the function to be transformed into a Jaxpr.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. ``nnx.make_jaxpr`` raises an error
if ``graph_updates=True``.
static_argnums: Optional, int or sequence of ints. Specifies which
positional argument(s) to treat as static (compile-time constant).
"""
if isinstance(f, Missing):
return functools.partial(
make_jaxpr,
graph=graph,
graph_updates=graph_updates,
static_argnums=static_argnums,
)

if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if graph_updates:
raise ValueError('nnx.make_jaxpr does not support graph_updates=True.')

f_call, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('make_jaxpr')
if graph is None:
graph = graphlib.set_graph_mode.current_value()

jaxpr_maker = jax.make_jaxpr(
SimpleMakeJaxprFn(f_call, graph=graph),
static_argnums=static_argnums,
)

@functools.wraps(f)
def jaxpr_wrapper(*args, **kwargs):
if graph:
args, kwargs = extract.to_tree2((args, kwargs))
extract.check_no_aliases('make_jaxpr', args=args, kwargs=kwargs)
jaxpr = jaxpr_maker(*args, **kwargs)
return jaxpr

return jaxpr_wrapper


@dataclasses.dataclass(eq=False)
class SimpleCondFn:
f: tp.Callable[..., tp.Any]
Expand Down
29 changes: 29 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5224,6 +5224,35 @@ def f(c):
np.testing.assert_allclose(out, 1)


class TestMakeJaxpr(parameterized.TestCase):

def test_make_jaxpr_graph_updates_error(self):
m = nnx.Dict(a=nnx.Param(jnp.array(1)))

def f(m):
return m['a'][...]

with self.assertRaisesRegex(
ValueError, 'nnx.make_jaxpr does not support graph_updates=True.'
):
nnx.make_jaxpr(f, graph=True, graph_updates=True)(m)

@parameterized.parameters(True, False)
def test_make_jaxpr_with_variable_update(self, graph):
class Counter(nnx.Module):
def __init__(self):
self.count = nnx.Variable(jnp.array(0))

def __call__(self):
self.count[...] += 1
return self.count[...]

m = Counter()
jaxpr = nnx.make_jaxpr(lambda m: m(), graph=graph, graph_updates=False)(m)
self.assertIsNotNone(jaxpr)
self.assertEqual(m.count[...], 0)


class TestBoundMethodTransforms(parameterized.TestCase):
def test_remat_with_bound_method_raises(self):
class M(nnx.Module):
Expand Down
Loading