Version 0.2.8
This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.
Compatibility
- JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses
jax.make_jaxprdirectly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.
Breaking Changes
- Removed
abstracted_axesparameter: Theabstracted_axesparameter has been removed from:StatefulFunction.__init__StatefulMapping.__init__make_jaxprfunction_make_jaxprinternal function
Improvements
-
Debug mode support: Added
debug_callmethod toStatefulFunctionfor proper execution whenjax.config.jax_disable_jitis enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation. -
Lazy loading optimization:
RandomStateimport in the_mappingmodule is now lazily loaded via_import_rand_state(), improving initial import performance and reducing circular dependency issues.
Internal Changes
- Removed unused imports (
annotate,api_boundaryfromjax._src) at module level; now imported only where needed - Removed internal helper functions
_broadcast_prefixand_flat_axes_specs - Simplified
_abstractifyfunction by removing abstracted axes handling - Updated example files to reflect API changes
What's Changed
- fix: compatiable with
jax>=0.8.2by @chaoming0625 in #124 - chore(changelog): update release notes for version 0.2.8 by @chaoming0625 in #125
Full Changelog: v0.2.7...v0.2.8