Skip to content

Version 0.3.0

Choose a tag to compare

@chaoming0625 chaoming0625 released this 11 Mar 17:18
· 97 commits to main since this release

This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.

Breaking Changes

  • Python >= 3.11 required: Dropped support for Python 3.10. The requires-python field and classifiers now start at 3.11.
  • JAX >= 0.6.0 required: All dependency groups (cpu, cuda12, cuda13, tpu, testing) now mandate jax>=0.6.0.
  • Unified compilation cache in StatefulFunction: The four separate internal caches (_cached_jaxpr, _cached_out_shapes, _cached_jaxpr_out_tree, _cached_state_trace) have been consolidated into a single _compilation_cache storing _CachedCompilation objects. get_cache_stats() now returns {'compilation_cache': {...}} instead of four individual entries.
  • Immutable CacheKey replaces hashabledict: get_arg_cache_key() now returns a CacheKey (NamedTuple) instead of the mutable hashabledict. Code that directly inspected or constructed cache keys must be updated.
  • Removed internal _make_jaxpr function: The custom tracing implementation has been deleted in favor of using jax.make_jaxpr() directly (available in JAX >= 0.6.0).
  • Removed debug_depth and debug_context from GradientTransform: The depth and context parameters for NaN debugging no longer exist following the debug module rewrite.
  • Removed breakpoint_if function: The conditional breakpoint helper has been removed from brainstate.transform._debug.
  • Removed extend_axis_env_nd from compatible imports: This compatibility shim is no longer exported.

New Features

On-Device NaN/Inf Detection

  • Complete rewrite of the NaN debugging system (brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance.
  • Uses jax.debug.callback with thread-local storage to collect and report NaN findings.
  • Error tracebacks now point to the user's source code via source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations.
  • Recursive instrumentation of nested primitives (jit, cond, while, scan) for comprehensive NaN detection throughout the computation graph.
  • More compact and informative error messages via _format_nan_message().

JAX Traceback Filtering

  • Registered brainstate with JAX's traceback_util.register_exclusion() so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries.
  • Users can still see full tracebacks via JAX_TRACEBACK_FILTERING=off.

State Validation at Call Time

  • New _validate_state_shapes() method checks that current state shapes and dtypes match those recorded at compile time.
  • StatefulFunction.__call__() automatically validates before execution, catching state shape mismatches early with clear error messages.
  • Added static_argnums bounds validation — make_jaxpr() now raises ValueError if indices exceed the number of positional arguments.

New Compatible Import

  • Added mapped_aval import with version-based routing: jax.core.mapped_aval for JAX < 0.8.2, jax.extend.core.mapped_aval for >= 0.8.2.

Improvements

  • Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
  • Better cache key hashing: Dynamic args/kwargs are now flattened via jax.tree.flatten() before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g., Quantity).
  • Modern Python type annotations: Migrated from typing.Tuple, typing.List, typing.Dict, typing.Optional, typing.Union to built-in tuple, list, dict, X | None, X | Y syntax across the codebase.
  • IR visualization compatibility: Replaced direct jax.core.X references with compatible imports (Var, ClosedJaxpr, Jaxpr, JaxprEqn, Literal, DropVar) in the IR visualizer.
  • Deterministic error reporting: jax.debug.callback in _error_if.py now uses ordered=True for deterministic error callback ordering.
  • Graph operations cleanup: Major refactoring of _operation.py, _node.py, _convert.py, and _context.py with streamlined docstrings, better thread-safety documentation, and cleaner context managers.

Bug Fixes

  • Fixed Delay.__init__ initialization order: update_every is now initialized before register_entry is called, preventing attribute errors during entry registration (#135).
  • Fixed graph_to_tree private attribute access: Replaced internal _mapping access with public API usage in _convert.py.

Internal Changes

  • Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
  • Cleaned up TypeVar usage: removed unused C and Names aliases, renamed Node TypeVar to N, removed Hashable bound from type variables.
  • Removed unused tests (test_all_exports, test_function_imports_availability) from compatible import tests.
  • Rewrote debug and make_jaxpr test suites to match the new APIs.
  • IR optimization imports are now lazy-loaded inside make_jaxpr() only when ir_optimizations is configured.

CI/CD

  • Bumped actions/upload-artifact from v6 to v7.
  • Bumped actions/download-artifact from v7 to v8.

What's Changed

  • fix(nn): initialize update_every before register_entry by @Routhleck in #135
  • deps(deps): bump actions/upload-artifact from 6 to 7 by @dependabot[bot] in #133
  • deps(deps): bump actions/download-artifact from 7 to 8 by @dependabot[bot] in #132
  • Simplify JAX compat: use jax.make_jaxpr and aval helpers by @chaoming0625 in #137
  • Refactor graph ops, update JAX/Python requirements, improve tests by @chaoming0625 in #138
  • Add on-device NaN debugging and unify StatefulFunction cache by @chaoming0625 in #139

Full Changelog: v0.2.10...v0.3.0