Version 0.2.4
This release introduces the new ArrayParam state type for parameter arrays with custom transformations, experimental BPU backend export support, enhanced JAXPR optimization capabilities, and improved module organization.
New Features
ArrayParam State Type
- ArrayParam Class: New state type for managing parameter arrays with advanced transformation control
- Supports custom transformations (e.g., quantization, normalization) that preserve array identity
- Enables
vmap,pmap, and other JAX transformations to correctly handle stateful parameters - Provides
identity()method that returns the raw array without applying custom transformations - Integrates seamlessly with existing State management infrastructure
- Useful for implementing quantization-aware training and other advanced parameter manipulations
- Comprehensive documentation with usage examples and best practices
Experimental BPU Backend Export (brainstate.experimental.gdiist_bpu)
-
BPU Backend Export Support: Complete infrastructure for exporting models to GDiist BPU hardware backend (727 lines)
export.py: Main export API withto_bpu()function for model conversionparser.py: Operation parser that analyzes JAXPR to identify operations and connections (305 lines)data.py: Data structures and analysis utilities for operation representation (215 lines)
-
Operation Parser Features:
- Automatic detection of operations from JAXPR equations using brainevent primitives
- Data flow analysis to identify connections between operations
- Support for various operation types: slice, add, multiply, and more
- Detailed analysis output showing equations, inputs, outputs, and connections
-
Analysis and Debugging Tools:
display_analysis_results(): Comprehensive visualization of parsed operations- Shows operation details including equation count, variable mappings, and connections
- Displays connection information with producer/consumer operations and variable details
- Example implementation in
examples/400_CUBA_2005_bpu.py
Enhancements
JAXPR Optimization Improvements
-
Enhanced Constant Folding:
- Better handling of literal values in constant folding optimization
- Improved detection and elimination of redundant literal operations
- More efficient constant propagation through computation graphs
-
Identity Equation Optimization:
- Optimized handling of
Literaloutputs to avoid unnecessary bridging equations - Improved identity equation creation for interface preservation
- Better handling of edge cases in optimization passes
- Optimized handling of
-
Error Handling:
- Added fallback source info utility for better error messages
- Fixed potential NoneType errors in equation handling
- Improved validation of optimization results
State Management
- Enhanced State Tests: Comprehensive test refactoring with improved coverage (454 tests)
- Better organization of state type tests
- More thorough validation of state behavior
- Enhanced test readability and maintainability
What's Changed
- deps(deps): bump actions/download-artifact from 5 to 6 by @dependabot[bot] in #109
- deps(deps): bump actions/upload-artifact from 4 to 5 by @dependabot[bot] in #110
- Add ArrayParam and integrate JAXPR optimizations by @chaoming0625 in #112
- Add experimental BPU backend export support by @chaoming0625 in #111
- Standardize module attribution for random and transform by @chaoming0625 in #113
Full Changelog: v0.2.3...v0.2.4