Version 0.2.9
This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.
State Management
State Hook System
-
Global Hook Infrastructure: Comprehensive hook system for intercepting state operations
register_read_hook: Register hooks that execute when state values are readregister_write_hook: Register hooks that execute when state values are writtenregister_restore_hook: Register hooks that execute when state values are restoredHookManager: Thread-safe manager for organizing and executing hooks with priority supportHookContext: Context manager for scoped hook registration and execution- Enables advanced use cases: logging, debugging, value transformation, validation
-
Enhanced State Class: Improved state management with hook integration
- Automatic hook execution on read/write operations
- Better cache key handling for improved performance
- Enhanced thread safety and context management
- Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)
Neural Network Components
Parameter Management (brainstate.nn.Param and brainstate.nn.Const)
-
Renamed Classes: Simplified naming convention
ParaM→Param: Trainable parameter wrapperConstM→Const: Non-trainable constant wrapper
-
Enhanced Caching System: Improved parameter precomputation and caching
param_precomputecontext manager for efficient parameter transformation cachingcache()method for retrieving cached parameter values- Support for custom precompute functions
- Automatic cache invalidation and management
- 391 comprehensive tests for caching behavior
-
Hierarchical Parameter Data (
brainstate.nn.HiData): New module for structured parameter organizationdefine_param_data()method for declaring hierarchical parameter structures- Support for nested parameter groups
- Improved parameter surgery and manipulation
- Enhanced type hints and documentation
Module System Enhancements
-
ModuleMapper: New helper for vectorized module operations (formerly
Vmap2Module)- Simplified API for applying
vmap2to module methods - Automatic state management for vectorized operations
- Consistent interface with
Vmap2ModuleCaller - Comprehensive documentation with usage examples
- Simplified API for applying
-
Enhanced Module Methods:
parameters(): Iterate over all parameters in the module hierarchynamed_parameters(): Iterate over parameters with their qualified nameschildren(): Access direct child modulesnamed_children(): Access child modules with namesinit_all_states(): Initialize states with additional keyword arguments- Improved
Sequentialwithextend()andinsert()methods
Delay Mechanisms
-
Frequency-Controlled Updates: Enhanced
Delayclass with flexible update strategiesupdate_everyparameter: Control how often delay buffers are updated- Support for integer steps (update every N steps)
- Support for time-based updates with physical units (e.g.,
1*ms) - Automatic handling of unit conversions and validation
- Comprehensive tests covering various update strategies
-
Unified Delay Implementation: Refactored delay mechanism
- Ring buffer implementation for efficient historical value storage
- Support for linear interpolation
- Better handling of multi-dimensional inputs
- Improved integration with neural network modules
Regularization
-
Comprehensive Regularization Module (
brainstate.nn._regularization, 2840 lines):- Complete suite of regularization techniques
- L1, L2, and elastic net regularization
- Dropout variants
- Weight decay and other parameter constraints
- 1261 tests for regularization functionality
-
Transform Module (
brainstate.nn._transform, 1661 lines):- Advanced parameter transformations
- Quantization support
- Normalization techniques
- Integration with caching system
- 452 comprehensive tests
Transformations
Vectorization and Parallelization
-
Mapping Function Refactoring: Reorganized mapping implementations
- Renamed
_mapping.py→_mapping2.py(primaryvmap2implementation) - Renamed
_mapping_old.py→_mapping1.py(legacyvmapimplementation) - Added
_mapping3.py: Newpmap2implementation for parallelization vmap2_new_states: Helper for creating new states in vectorized operations- Relaxed return type requirements for more flexible mapping functions
- Renamed
-
Enhanced Documentation: Updated tutorials and API documentation
- Comprehensive
vmap2tutorial with practical examples - Enhanced parallelization documentation for
pmap2 - Updated state management guides
- Expanded gradient transformation documentation
- Comprehensive
Compatibility and Utilities
JAX Compatibility
- Enhanced JAX Integration: Improved compatibility with newer JAX versions
- Updated backend import for JAX version detection
- Enhanced
get_avalfunction for JAX version compatibility - Standardized
jit_named_scopearguments - Support for JAX 0.8.0+ in CI configuration
Utility Functions
-
Dataclass Support: Added
is_dataclassutility function inbrainstate.util.struct- Robust dataclass type checking
- Better handling of dataclass-based structures
-
Tracer Utilities: New
_tracers.pymodule for JAX tracer handlingcurrent_jax_trace(): Get current JAX trace context with version compatibility- Helper functions for working with JAX abstract values
Graph Operations
-
Context Management (
brainstate.graph._context):- New context management system for graph operations (119 lines)
TraceContextError: Specialized error class for tracing issues- Enhanced state tracking during graph construction
- 64 tests for context management
-
Conversion Utilities (
brainstate.graph._convert):- New conversion utilities for graph operations (278 lines)
- Better handling of graph transformations
- Improved node conversion logic
Random Number Generation
- Enhanced RandomState: Improved random number generation
- Better compatibility with newer JAX versions (98 lines of improvements)
- Enhanced state management for random keys
- Improved thread safety
- Better error messages and validation
Documentation
-
Comprehensive API Documentation: Expanded documentation across all modules
brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)environ.rst: Added 48 lines of documentation for environment state and keysnn.rst: Added 222 lines documenting neural network componentstransform.rst: Added 132 lines for gradient transformations and mapping functions
-
Tutorial Updates:
- Updated vectorization tutorial to reflect
vmap→vmap2transition - Enhanced examples with
ModuleMapperusage - Improved state management examples
- Updated vectorization tutorial to reflect
Breaking Changes
-
Renamed Functions and Classes:
ParaM→ParamConstM→Constvmap→vmap2(oldvmappreserved in_mapping1.pyfor compatibility)pmap→pmap2_param_data→_hidata
-
Parameter Naming Standardization:
fit_par→fitacross all modulesbrainscale→braintracein example files
-
Method Signature Changes:
init_all_states()now accepts additional keyword argumentsparam_precompute()signature updated to support caching and custom functions- Module initialization methods enhanced with keyword argument support
Testing
- Comprehensive Test Coverage: Added 4,000+ lines of new tests
- Thread safety tests: 346 tests ensuring thread-safe operations
- Hook system tests: 320 tests for state hooks
- State management tests: 924 tests expanded coverage
- Parameter caching tests: 391 tests for caching behavior
- Delay mechanism tests: 244 tests for delay functionality
- HiData tests: 463 tests for hierarchical data structures
- Module tests: 661 tests expanded coverage
- Regularization tests: 1,261 tests
- Transform tests: 452 tests
- Mapping tests: Updated for
vmap2andpmap2
Bug Fixes
- Fixed cache key handling in state management
- Improved error messages for missing states in gradient transformations
- Enhanced validation for delay update frequency
- Corrected import paths for better module organization
- Fixed compatibility issues with JAX 0.8.0+
Internal Changes
- Reorganized import statements across all modules for clarity
- Enhanced type hints throughout the codebase
- Improved code documentation with comprehensive docstrings
- Streamlined module exports in
__all__definitions - Better separation of concerns in module organization
What's Changed
- Enhance random utils and dataclass helpers for newer JAX by @chaoming0625 in #126
- Add State hook system and refactor nn modules and transforms by @chaoming0625 in #127
- Update vectorization docs for vmap2 and relax mapping return type by @chaoming0625 in #128
- Refactor Param and delay APIs and add ModuleMapper/pmap2 helpers by @chaoming0625 in #129
- Enhance Delay with frequency-controlled updates and unit-aware timing by @chaoming0625 in #130
Full Changelog: v0.2.8...v0.2.9