Skip to content

Version 0.0.6

Choose a tag to compare

@chaoming0625 chaoming0625 released this 14 Feb 16:21
· 204 commits to main since this release
337bd12

This release is our most rigorously validated to date, having passed all CPU and GPU test suites. Performance optimizations are not yet included and are planned for version 0.0.7.

Added

  • DataRepresentation base class with buffer registry for mutable named state on sparse matrices (register_buffer, set_buffer, buffers), plus JITCMatrix with full operator overloading (__mul__, __add__, apply, apply2, etc.) (#81)
  • CSR/CSC row slicing via csr_slice_rows with full autodiff support (JVP, transpose, batching) and three backends (numba, warp, pallas); enables csr[row_indices] and csc[col_indices] indexing (#80)
  • SDDMM helpers (sddmm_indices, sddmm_coo_indices, sddmm_bcoo) for Sampled Dense-Dense Matrix Multiplication built on jax.experimental.sparse (#75)
  • Primitive registry (get_registry, get_primitives_by_tags, get_all_primitive_names) with automatic registration of all XLACustomKernel instances (#65)
  • User backend configuration (brainevent/config.py) with JSON persistence, per-primitive default backend selection, Numba threading config, and LFSR algorithm selection (#65, #74)
  • CLI tool (brainevent benchmark-performance) for automated benchmarking across backends with tabular output and automatic optimal-default persistence (#65)
  • Configurable LFSR RNG for both Numba (_numba_random.py) and Pallas (_pallas_random.py) with three algorithm families: LFSR88 (~2^88 period), LFSR113 (~2^113 period), LFSR128 (~2^128 period) (#74)
  • TPU backend support for CSR operations (#72)
  • Event representation classes: IndexedBinary1d/2d, IndexedSpFloat1d/2d for indexed subsets of events, with binary_array_index() extraction function
  • Fixed-connection matmul helpers (binary_fcnmv/mm, fcnmv/mm, spfloat_fcnmv/mm) and JITC matmul helpers for scalar/normal/uniform connectivity (#61)
  • namescope JAX decorator for per-backend JIT compilation caching (#62)
  • Custom error types: KernelNotAvailableError, KernelCompilationError, KernelFallbackExhaustedError, KernelExecutionError
  • Tutorial on BinaryArray usage and optimization techniques (#64)

Changed

  • Major codebase restructuring: flat modules reorganized into coherent subpackages (_coo/, _csr/, _dense/, _fcn/, _jit_scalar/, _jit_normal/, _jit_uniform/, _event/) (#59, #69)
  • Consistent function naming convention across all operations: binary_*mv/mm, *mv/mm, spfloat_*mv/mm, update_*_on_binary_pre/post, with _p suffix for raw primitives (#62)
  • EventArray renamed to BinaryArray across the entire codebase (backward-compatible alias retained)
  • JITC class renames: JITCHomoR/CJITCScalarR/C; module renames _jitc_homo_jit_scalar, _jitc_normal_jit_normal, _jitc_uniform_jit_uniform
  • Pallas RNG class renames: LFSR88RNGPallasLFSR88RNG, LFSR113RNGPallasLFSR113RNG; new factory PallasLFSRRNG(seed)
  • Plasticity function renames: csr_on_preupdate_csr_on_binary_pre, coo_on_preupdate_coo_on_binary_pre, etc. (backward-compatible aliases for CSR/dense variants)
  • Configuration system: replaced _config.py singleton with config.py module using JSON file persistence
  • XLACustomKernel enhanced with def_tags(), def_benchmark_data(), benchmark(), available_backends(), set_default(), and KernelEntry dataclass
  • csrmv_yw2y moved to its own module _csr/yw2y.py (#79)
  • Unified sparse-float dense matmul operations across all formats (#77)
  • Project description updated to "Enabling Event-driven Computation in CPU/GPU/TPU"
  • Added Python 3.14 support; dropped Python 3.10 from classifiers
  • Core dependency jax>=0.5.0 now explicitly required

Fixed

  • Pallas GPU binary_densemm kernel corruption: pl.ds() out-of-bounds reads when block_dim > m corrupted adjacent GPU memory; fixed with scalar pl.program_id() indexing and jnp.where instead of jax.lax.cond (#71)
  • Warp tile operation bug: cooperative tile ops (tile_load, tile_store, tile_atomic_add) produced diagonal-like output when launch dimensions < 32 threads; replaced with scalar loops in _jit_normal/float.py (#71)
  • Backend passthrough in AD rules: JVP/transpose/batching rules now correctly forward backend= parameter to *_p_call() functions, preventing silent use of wrong backend for tangent computation (#72)
  • Fixed-connection matmul return values (#62)
  • Bool-to-float conversion added in binary_densemm_p_call before passing to primitive (#71)

Removed

  • BlockCSR class and _block_csr module
  • BlockELL class and _block_ell module
  • BaseArray, BinaryArrayIndex, MaskedFloat, MaskedFloatIndex classes (replaced by new event representations)
  • GPUKernelChoice, pallas_kernel, warp_kernel from _op
  • _primitives.py module (replaced by _registry.py)