Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions flax/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
from .tracers import (
check_trace_level as check_trace_level,
current_trace as current_trace,
trace_level as trace_level,
)

from flax.typing import (
Array as Array,
)
)
2 changes: 1 addition & 1 deletion flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def __init__(
self.flags = freeze({} if flags is None else flags)

self._root = parent.root if parent else None
self.trace_level = tracers.trace_level(tracers.current_trace())
self.trace_level = tracers.current_trace()

self.rng_counters = {key: 0 for key in self.rngs}
self.reservations = collections.defaultdict(set)
Expand Down
19 changes: 9 additions & 10 deletions flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@


def current_trace():
"""Returns the innermost Jax tracer."""
return jax.core.find_top_trace(())


def trace_level(main):
"""Returns the level of the trace of -infinity if it is None."""
if main:
return main.level
return float('-inf')
"""Returns the current JAX state tracer."""
if jax.__version_info__ <= (0, 4, 33):
top = jax.core.find_top_trace(())
if top:
return top.level
else:
return float('-inf')

return jax.core.get_opaque_trace_state(convention="flax")

def check_trace_level(base_level):
level = trace_level(current_trace())
level = current_trace()
if level != base_level:
raise errors.JaxTransformError()
21 changes: 15 additions & 6 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
# Taken from flax/core/tracer.py 🏴‍☠️


from jax.core import MainTrace, thread_local_state
import jax
import jax.core

from flax.nnx import reprlib


def current_jax_trace() -> MainTrace:
"""Returns the innermost Jax tracer."""
return thread_local_state.trace_state.trace_stack.dynamic
def current_jax_trace():
"""Returns the Jax tracing state."""
if jax.__version_info__ <= (0, 4, 33):
return jax.core.thread_local_state.trace_state.trace_stack.dynamic
return jax.core.get_opaque_trace_state(convention="nnx")


class TraceState(reprlib.Representable):
Expand All @@ -36,7 +39,10 @@ def jax_trace(self):
return self._jax_trace

def is_valid(self) -> bool:
return self._jax_trace is current_jax_trace()
if jax.__version_info__ <= (0, 4, 33):
return self._jax_trace is current_jax_trace()

return self._jax_trace == current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
Expand All @@ -52,4 +58,7 @@ def __treescope_repr__(self, path, subtree_renderer):
)

def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace
if jax.__version_info__ <= (0, 4, 33):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace

return isinstance(other, TraceState) and self._jax_trace == other._jax_trace
Loading