Skip to content

Version 0.2.8

Choose a tag to compare

@chaoming0625 chaoming0625 released this 19 Dec 06:21
· 110 commits to main since this release
ac51f5f

This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.

Compatibility

  • JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses jax.make_jaxpr directly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.

Breaking Changes

  • Removed abstracted_axes parameter: The abstracted_axes parameter has been removed from:
    • StatefulFunction.__init__
    • StatefulMapping.__init__
    • make_jaxpr function
    • _make_jaxpr internal function

Improvements

  • Debug mode support: Added debug_call method to StatefulFunction for proper execution when jax.config.jax_disable_jit is enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation.

  • Lazy loading optimization: RandomState import in the _mapping module is now lazily loaded via _import_rand_state(), improving initial import performance and reducing circular dependency issues.

Internal Changes

  • Removed unused imports (annotate, api_boundary from jax._src) at module level; now imported only where needed
  • Removed internal helper functions _broadcast_prefix and _flat_axes_specs
  • Simplified _abstractify function by removing abstracted axes handling
  • Updated example files to reflect API changes

What's Changed

Full Changelog: v0.2.7...v0.2.8