Skip to content

Version 0.0.7

Choose a tag to compare

@chaoming0625 chaoming0625 released this 11 Mar 16:18
· 84 commits to main since this release
008f3f5

This release improves CUDA kernel performance significantly. The FCN, CSR, and JITC kernels now run much faster, with JITC achieving up to 500× acceleration in the best cases.

Added

  • CUDA kernel compilation pipeline (cuda_raw backend): Native nvcc-based compilation system. Compile .cu files on-the-fly with source-hash caching, automatic XLA FFI registration, and multi-dtype dispatch (f16, bf16, f32, f64). Key APIs: load_cuda_file, load_cuda_inline, load_cuda_dir, load_cpp_file, load_cpp_inline (#88)
  • BitPacked binary event representations: BitPackedBinary compresses 32 spike values into a single uint32 word (32x memory reduction). CompactBinary combines bitpacking with stream compaction to skip inactive rows in scatter kernels. Factory methods: BitPackedBinary.from_array(x), CompactBinary.from_array(x), and standalone bitpack() utility (#97)
  • BitPack FCN kernels: bitpack_binary_fcnmv, bitpack_binary_fcnmm, compact_binary_fcnmv, compact_binary_fcnmm with both Numba CPU and CUDA GPU backends for event-driven matmul on packed spike representations (#97)
  • Parallel RNN training (brainevent.pararnn): O(log T) parallel training via Newton's method and parallel prefix reduction. Includes parallel_rnn() single-function API, AutoRNNCell with automatic Jacobian structure detection (diagonal, block-diagonal, dense), pre-built cells (GRUDiagMH, LSTMCIFGDiagMH), fused CUDA kernels for GRU/LSTM forward and backward passes, and configurable Newton solver (#85)
  • Warp kernel support for CSR matrix-vector multiplication and various binary/sparse operations across COO, CSR, Dense, and FCN modules (#86)
  • Shared CUDA headers (brainevent/include/): common.h (BE::Tensor, BE::DType, error-check macros), cuda_common.h (warp reductions, dtype macros, atomics), dispatch.h (type dispatch macros) for consistent CUDA kernel development
  • CUDA compilation diagnostics: print_diagnostics(), get_cache_dir(), set_cache_dir(), clear_cache() for cache management; CompiledModule, register_ffi_target, list_registered_targets for FFI target management
  • Tutorials for custom GPU operators with Warp and Numba CUDA (#83)

Changed

  • CUDA raw as default GPU backend: All operations (COO, CSR, Dense, FCN, JIT*) now default to cuda_raw backend on GPU, with automatic fallback to numba/pallas when CUDA is unavailable (#94)
  • Namespace migration: brainevent.kernix namespace moved into brainevent._op and re-exported directly under brainevent.* (e.g., brainevent.load_cuda_file). Old kernix namespace removed (#96)
  • Backend rename: "tvmffi" backend renamed to "cuda_raw" throughout the codebase (#87, #96)
  • Versioned cache directory: Compiled kernel cache moved from ~/.cache/brainevent/ to ~/.cache/brainevent/<version>/ to prevent cross-version incompatibilities
  • FCN kernel launch optimization: Scatter/gather kernels switched from block-per-row (<<<n_pre, 256>>>) to thread-per-row (<<<ceil(n_pre/256), 256>>>) strategy for moderate n_conn (33–512), yielding up to 6.4x speedup on COBA benchmarks (#84, #97)
  • FCN interface streamlining: Unified fcnmv/fcnmm dispatch to optimal kernel based on input type (dense, bitpacked, or compact) (#96)
  • JAX >= 0.9.1 compatibility: Added JAX Zero init helper and refactored JVP utilities for forward compatibility (#93)
  • JIT/CSR CUDA module splitting: Reorganized CUDA kernel files for JIT and CSR operations into separate modules with updated Warp kernel implementations (#86)

Removed

  • sparse_float module and all related operations
  • IndexedBinary1d, IndexedBinary2d, IndexedSpFloat1d, IndexedSpFloat2d classes (replaced by bitpack/compact representations)
  • brainevent.kernix namespace (absorbed into brainevent._op, re-exported at top level)
  • ell_mv function (superseded by FCN operations)

Fixed

  • Binary FCN CUDA kernel correctness: Fixed kernel launch parameter issues causing incorrect results in scatter/gather operations (#87)
  • Warp tile operation bug in JIT modules: Cooperative tile ops produced diagonal-like output when launch dimensions < 32; replaced with scalar loops (#86)
  • CSR matrix-vector multiplication tolerance: Enhanced assertion tolerance for numerical stability in tests

What's Changed

  • Docs: Add tutorials for Warp, Numba CUDA, and Numba CPU operators by @chaoming0625 in #83
  • perf: implementing cuda kernels for most operators by @chaoming0625 in #84
  • feat: add parallel RNN support by @chaoming0625 in #85
  • Refactor: Split JIT/CSR CUDA modules and update Warp kernels by @chaoming0625 in #86
  • Refactor: Extract TVM FFI module and fix binary FCN CUDA kernels by @chaoming0625 in #87
  • feat: introduce CUDA kernel compilation pipeline with cuda_raw backend by @chaoming0625 in #88
  • Docs: Rename kernix docs to kernel and update tutorials by @chaoming0625 in #90
  • Feat: Add JAX Zero init helper and refactor JVP utilities by @chaoming0625 in #93
  • Refactor: Make CUDA raw kernels default GPU backend for ops by @chaoming0625 in #94
  • chore: add Apache License information by @chaoming0625 in #95
  • refactor: migrate to kernix namespace, remove sparse_float, and streamline fcnmv/fcnmm interfaces by @chaoming0625 in #96
  • feat: bitpack binary representations, FCNMM/FCNMV kernel enhancements, and public API cleanup by @chaoming0625 in #97

Full Changelog: v0.0.6...v0.0.7