Skip to content

Version 0.2.9

Choose a tag to compare

@chaoming0625 chaoming0625 released this 16 Jan 14:16
· 105 commits to main since this release
c509d95

This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.

State Management

State Hook System

  • Global Hook Infrastructure: Comprehensive hook system for intercepting state operations

    • register_read_hook: Register hooks that execute when state values are read
    • register_write_hook: Register hooks that execute when state values are written
    • register_restore_hook: Register hooks that execute when state values are restored
    • HookManager: Thread-safe manager for organizing and executing hooks with priority support
    • HookContext: Context manager for scoped hook registration and execution
    • Enables advanced use cases: logging, debugging, value transformation, validation
  • Enhanced State Class: Improved state management with hook integration

    • Automatic hook execution on read/write operations
    • Better cache key handling for improved performance
    • Enhanced thread safety and context management
    • Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)

Neural Network Components

Parameter Management (brainstate.nn.Param and brainstate.nn.Const)

  • Renamed Classes: Simplified naming convention

    • ParaMParam: Trainable parameter wrapper
    • ConstMConst: Non-trainable constant wrapper
  • Enhanced Caching System: Improved parameter precomputation and caching

    • param_precompute context manager for efficient parameter transformation caching
    • cache() method for retrieving cached parameter values
    • Support for custom precompute functions
    • Automatic cache invalidation and management
    • 391 comprehensive tests for caching behavior
  • Hierarchical Parameter Data (brainstate.nn.HiData): New module for structured parameter organization

    • define_param_data() method for declaring hierarchical parameter structures
    • Support for nested parameter groups
    • Improved parameter surgery and manipulation
    • Enhanced type hints and documentation

Module System Enhancements

  • ModuleMapper: New helper for vectorized module operations (formerly Vmap2Module)

    • Simplified API for applying vmap2 to module methods
    • Automatic state management for vectorized operations
    • Consistent interface with Vmap2ModuleCaller
    • Comprehensive documentation with usage examples
  • Enhanced Module Methods:

    • parameters(): Iterate over all parameters in the module hierarchy
    • named_parameters(): Iterate over parameters with their qualified names
    • children(): Access direct child modules
    • named_children(): Access child modules with names
    • init_all_states(): Initialize states with additional keyword arguments
    • Improved Sequential with extend() and insert() methods

Delay Mechanisms

  • Frequency-Controlled Updates: Enhanced Delay class with flexible update strategies

    • update_every parameter: Control how often delay buffers are updated
    • Support for integer steps (update every N steps)
    • Support for time-based updates with physical units (e.g., 1*ms)
    • Automatic handling of unit conversions and validation
    • Comprehensive tests covering various update strategies
  • Unified Delay Implementation: Refactored delay mechanism

    • Ring buffer implementation for efficient historical value storage
    • Support for linear interpolation
    • Better handling of multi-dimensional inputs
    • Improved integration with neural network modules

Regularization

  • Comprehensive Regularization Module (brainstate.nn._regularization, 2840 lines):

    • Complete suite of regularization techniques
    • L1, L2, and elastic net regularization
    • Dropout variants
    • Weight decay and other parameter constraints
    • 1261 tests for regularization functionality
  • Transform Module (brainstate.nn._transform, 1661 lines):

    • Advanced parameter transformations
    • Quantization support
    • Normalization techniques
    • Integration with caching system
    • 452 comprehensive tests

Transformations

Vectorization and Parallelization

  • Mapping Function Refactoring: Reorganized mapping implementations

    • Renamed _mapping.py_mapping2.py (primary vmap2 implementation)
    • Renamed _mapping_old.py_mapping1.py (legacy vmap implementation)
    • Added _mapping3.py: New pmap2 implementation for parallelization
    • vmap2_new_states: Helper for creating new states in vectorized operations
    • Relaxed return type requirements for more flexible mapping functions
  • Enhanced Documentation: Updated tutorials and API documentation

    • Comprehensive vmap2 tutorial with practical examples
    • Enhanced parallelization documentation for pmap2
    • Updated state management guides
    • Expanded gradient transformation documentation

Compatibility and Utilities

JAX Compatibility

  • Enhanced JAX Integration: Improved compatibility with newer JAX versions
    • Updated backend import for JAX version detection
    • Enhanced get_aval function for JAX version compatibility
    • Standardized jit_named_scope arguments
    • Support for JAX 0.8.0+ in CI configuration

Utility Functions

  • Dataclass Support: Added is_dataclass utility function in brainstate.util.struct

    • Robust dataclass type checking
    • Better handling of dataclass-based structures
  • Tracer Utilities: New _tracers.py module for JAX tracer handling

    • current_jax_trace(): Get current JAX trace context with version compatibility
    • Helper functions for working with JAX abstract values

Graph Operations

  • Context Management (brainstate.graph._context):

    • New context management system for graph operations (119 lines)
    • TraceContextError: Specialized error class for tracing issues
    • Enhanced state tracking during graph construction
    • 64 tests for context management
  • Conversion Utilities (brainstate.graph._convert):

    • New conversion utilities for graph operations (278 lines)
    • Better handling of graph transformations
    • Improved node conversion logic

Random Number Generation

  • Enhanced RandomState: Improved random number generation
    • Better compatibility with newer JAX versions (98 lines of improvements)
    • Enhanced state management for random keys
    • Improved thread safety
    • Better error messages and validation

Documentation

  • Comprehensive API Documentation: Expanded documentation across all modules

    • brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)
    • environ.rst: Added 48 lines of documentation for environment state and keys
    • nn.rst: Added 222 lines documenting neural network components
    • transform.rst: Added 132 lines for gradient transformations and mapping functions
  • Tutorial Updates:

    • Updated vectorization tutorial to reflect vmapvmap2 transition
    • Enhanced examples with ModuleMapper usage
    • Improved state management examples

Breaking Changes

  • Renamed Functions and Classes:

    • ParaMParam
    • ConstMConst
    • vmapvmap2 (old vmap preserved in _mapping1.py for compatibility)
    • pmappmap2
    • _param_data_hidata
  • Parameter Naming Standardization:

    • fit_parfit across all modules
    • brainscalebraintrace in example files
  • Method Signature Changes:

    • init_all_states() now accepts additional keyword arguments
    • param_precompute() signature updated to support caching and custom functions
    • Module initialization methods enhanced with keyword argument support

Testing

  • Comprehensive Test Coverage: Added 4,000+ lines of new tests
    • Thread safety tests: 346 tests ensuring thread-safe operations
    • Hook system tests: 320 tests for state hooks
    • State management tests: 924 tests expanded coverage
    • Parameter caching tests: 391 tests for caching behavior
    • Delay mechanism tests: 244 tests for delay functionality
    • HiData tests: 463 tests for hierarchical data structures
    • Module tests: 661 tests expanded coverage
    • Regularization tests: 1,261 tests
    • Transform tests: 452 tests
    • Mapping tests: Updated for vmap2 and pmap2

Bug Fixes

  • Fixed cache key handling in state management
  • Improved error messages for missing states in gradient transformations
  • Enhanced validation for delay update frequency
  • Corrected import paths for better module organization
  • Fixed compatibility issues with JAX 0.8.0+

Internal Changes

  • Reorganized import statements across all modules for clarity
  • Enhanced type hints throughout the codebase
  • Improved code documentation with comprehensive docstrings
  • Streamlined module exports in __all__ definitions
  • Better separation of concerns in module organization

What's Changed

  • Enhance random utils and dataclass helpers for newer JAX by @chaoming0625 in #126
  • Add State hook system and refactor nn modules and transforms by @chaoming0625 in #127
  • Update vectorization docs for vmap2 and relax mapping return type by @chaoming0625 in #128
  • Refactor Param and delay APIs and add ModuleMapper/pmap2 helpers by @chaoming0625 in #129
  • Enhance Delay with frequency-controlled updates and unit-aware timing by @chaoming0625 in #130

Full Changelog: v0.2.8...v0.2.9