Version 0.0.7
This release improves CUDA kernel performance significantly. The FCN, CSR, and JITC kernels now run much faster, with JITC achieving up to 500× acceleration in the best cases.
Added
- CUDA kernel compilation pipeline (
cuda_rawbackend): Native nvcc-based compilation system. Compile.cufiles on-the-fly with source-hash caching, automatic XLA FFI registration, and multi-dtype dispatch (f16, bf16, f32, f64). Key APIs:load_cuda_file,load_cuda_inline,load_cuda_dir,load_cpp_file,load_cpp_inline(#88) - BitPacked binary event representations:
BitPackedBinarycompresses 32 spike values into a single uint32 word (32x memory reduction).CompactBinarycombines bitpacking with stream compaction to skip inactive rows in scatter kernels. Factory methods:BitPackedBinary.from_array(x),CompactBinary.from_array(x), and standalonebitpack()utility (#97) - BitPack FCN kernels:
bitpack_binary_fcnmv,bitpack_binary_fcnmm,compact_binary_fcnmv,compact_binary_fcnmmwith both Numba CPU and CUDA GPU backends for event-driven matmul on packed spike representations (#97) - Parallel RNN training (
brainevent.pararnn): O(log T) parallel training via Newton's method and parallel prefix reduction. Includesparallel_rnn()single-function API,AutoRNNCellwith automatic Jacobian structure detection (diagonal, block-diagonal, dense), pre-built cells (GRUDiagMH,LSTMCIFGDiagMH), fused CUDA kernels for GRU/LSTM forward and backward passes, and configurable Newton solver (#85) - Warp kernel support for CSR matrix-vector multiplication and various binary/sparse operations across COO, CSR, Dense, and FCN modules (#86)
- Shared CUDA headers (
brainevent/include/):common.h(BE::Tensor,BE::DType, error-check macros),cuda_common.h(warp reductions, dtype macros, atomics),dispatch.h(type dispatch macros) for consistent CUDA kernel development - CUDA compilation diagnostics:
print_diagnostics(),get_cache_dir(),set_cache_dir(),clear_cache()for cache management;CompiledModule,register_ffi_target,list_registered_targetsfor FFI target management - Tutorials for custom GPU operators with Warp and Numba CUDA (#83)
Changed
- CUDA raw as default GPU backend: All operations (COO, CSR, Dense, FCN, JIT*) now default to
cuda_rawbackend on GPU, with automatic fallback to numba/pallas when CUDA is unavailable (#94) - Namespace migration:
brainevent.kernixnamespace moved intobrainevent._opand re-exported directly underbrainevent.*(e.g.,brainevent.load_cuda_file). Oldkernixnamespace removed (#96) - Backend rename:
"tvmffi"backend renamed to"cuda_raw"throughout the codebase (#87, #96) - Versioned cache directory: Compiled kernel cache moved from
~/.cache/brainevent/to~/.cache/brainevent/<version>/to prevent cross-version incompatibilities - FCN kernel launch optimization: Scatter/gather kernels switched from block-per-row (
<<<n_pre, 256>>>) to thread-per-row (<<<ceil(n_pre/256), 256>>>) strategy for moderate n_conn (33–512), yielding up to 6.4x speedup on COBA benchmarks (#84, #97) - FCN interface streamlining: Unified
fcnmv/fcnmmdispatch to optimal kernel based on input type (dense, bitpacked, or compact) (#96) - JAX >= 0.9.1 compatibility: Added JAX Zero init helper and refactored JVP utilities for forward compatibility (#93)
- JIT/CSR CUDA module splitting: Reorganized CUDA kernel files for JIT and CSR operations into separate modules with updated Warp kernel implementations (#86)
Removed
sparse_floatmodule and all related operationsIndexedBinary1d,IndexedBinary2d,IndexedSpFloat1d,IndexedSpFloat2dclasses (replaced by bitpack/compact representations)brainevent.kernixnamespace (absorbed intobrainevent._op, re-exported at top level)ell_mvfunction (superseded by FCN operations)
Fixed
- Binary FCN CUDA kernel correctness: Fixed kernel launch parameter issues causing incorrect results in scatter/gather operations (#87)
- Warp tile operation bug in JIT modules: Cooperative tile ops produced diagonal-like output when launch dimensions < 32; replaced with scalar loops (#86)
- CSR matrix-vector multiplication tolerance: Enhanced assertion tolerance for numerical stability in tests
What's Changed
- Docs: Add tutorials for Warp, Numba CUDA, and Numba CPU operators by @chaoming0625 in #83
- perf: implementing cuda kernels for most operators by @chaoming0625 in #84
- feat: add parallel RNN support by @chaoming0625 in #85
- Refactor: Split JIT/CSR CUDA modules and update Warp kernels by @chaoming0625 in #86
- Refactor: Extract TVM FFI module and fix binary FCN CUDA kernels by @chaoming0625 in #87
- feat: introduce CUDA kernel compilation pipeline with
cuda_rawbackend by @chaoming0625 in #88 - Docs: Rename kernix docs to kernel and update tutorials by @chaoming0625 in #90
- Feat: Add JAX Zero init helper and refactor JVP utilities by @chaoming0625 in #93
- Refactor: Make CUDA raw kernels default GPU backend for ops by @chaoming0625 in #94
- chore: add Apache License information by @chaoming0625 in #95
- refactor: migrate to kernix namespace, remove sparse_float, and streamline fcnmv/fcnmm interfaces by @chaoming0625 in #96
- feat: bitpack binary representations, FCNMM/FCNMV kernel enhancements, and public API cleanup by @chaoming0625 in #97
Full Changelog: v0.0.6...v0.0.7