Version 0.0.6
This release is our most rigorously validated to date, having passed all CPU and GPU test suites. Performance optimizations are not yet included and are planned for version 0.0.7.
Added
DataRepresentationbase class with buffer registry for mutable named state on sparse matrices (register_buffer,set_buffer,buffers), plusJITCMatrixwith full operator overloading (__mul__,__add__,apply,apply2, etc.) (#81)- CSR/CSC row slicing via
csr_slice_rowswith full autodiff support (JVP, transpose, batching) and three backends (numba, warp, pallas); enablescsr[row_indices]andcsc[col_indices]indexing (#80) - SDDMM helpers (
sddmm_indices,sddmm_coo_indices,sddmm_bcoo) for Sampled Dense-Dense Matrix Multiplication built onjax.experimental.sparse(#75) - Primitive registry (
get_registry,get_primitives_by_tags,get_all_primitive_names) with automatic registration of allXLACustomKernelinstances (#65) - User backend configuration (
brainevent/config.py) with JSON persistence, per-primitive default backend selection, Numba threading config, and LFSR algorithm selection (#65, #74) - CLI tool (
brainevent benchmark-performance) for automated benchmarking across backends with tabular output and automatic optimal-default persistence (#65) - Configurable LFSR RNG for both Numba (
_numba_random.py) and Pallas (_pallas_random.py) with three algorithm families: LFSR88 (~2^88 period), LFSR113 (~2^113 period), LFSR128 (~2^128 period) (#74) - TPU backend support for CSR operations (#72)
- Event representation classes:
IndexedBinary1d/2d,IndexedSpFloat1d/2dfor indexed subsets of events, withbinary_array_index()extraction function - Fixed-connection matmul helpers (
binary_fcnmv/mm,fcnmv/mm,spfloat_fcnmv/mm) and JITC matmul helpers for scalar/normal/uniform connectivity (#61) namescopeJAX decorator for per-backend JIT compilation caching (#62)- Custom error types:
KernelNotAvailableError,KernelCompilationError,KernelFallbackExhaustedError,KernelExecutionError - Tutorial on BinaryArray usage and optimization techniques (#64)
Changed
- Major codebase restructuring: flat modules reorganized into coherent subpackages (
_coo/,_csr/,_dense/,_fcn/,_jit_scalar/,_jit_normal/,_jit_uniform/,_event/) (#59, #69) - Consistent function naming convention across all operations:
binary_*mv/mm,*mv/mm,spfloat_*mv/mm,update_*_on_binary_pre/post, with_psuffix for raw primitives (#62) EventArrayrenamed toBinaryArrayacross the entire codebase (backward-compatible alias retained)- JITC class renames:
JITCHomoR/C→JITCScalarR/C; module renames_jitc_homo→_jit_scalar,_jitc_normal→_jit_normal,_jitc_uniform→_jit_uniform - Pallas RNG class renames:
LFSR88RNG→PallasLFSR88RNG,LFSR113RNG→PallasLFSR113RNG; new factoryPallasLFSRRNG(seed) - Plasticity function renames:
csr_on_pre→update_csr_on_binary_pre,coo_on_pre→update_coo_on_binary_pre, etc. (backward-compatible aliases for CSR/dense variants) - Configuration system: replaced
_config.pysingleton withconfig.pymodule using JSON file persistence XLACustomKernelenhanced withdef_tags(),def_benchmark_data(),benchmark(),available_backends(),set_default(), andKernelEntrydataclasscsrmv_yw2ymoved to its own module_csr/yw2y.py(#79)- Unified sparse-float dense matmul operations across all formats (#77)
- Project description updated to "Enabling Event-driven Computation in CPU/GPU/TPU"
- Added Python 3.14 support; dropped Python 3.10 from classifiers
- Core dependency
jax>=0.5.0now explicitly required
Fixed
- Pallas GPU
binary_densemmkernel corruption:pl.ds()out-of-bounds reads whenblock_dim > mcorrupted adjacent GPU memory; fixed with scalarpl.program_id()indexing andjnp.whereinstead ofjax.lax.cond(#71) - Warp tile operation bug: cooperative tile ops (
tile_load,tile_store,tile_atomic_add) produced diagonal-like output when launch dimensions < 32 threads; replaced with scalar loops in_jit_normal/float.py(#71) - Backend passthrough in AD rules: JVP/transpose/batching rules now correctly forward
backend=parameter to*_p_call()functions, preventing silent use of wrong backend for tangent computation (#72) - Fixed-connection matmul return values (#62)
- Bool-to-float conversion added in
binary_densemm_p_callbefore passing to primitive (#71)
Removed
BlockCSRclass and_block_csrmoduleBlockELLclass and_block_ellmoduleBaseArray,BinaryArrayIndex,MaskedFloat,MaskedFloatIndexclasses (replaced by new event representations)GPUKernelChoice,pallas_kernel,warp_kernelfrom_op_primitives.pymodule (replaced by_registry.py)