Skip to content

Version 0.2.2

Choose a tag to compare

@chaoming0625 chaoming0625 released this 12 Oct 15:36
· 132 commits to main since this release
74c0b31

This release focuses on enhancing hidden state management for recurrent neural networks and eligibility trace-based learning, along with comprehensive testing and documentation improvements.

New Features

Hidden State Classes

  • HiddenGroupState: New class for managing multiple hidden states within a single array

    • Stores multiple states in the last dimension of a single array
    • Provides get_value() and set_value() methods for accessing individual states by index or name
    • Optimized for LSTM-style architectures with multiple hidden components (h, c)
    • Includes name2index mapping for convenient state access
  • HiddenTreeState: New class for managing multiple hidden states with different physical units

    • Supports PyTree structure (dict or sequence) of hidden states
    • Preserves physical units (e.g., voltage, current, conductance) via brainunit integration
    • Provides name2unit and index2unit mappings for unit tracking
    • Ideal for neuroscience models with heterogeneous state variables
    • Maintains compatibility with BrainScale online learning

State Utilities

  • maybe_state: New utility function for flexible value extraction
    • Extracts values from State objects automatically
    • Returns non-State values unchanged
    • Simplifies writing functions that accept both states and raw values

Enhancements

State Classes

  • HiddenState: Enhanced documentation and type checking

    • Restricted to numpy.ndarray, jax.Array, and brainunit.Quantity types only
    • Added comprehensive docstrings with examples
    • Clarified equivalence to brainscale.ETraceState for online learning
    • Improved error messages for invalid input types
  • BatchState: Now properly exported in the public API

    • Available via brainstate.BatchState
    • Enhanced documentation for batch data management

Documentation

  • API Reference: Completely reorganized brainstate.rst documentation

    • Organized into 6 major sections: Core State Classes, State Management, State Utilities, Error Handling, and Submodules
    • Added detailed descriptions for each section and subsection
    • Included comprehensive bullet-point summaries for all APIs
    • Enhanced deprecation warnings with clear migration paths
    • Added module-level descriptions for all submodules
  • State Classes: Enhanced documentation for all state types

    • Added detailed use case descriptions
    • Included practical examples for each state type
    • Clarified semantic distinctions between state types
    • Documented integration with JAX transformations
  • JAX Transformations: Improved documentation for stateful transforms

    • Enhanced docstrings for jit, grad, vmap, scan, and other transforms
    • Added examples showing state management patterns
    • Documented state tracing behavior
    • Clarified interaction with StateTraceStack

Transform System

  • Enhanced State Finding: New _find_state.py module for automatic state discovery

    • Improved state detection in nested structures
    • Better handling of state dependencies
    • Enhanced error messages for state-related issues
  • StatefulFunction: Major enhancements to make_jaxpr functionality

    • Improved Jaxpr generation for stateful computations
    • Better handling of state read/write tracking
    • Enhanced debugging support
  • Mapping Transformations: Significant refactoring of vmap and pmap

    • Improved state management across vectorized operations
    • Better handling of state broadcasting
    • Enhanced error reporting for mapping operations

Random Number Generation

  • Module Reorganization: Complete refactoring of random module structure

    • Renamed _rand_funs.py to _fun.py
    • Renamed _rand_seed.py to _seed.py
    • Renamed _rand_state.py to _state.py
    • Extracted distribution implementations to new _impl.py module (691 lines)
  • Improved Random State: Enhanced RandomState class with better state management

    • Simplified implementation (reduced from 534 to ~300 lines)
    • Better integration with JAX's random number generation
    • Improved thread safety and state isolation

Testing

  • Comprehensive Test Suite: Added 102 tests covering all state functionality
    • TestBasicState (13 tests): Core State class operations
    • TestShortTermState (2 tests): Short-term state behavior
    • TestLongTermState (2 tests): Long-term state behavior
    • TestParamState (2 tests): Parameter state usage patterns
    • TestBatchState (2 tests): Batch state functionality
    • TestHiddenState (7 tests): Hidden state with different array types
    • TestHiddenGroupState (9 tests): Multiple hidden state management
    • TestHiddenTreeState (12 tests): PyTree hidden states with units
    • TestFakeState (4 tests): Lightweight state alternative
    • TestStateDictManager (6 tests): State collection management
    • TestStateTraceStack (11 tests): State tracing and recovery
    • TestTreefyState (6 tests): PyTree state references
    • TestContextManagers (6 tests): State context managers
    • TestStateCatcher (8 tests): State catching utilities
    • TestIntegrationScenarios (5 tests): Real-world use cases

Bug Fixes

  • Fixed HiddenGroupState.set_value() to work correctly with JAX arrays
  • Improved error handling in hidden state value validation
  • Enhanced type checking for hidden state initialization

Documentation

Tutorial Reorganization

  • Basics Tutorials: Complete rewrite and expansion

    • 01_getting_started.ipynb: Enhanced introduction with practical examples
    • 02_state_management.ipynb: Comprehensive state management guide
    • 03_random_numbers.ipynb: In-depth random number generation tutorial
  • Neural Networks Tutorials: Restructured and expanded

    • 01_module_basics.ipynb: New comprehensive module system guide
    • 02_basic_layers.ipynb: Enhanced layer documentation with examples
    • 03_activations_normalization.ipynb: Detailed activation and normalization guide
    • 04_recurrent_networks.ipynb: New RNN tutorial with practical examples
    • 05_dynamics_systems.ipynb: New dynamical systems tutorial
  • Examples: Reorganized and enhanced

    • Renamed 10_image_classification.ipynb to 01_image_classification.ipynb
    • Renamed 11_sequence_modeling.ipynb to 02_sequence_modeling.ipynb
    • Added 03_brain_inspired_computing.ipynb: New brain-inspired computing examples
    • Renamed 18_optimization_tricks.ipynb to 04_optimization_tricks.ipynb
    • Renamed 19_model_deployment.ipynb to 05_model_deployment.ipynb
  • Transforms Tutorials: Reorganized for better flow

    • 01_jit_compilation.ipynb: New comprehensive JIT guide
    • 02_automatic_differentiation.ipynb: Enhanced autodiff tutorial
    • 03_vectorization.ipynb: Improved vmap/pmap guide
    • 04_loops_conditions.ipynb: Enhanced control flow guide
    • 05_other_transforms.ipynb: Other transformation utilities
  • Advanced Tutorials: Renumbered for clarity

    • 01_graph_operations.ipynb (formerly 14_graph_operations.ipynb)
    • 02_mixin_system.ipynb (formerly 15_mixin_system.ipynb)
    • 03_typing_system.ipynb (formerly 16_typing_system.ipynb)
    • 04_utilities.ipynb (formerly 17_utilities.ipynb)
  • Migration Guides: Updated and simplified

    • 01_migration_from_pytorch.ipynb: Enhanced PyTorch migration guide
    • Removed outdated BrainPy integration notebook
  • Supplementary: Reorganized

    • 01_performance_optimization.ipynb
    • 02_debugging_tips.ipynb
    • 03_faq.ipynb: Updated FAQ with new content

API Documentation

  • Enhanced module documentation in nn.rst with 306 line improvements
  • Updated transform.rst with new transform APIs
  • Improved environ.rst and graph.rst documentation

Refactoring

  • Removed deprecated eval_shape module and tests
  • Removed deprecated _random.py transform module
  • Cleaned up unused imports across all modules
  • Improved code organization in neural network layers
  • Enhanced type hints and docstrings throughout

Infrastructure

  • Added development dependency for tutorial generation
  • Updated benchmark scripts for performance testing
  • Improved test coverage across transformation modules

What's Changed

New Contributors

Full Changelog: v0.2.0...v0.2.2