Version 0.3.0
This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.
Breaking Changes
- Python >= 3.11 required: Dropped support for Python 3.10. The
requires-pythonfield and classifiers now start at 3.11. - JAX >= 0.6.0 required: All dependency groups (
cpu,cuda12,cuda13,tpu,testing) now mandatejax>=0.6.0. - Unified compilation cache in
StatefulFunction: The four separate internal caches (_cached_jaxpr,_cached_out_shapes,_cached_jaxpr_out_tree,_cached_state_trace) have been consolidated into a single_compilation_cachestoring_CachedCompilationobjects.get_cache_stats()now returns{'compilation_cache': {...}}instead of four individual entries. - Immutable
CacheKeyreplaceshashabledict:get_arg_cache_key()now returns aCacheKey(NamedTuple) instead of the mutablehashabledict. Code that directly inspected or constructed cache keys must be updated. - Removed internal
_make_jaxprfunction: The custom tracing implementation has been deleted in favor of usingjax.make_jaxpr()directly (available in JAX >= 0.6.0). - Removed
debug_depthanddebug_contextfromGradientTransform: Thedepthandcontextparameters for NaN debugging no longer exist following the debug module rewrite. - Removed
breakpoint_iffunction: The conditional breakpoint helper has been removed frombrainstate.transform._debug. - Removed
extend_axis_env_ndfrom compatible imports: This compatibility shim is no longer exported.
New Features
On-Device NaN/Inf Detection
- Complete rewrite of the NaN debugging system (
brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance. - Uses
jax.debug.callbackwith thread-local storage to collect and report NaN findings. - Error tracebacks now point to the user's source code via
source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations. - Recursive instrumentation of nested primitives (
jit,cond,while,scan) for comprehensive NaN detection throughout the computation graph. - More compact and informative error messages via
_format_nan_message().
JAX Traceback Filtering
- Registered brainstate with JAX's
traceback_util.register_exclusion()so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries. - Users can still see full tracebacks via
JAX_TRACEBACK_FILTERING=off.
State Validation at Call Time
- New
_validate_state_shapes()method checks that current state shapes and dtypes match those recorded at compile time. StatefulFunction.__call__()automatically validates before execution, catching state shape mismatches early with clear error messages.- Added
static_argnumsbounds validation —make_jaxpr()now raisesValueErrorif indices exceed the number of positional arguments.
New Compatible Import
- Added
mapped_avalimport with version-based routing:jax.core.mapped_avalfor JAX < 0.8.2,jax.extend.core.mapped_avalfor >= 0.8.2.
Improvements
- Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
- Better cache key hashing: Dynamic args/kwargs are now flattened via
jax.tree.flatten()before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g.,Quantity). - Modern Python type annotations: Migrated from
typing.Tuple,typing.List,typing.Dict,typing.Optional,typing.Unionto built-intuple,list,dict,X | None,X | Ysyntax across the codebase. - IR visualization compatibility: Replaced direct
jax.core.Xreferences with compatible imports (Var,ClosedJaxpr,Jaxpr,JaxprEqn,Literal,DropVar) in the IR visualizer. - Deterministic error reporting:
jax.debug.callbackin_error_if.pynow usesordered=Truefor deterministic error callback ordering. - Graph operations cleanup: Major refactoring of
_operation.py,_node.py,_convert.py, and_context.pywith streamlined docstrings, better thread-safety documentation, and cleaner context managers.
Bug Fixes
- Fixed
Delay.__init__initialization order:update_everyis now initialized beforeregister_entryis called, preventing attribute errors during entry registration (#135). - Fixed
graph_to_treeprivate attribute access: Replaced internal_mappingaccess with public API usage in_convert.py.
Internal Changes
- Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
- Cleaned up TypeVar usage: removed unused
CandNamesaliases, renamedNodeTypeVar toN, removedHashablebound from type variables. - Removed unused tests (
test_all_exports,test_function_imports_availability) from compatible import tests. - Rewrote debug and make_jaxpr test suites to match the new APIs.
- IR optimization imports are now lazy-loaded inside
make_jaxpr()only whenir_optimizationsis configured.
CI/CD
- Bumped
actions/upload-artifactfrom v6 to v7. - Bumped
actions/download-artifactfrom v7 to v8.
What's Changed
- fix(nn): initialize update_every before register_entry by @Routhleck in #135
- deps(deps): bump actions/upload-artifact from 6 to 7 by @dependabot[bot] in #133
- deps(deps): bump actions/download-artifact from 7 to 8 by @dependabot[bot] in #132
- Simplify JAX compat: use jax.make_jaxpr and aval helpers by @chaoming0625 in #137
- Refactor graph ops, update JAX/Python requirements, improve tests by @chaoming0625 in #138
- Add on-device NaN debugging and unify StatefulFunction cache by @chaoming0625 in #139
Full Changelog: v0.2.10...v0.3.0