Skip to content

BrainTrace 0.2.0

Latest

Choose a tag to compare

@chaoming0625 chaoming0625 released this 27 May 06:11
d0627d4

BrainTrace 0.2.0 is a major step forward. It adds a family of spiking neural network (SNN) online-learning algorithms, rewrites the eligibility-trace compiler around primitive-type dispatch, generalizes every ETP primitive to support multiple trainable inputs (fixing a silent bias-gradient drop), delivers substantial performance gains for D-RTRL and multi-step rollouts, and hardens the package with PEP 561 typing and a BPTT-oracle-backed test suite.

Highlights

New: SNN online-learning algorithms

Five SNN online-learning algorithms ship as flat ETraceVjpAlgorithm subclasses, all exported at the top level: EProp, OSTL (OSTLRecurrent / OSTLFeedforward), OTPE, OTTT, and OSTTP. A new _compute_learning_signal hook supports target-projection algorithms (OSTTP), and new trace helpers (PresynapticTrace, KappaFilter, FixedRandomFeedback) back them. The algorithms are cross-checked for regime equivalence and verified to decrease loss in integration smoke tests.

ETP compiler rewrite

The eligibility-trace compiler now dispatches on primitive-type identity rather than string-matching op/trace names, with structured, leveled diagnostics (DiagnosticKind, DiagnosticLevel, CompilationRecord) replacing ad-hoc warnings. New compile-time diagnostics surface previously silent issues — e.g. TRAINABLE_INVAR_NOT_PARAMSTATE flags a trainable input (such as a constant bias) that does not trace to a ParamState.

Multi-trainable-input ETP primitives (bias gradients)

Every ETP primitive was generalized from a single-weight assumption to an arbitrary named dict of trainable inputs, fixing a silent bias-gradient drop and a LoRA executor signature mismatch in one coherent refactor. All built-in primitives (elemwise, dense mm/mv, conv, sparse mm/mv, lora) now have first-class bias gradient support, each verified element-wise against a BPTT oracle. Layout-aware conv axis handling (NHWC/NCHW, OIHW/HWIO) and non-square dense weight broadcasting were also fixed.

Performance

  • D-RTRL einsum fast path (fast_solve=True, default on): direct einsum kernels for mm/mv/elemwise replace nested vmap-of-vjp and per-step lax.cond overhead.
  • Reduced-precision trace storage (trace_dtype, e.g. bf16/fp16) halves the dominant B*N^2 trace bandwidth while keeping Jacobians, learning signals, and final gradients in fp32.
  • Multi-step trace fusion: the per-step eligibility-trace roll for exact algorithms is threaded into the forward scan, eliminating an O(T × Jacobian) HBM round-trip.

Typing, testing & packaging

  • The package is now PEP 561 compliant (ships py.typed), with a pragmatic mypy config and CI type/packaging checks.
  • A BPTT gradient oracle and a layered correctness suite (P2–P8) cover per-operator rules, public-API contracts, exact-class element-wise equivalence, approximate-class direction alignment, transform/integration invariance, and per-cell compiler relation guardrails.
  • All public-API docstrings converted to NumPy-doc style with math, references, and runnable examples.

Deprecations

The entire v0.1.x class-based operator/parameter API is deprecated in favor of the new primitive-based ETP user-API. The legacy classes still work (as thin shims that route through the new primitives) but each emits a DeprecationWarning on first use and will be removed in a future release.

Deprecated (v0.1.x) Use instead (v0.2.0)
MatMulOp braintrace.matmul
ElemWiseOp braintrace.element_wise
ConvOp braintrace.conv
SpMatMulOp braintrace.sparse_matmul
LoraOp braintrace.lora_matmul
ETraceParam / ElemWiseParam brainstate.ParamState + an ETP primitive function
NonTempParam brainstate.ParamState + plain JAX ops

Breaking changes

  1. OSTL factory removed — use OSTLRecurrent or OSTLFeedforward directly.
  2. OTTT and OTPE require an explicit leak — no longer inferred from model.states(); both reject hidden groups with num_state > 1 at compile time.
  3. Unit dependency changebrainunit replaced by saiunit.
  4. ETPPrimitiveSpec removed — custom primitives register layout metadata via register_primitive keyword arguments (trainable_invars_fn, x_invar_index, y_outvar_index).

See the full changelog for the complete migration guide.