Version 0.2.2
This release focuses on enhancing hidden state management for recurrent neural networks and eligibility trace-based learning, along with comprehensive testing and documentation improvements.
New Features
Hidden State Classes
-
HiddenGroupState: New class for managing multiple hidden states within a single array
- Stores multiple states in the last dimension of a single array
- Provides
get_value()andset_value()methods for accessing individual states by index or name - Optimized for LSTM-style architectures with multiple hidden components (h, c)
- Includes
name2indexmapping for convenient state access
-
HiddenTreeState: New class for managing multiple hidden states with different physical units
- Supports PyTree structure (dict or sequence) of hidden states
- Preserves physical units (e.g., voltage, current, conductance) via
brainunitintegration - Provides
name2unitandindex2unitmappings for unit tracking - Ideal for neuroscience models with heterogeneous state variables
- Maintains compatibility with BrainScale online learning
State Utilities
- maybe_state: New utility function for flexible value extraction
- Extracts values from State objects automatically
- Returns non-State values unchanged
- Simplifies writing functions that accept both states and raw values
Enhancements
State Classes
-
HiddenState: Enhanced documentation and type checking
- Restricted to
numpy.ndarray,jax.Array, andbrainunit.Quantitytypes only - Added comprehensive docstrings with examples
- Clarified equivalence to
brainscale.ETraceStatefor online learning - Improved error messages for invalid input types
- Restricted to
-
BatchState: Now properly exported in the public API
- Available via
brainstate.BatchState - Enhanced documentation for batch data management
- Available via
Documentation
-
API Reference: Completely reorganized
brainstate.rstdocumentation- Organized into 6 major sections: Core State Classes, State Management, State Utilities, Error Handling, and Submodules
- Added detailed descriptions for each section and subsection
- Included comprehensive bullet-point summaries for all APIs
- Enhanced deprecation warnings with clear migration paths
- Added module-level descriptions for all submodules
-
State Classes: Enhanced documentation for all state types
- Added detailed use case descriptions
- Included practical examples for each state type
- Clarified semantic distinctions between state types
- Documented integration with JAX transformations
-
JAX Transformations: Improved documentation for stateful transforms
- Enhanced docstrings for
jit,grad,vmap,scan, and other transforms - Added examples showing state management patterns
- Documented state tracing behavior
- Clarified interaction with
StateTraceStack
- Enhanced docstrings for
Transform System
-
Enhanced State Finding: New
_find_state.pymodule for automatic state discovery- Improved state detection in nested structures
- Better handling of state dependencies
- Enhanced error messages for state-related issues
-
StatefulFunction: Major enhancements to
make_jaxprfunctionality- Improved Jaxpr generation for stateful computations
- Better handling of state read/write tracking
- Enhanced debugging support
-
Mapping Transformations: Significant refactoring of
vmapandpmap- Improved state management across vectorized operations
- Better handling of state broadcasting
- Enhanced error reporting for mapping operations
Random Number Generation
-
Module Reorganization: Complete refactoring of random module structure
- Renamed
_rand_funs.pyto_fun.py - Renamed
_rand_seed.pyto_seed.py - Renamed
_rand_state.pyto_state.py - Extracted distribution implementations to new
_impl.pymodule (691 lines)
- Renamed
-
Improved Random State: Enhanced
RandomStateclass with better state management- Simplified implementation (reduced from 534 to ~300 lines)
- Better integration with JAX's random number generation
- Improved thread safety and state isolation
Testing
- Comprehensive Test Suite: Added 102 tests covering all state functionality
- TestBasicState (13 tests): Core State class operations
- TestShortTermState (2 tests): Short-term state behavior
- TestLongTermState (2 tests): Long-term state behavior
- TestParamState (2 tests): Parameter state usage patterns
- TestBatchState (2 tests): Batch state functionality
- TestHiddenState (7 tests): Hidden state with different array types
- TestHiddenGroupState (9 tests): Multiple hidden state management
- TestHiddenTreeState (12 tests): PyTree hidden states with units
- TestFakeState (4 tests): Lightweight state alternative
- TestStateDictManager (6 tests): State collection management
- TestStateTraceStack (11 tests): State tracing and recovery
- TestTreefyState (6 tests): PyTree state references
- TestContextManagers (6 tests): State context managers
- TestStateCatcher (8 tests): State catching utilities
- TestIntegrationScenarios (5 tests): Real-world use cases
Bug Fixes
- Fixed
HiddenGroupState.set_value()to work correctly with JAX arrays - Improved error handling in hidden state value validation
- Enhanced type checking for hidden state initialization
Documentation
Tutorial Reorganization
-
Basics Tutorials: Complete rewrite and expansion
01_getting_started.ipynb: Enhanced introduction with practical examples02_state_management.ipynb: Comprehensive state management guide03_random_numbers.ipynb: In-depth random number generation tutorial
-
Neural Networks Tutorials: Restructured and expanded
01_module_basics.ipynb: New comprehensive module system guide02_basic_layers.ipynb: Enhanced layer documentation with examples03_activations_normalization.ipynb: Detailed activation and normalization guide04_recurrent_networks.ipynb: New RNN tutorial with practical examples05_dynamics_systems.ipynb: New dynamical systems tutorial
-
Examples: Reorganized and enhanced
- Renamed
10_image_classification.ipynbto01_image_classification.ipynb - Renamed
11_sequence_modeling.ipynbto02_sequence_modeling.ipynb - Added
03_brain_inspired_computing.ipynb: New brain-inspired computing examples - Renamed
18_optimization_tricks.ipynbto04_optimization_tricks.ipynb - Renamed
19_model_deployment.ipynbto05_model_deployment.ipynb
- Renamed
-
Transforms Tutorials: Reorganized for better flow
01_jit_compilation.ipynb: New comprehensive JIT guide02_automatic_differentiation.ipynb: Enhanced autodiff tutorial03_vectorization.ipynb: Improved vmap/pmap guide04_loops_conditions.ipynb: Enhanced control flow guide05_other_transforms.ipynb: Other transformation utilities
-
Advanced Tutorials: Renumbered for clarity
01_graph_operations.ipynb(formerly14_graph_operations.ipynb)02_mixin_system.ipynb(formerly15_mixin_system.ipynb)03_typing_system.ipynb(formerly16_typing_system.ipynb)04_utilities.ipynb(formerly17_utilities.ipynb)
-
Migration Guides: Updated and simplified
01_migration_from_pytorch.ipynb: Enhanced PyTorch migration guide- Removed outdated BrainPy integration notebook
-
Supplementary: Reorganized
01_performance_optimization.ipynb02_debugging_tips.ipynb03_faq.ipynb: Updated FAQ with new content
API Documentation
- Enhanced module documentation in
nn.rstwith 306 line improvements - Updated
transform.rstwith new transform APIs - Improved
environ.rstandgraph.rstdocumentation
Refactoring
- Removed deprecated
eval_shapemodule and tests - Removed deprecated
_random.pytransform module - Cleaned up unused imports across all modules
- Improved code organization in neural network layers
- Enhanced type hints and docstrings throughout
Infrastructure
- Added development dependency for tutorial generation
- Updated benchmark scripts for performance testing
- Improved test coverage across transformation modules
What's Changed
- update logo by @xinzhu-L in #102
- Refactor random API: Extract distributions and rename modules by @chaoming0625 in #103
- Enhance stateful JAX transforms and update tutorials by @chaoming0625 in #104
- Updates by @oujago in #105
- Docs by @oujago in #106
- Enhance HiddenState and add HiddenGroupState, HiddenTreeState by @chaoming0625 in #107
New Contributors
Full Changelog: v0.2.0...v0.2.2