Version 0.2.3
his release introduces powerful IR (Intermediate Representation) optimization capabilities for JAX computation graphs, comprehensive state management refactoring for vectorized mapping operations, and extensive testing infrastructure improvements.
New Features
IR Optimization (brainstate.transform._ir_optim)
-
Intermediate Representation Optimization Module (876 lines): Complete suite of compiler-level optimizations for JAX computation graphs
constant_fold: Evaluates constant expressions at compile time, reducing runtime computationdead_code_elimination: Removes equations whose outputs are unused, reducing computation overheadcommon_subexpression_elimination: Identifies and reuses results of identical computationscopy_propagation: Eliminates unnecessary copy operations by propagating original variablesalgebraic_simplification: Applies algebraic identities (x+0=x, x*1=x, x-x=0, etc.)optimize_jaxpr: Orchestrates multiple optimization passes with configurable iteration and verbose mode
-
IdentitySet Class: Custom set implementation using object identity (
id()) instead of equality- Enables proper handling of JAX variables and Literals in optimization passes
- Implements
MutableSetinterface with full collection protocol support - Essential for tracking variable usage without relying on equality comparisons
Optimization Features
-
Interface Preservation: All optimizations preserve function input/output variables (invars/outvars)
- Identity equations automatically added when needed to maintain correct interfaces
- Uses
convert_element_typeprimitive with matching dtypes as identity operation - Ensures optimized functions remain drop-in replacements
-
Optimization Pipeline: Configurable multi-pass optimization with convergence detection
- Customizable optimization sequence via
optimizationsparameter - Automatic convergence detection when no more reductions possible
- Maximum iteration control with
max_iterationsparameter - Verbose mode with detailed statistics and progress tracking
- Customizable optimization sequence via
-
JAX Integration: Full support for JAX primitives and special cases
- Blacklist for primitives that shouldn't be folded (broadcast_in_dim, broadcast)
- Proper handling of
closed_callandscanprimitives - Support for both Jaxpr and ClosedJaxpr inputs
State Management Refactoring (brainstate.transform._mapping)
-
Renamed vmap to vmap2: Major refactoring of vectorized mapping implementation (647 lines)
- Enhanced state management with improved axis tracking
- Better error messages and validation
- Streamlined state value restoration logic
-
Old vmap Implementation Preserved (
_mapping_old.py, 579 lines): Legacy vmap with explicit state management- Exports original
vmapandvmap_new_statesfunctions - Maintains backward compatibility for existing code
- Specialized for stateful functions with explicit state parameters
- Exports original
Documentation
API Documentation
-
transform.rst: Added comprehensive IR Optimization section (24 lines)
- Detailed module description explaining compiler optimizations
- All 6 optimization functions documented with autosummary
- Clear explanation of benefits: reduced computation overhead, improved runtime performance
- Positioned between Compilation Tools and Gradient Computations sections
-
NumPy-style Docstrings: All optimization functions include:
- Comprehensive parameter descriptions with types and defaults
- Detailed return value documentation
- Notes sections explaining preservation of function interfaces
- Multiple practical examples demonstrating usage
- Algorithm descriptions for complex optimizations
- Cross-references between related functions
Enhancements
Optimization Pipeline
-
Progress Tracking: Verbose mode shows equation count changes after each optimization
- Displays initial, intermediate, and final equation counts
- Shows reduction statistics with percentages
- Indicates convergence detection
- Reports iteration counts
-
Validation: Runtime checks ensure optimization correctness
- Verifies input variables unchanged after optimization
- Validates output variables preserved
- Raises clear errors if interface violated
- Checks for valid optimization names
-
Flexibility: Customizable optimization sequences
- Apply all optimizations in recommended order (default)
- Select specific optimizations only
- Control iteration limits
- Toggle verbose output
JAX Integration
-
JaxprEqn Construction: Proper handling of required
ctxparameter- Uses
JaxprEqnContext(None, True)for identity equations - Ensures compatibility with JAX internal API
- Maintains proper equation structure
- Uses
-
Primitive Handling: Special cases for JAX primitives
- Blacklist for primitives that shouldn't be optimized
- Proper parameter extraction and validation
- Support for effects and source_info fields
Bug Fixes
- Fixed JaxprEqn constructor calls to include required
ctxparameter (7th positional argument) - Corrected import paths for
vmap2in test files and tutorials - Fixed
RandomState.uniform()calls to usesizeparameter instead ofshape - Enhanced test assertions for proper state axis handling
- Improved error messages for batch axis mismatches
Refactoring
Transform Module
-
Renamed Files:
vmap→vmap2in_mapping.py- Preserved original
vmapin_mapping_old.pyfor compatibility
-
Module Exports: Updated
__init__.pyto export both old and new vmap implementationsvmapfrom_mapping_old.py(legacy)vmap2from_mapping.py(new)vmap_new_statesfrom both modules
What's Changed
- Introduce JAXPR optimizations and enhance stateful mapping by @chaoming0625 in #108
Full Changelog: v0.2.2...v0.2.3