Skip to content

Version 0.2.3

Choose a tag to compare

@chaoming0625 chaoming0625 released this 13 Oct 14:47
· 131 commits to main since this release
2235c36

his release introduces powerful IR (Intermediate Representation) optimization capabilities for JAX computation graphs, comprehensive state management refactoring for vectorized mapping operations, and extensive testing infrastructure improvements.

New Features

IR Optimization (brainstate.transform._ir_optim)

  • Intermediate Representation Optimization Module (876 lines): Complete suite of compiler-level optimizations for JAX computation graphs

    • constant_fold: Evaluates constant expressions at compile time, reducing runtime computation
    • dead_code_elimination: Removes equations whose outputs are unused, reducing computation overhead
    • common_subexpression_elimination: Identifies and reuses results of identical computations
    • copy_propagation: Eliminates unnecessary copy operations by propagating original variables
    • algebraic_simplification: Applies algebraic identities (x+0=x, x*1=x, x-x=0, etc.)
    • optimize_jaxpr: Orchestrates multiple optimization passes with configurable iteration and verbose mode
  • IdentitySet Class: Custom set implementation using object identity (id()) instead of equality

    • Enables proper handling of JAX variables and Literals in optimization passes
    • Implements MutableSet interface with full collection protocol support
    • Essential for tracking variable usage without relying on equality comparisons

Optimization Features

  • Interface Preservation: All optimizations preserve function input/output variables (invars/outvars)

    • Identity equations automatically added when needed to maintain correct interfaces
    • Uses convert_element_type primitive with matching dtypes as identity operation
    • Ensures optimized functions remain drop-in replacements
  • Optimization Pipeline: Configurable multi-pass optimization with convergence detection

    • Customizable optimization sequence via optimizations parameter
    • Automatic convergence detection when no more reductions possible
    • Maximum iteration control with max_iterations parameter
    • Verbose mode with detailed statistics and progress tracking
  • JAX Integration: Full support for JAX primitives and special cases

    • Blacklist for primitives that shouldn't be folded (broadcast_in_dim, broadcast)
    • Proper handling of closed_call and scan primitives
    • Support for both Jaxpr and ClosedJaxpr inputs

State Management Refactoring (brainstate.transform._mapping)

  • Renamed vmap to vmap2: Major refactoring of vectorized mapping implementation (647 lines)

    • Enhanced state management with improved axis tracking
    • Better error messages and validation
    • Streamlined state value restoration logic
  • Old vmap Implementation Preserved (_mapping_old.py, 579 lines): Legacy vmap with explicit state management

    • Exports original vmap and vmap_new_states functions
    • Maintains backward compatibility for existing code
    • Specialized for stateful functions with explicit state parameters

Documentation

API Documentation

  • transform.rst: Added comprehensive IR Optimization section (24 lines)

    • Detailed module description explaining compiler optimizations
    • All 6 optimization functions documented with autosummary
    • Clear explanation of benefits: reduced computation overhead, improved runtime performance
    • Positioned between Compilation Tools and Gradient Computations sections
  • NumPy-style Docstrings: All optimization functions include:

    • Comprehensive parameter descriptions with types and defaults
    • Detailed return value documentation
    • Notes sections explaining preservation of function interfaces
    • Multiple practical examples demonstrating usage
    • Algorithm descriptions for complex optimizations
    • Cross-references between related functions

Enhancements

Optimization Pipeline

  • Progress Tracking: Verbose mode shows equation count changes after each optimization

    • Displays initial, intermediate, and final equation counts
    • Shows reduction statistics with percentages
    • Indicates convergence detection
    • Reports iteration counts
  • Validation: Runtime checks ensure optimization correctness

    • Verifies input variables unchanged after optimization
    • Validates output variables preserved
    • Raises clear errors if interface violated
    • Checks for valid optimization names
  • Flexibility: Customizable optimization sequences

    • Apply all optimizations in recommended order (default)
    • Select specific optimizations only
    • Control iteration limits
    • Toggle verbose output

JAX Integration

  • JaxprEqn Construction: Proper handling of required ctx parameter

    • Uses JaxprEqnContext(None, True) for identity equations
    • Ensures compatibility with JAX internal API
    • Maintains proper equation structure
  • Primitive Handling: Special cases for JAX primitives

    • Blacklist for primitives that shouldn't be optimized
    • Proper parameter extraction and validation
    • Support for effects and source_info fields

Bug Fixes

  • Fixed JaxprEqn constructor calls to include required ctx parameter (7th positional argument)
  • Corrected import paths for vmap2 in test files and tutorials
  • Fixed RandomState.uniform() calls to use size parameter instead of shape
  • Enhanced test assertions for proper state axis handling
  • Improved error messages for batch axis mismatches

Refactoring

Transform Module

  • Renamed Files:

    • vmapvmap2 in _mapping.py
    • Preserved original vmap in _mapping_old.py for compatibility
  • Module Exports: Updated __init__.py to export both old and new vmap implementations

    • vmap from _mapping_old.py (legacy)
    • vmap2 from _mapping.py (new)
    • vmap_new_states from both modules

What's Changed

  • Introduce JAXPR optimizations and enhance stateful mapping by @chaoming0625 in #108

Full Changelog: v0.2.2...v0.2.3