Skip to content

Version 0.2.10

Choose a tag to compare

@chaoming0625 chaoming0625 released this 30 Jan 13:00
· 104 commits to main since this release
2019cae

This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.

New Features

NaN Debugging System

  • JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations

    • debug_nan: Analyze a function for NaN/Inf values with detailed reporting
    • debug_nan_if: Conditional NaN debugging with predicate-based activation
    • Full JIT compatibility for seamless integration into compiled workflows
    • Support for debugging NaN in while and scan primitives
    • Detailed analysis output including variable names, shapes, and affected indices
  • Gradient Function Integration: Added debug_nan parameter to gradient transformation functions

    • grad: Enable NaN debugging during gradient computation
    • vector_grad: NaN debugging for vectorized gradients
    • jacobian and jacobian_reverse: NaN debugging for Jacobian computations
    • hessian: NaN debugging for Hessian computations
  • Breakpoint Utility: New breakpoint function for conditional debugging

    • Wraps jax.debug.breakpoint with predicate support
    • Only triggers when the specified condition is True

API Changes

Module System

  • Renamed ModuleMapper to Map: Simplified naming for the vectorized module wrapper

    • Map provides vectorized (vmap2) and parallel (pmap2) mapping over modules
    • ModuleMapper retained as a deprecated alias for backward compatibility
    • Internal _ModuleMapperCalling renamed to _MapCaller for consistency
  • Enhanced Map.map() Method: Now accepts callable functions for flexible mapping operations

Bug Fixes

  • Fixed get_backend import for JAX version compatibility across different JAX releases
  • Removed abstractmethod decorators from Regularization class to allow proper instantiation
  • Cleaned up unused imports in module initialization files

Internal Changes

  • Added comprehensive test suite for NaN debugging (_debug_test.py, 938 lines)
  • Removed deprecated _mapping3.py module and associated tests
  • Streamlined module exports in __init__.py files