From fc19c5dcc136b17d7a2c406831126d93bbdac6d7 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 24 Sep 2024 16:53:13 -0700 Subject: [PATCH] Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change. PiperOrigin-RevId: 678446080 --- flax/core/__init__.py | 3 +-- flax/core/scope.py | 2 +- flax/core/tracers.py | 19 +++++++++---------- flax/nnx/tracers.py | 21 +++++++++++++++------ 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/flax/core/__init__.py b/flax/core/__init__.py index bca72392d..f90775f73 100644 --- a/flax/core/__init__.py +++ b/flax/core/__init__.py @@ -48,9 +48,8 @@ from .tracers import ( check_trace_level as check_trace_level, current_trace as current_trace, - trace_level as trace_level, ) from flax.typing import ( Array as Array, -) \ No newline at end of file +) diff --git a/flax/core/scope.py b/flax/core/scope.py index e056d6ddb..ea8a586b1 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -454,7 +454,7 @@ def __init__( self.flags = freeze({} if flags is None else flags) self._root = parent.root if parent else None - self.trace_level = tracers.trace_level(tracers.current_trace()) + self.trace_level = tracers.current_trace() self.rng_counters = {key: 0 for key in self.rngs} self.reservations = collections.defaultdict(set) diff --git a/flax/core/tracers.py b/flax/core/tracers.py index 9d8472bdc..fe2ff874c 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -20,18 +20,17 @@ def current_trace(): - """Returns the innermost Jax tracer.""" - return jax.core.find_top_trace(()) - - -def trace_level(main): - """Returns the level of the trace of -infinity if it is None.""" - if main: - return main.level - return float('-inf') + """Returns the current JAX state tracer.""" + if jax.__version_info__ <= (0, 4, 33): + top = jax.core.find_top_trace(()) + if top: + return top.level + else: + return float('-inf') + return jax.core.get_opaque_trace_state(convention="flax") def check_trace_level(base_level): - level = trace_level(current_trace()) + level = current_trace() if level != base_level: raise errors.JaxTransformError() diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index 3db066376..cc7859739 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -15,14 +15,17 @@ # Taken from flax/core/tracer.py 🏴‍☠️ -from jax.core import MainTrace, thread_local_state +import jax +import jax.core from flax.nnx import reprlib -def current_jax_trace() -> MainTrace: - """Returns the innermost Jax tracer.""" - return thread_local_state.trace_state.trace_stack.dynamic +def current_jax_trace(): + """Returns the Jax tracing state.""" + if jax.__version_info__ <= (0, 4, 33): + return jax.core.thread_local_state.trace_state.trace_stack.dynamic + return jax.core.get_opaque_trace_state(convention="nnx") class TraceState(reprlib.Representable): @@ -36,7 +39,10 @@ def jax_trace(self): return self._jax_trace def is_valid(self) -> bool: - return self._jax_trace is current_jax_trace() + if jax.__version_info__ <= (0, 4, 33): + return self._jax_trace is current_jax_trace() + + return self._jax_trace == current_jax_trace() def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') @@ -52,4 +58,7 @@ def __treescope_repr__(self, path, subtree_renderer): ) def __eq__(self, other): - return isinstance(other, TraceState) and self._jax_trace is other._jax_trace + if jax.__version_info__ <= (0, 4, 33): + return isinstance(other, TraceState) and self._jax_trace is other._jax_trace + + return isinstance(other, TraceState) and self._jax_trace == other._jax_trace