Skip to content

Version 0.4.0

Choose a tag to compare

@chaoming0625 chaoming0625 released this 01 Jun 12:10
· 25 commits to main since this release
4ecbb68

This release modernizes brainstate.random with JAX typed PRNG keys and comprehensive physical-unit support, ships inline type information (PEP 561) gated by a mypy CI ratchet, adds a new brainstate.interop module for converting models to/from Flax NNX, Flax Linen, and Equinox, and expands brainstate.transform with several new state-aware transformations (vjp/jvp, shard_map, named_call, and the checkify runtime-check family).

Breaking Changes

  • Renamed jit_named_scope to named_scope: The brainstate.transform.jit_named_scope decorator is now exported as brainstate.transform.named_scope. Update any usage accordingly.
  • Removed brainstate.transform.sofo_grad: the second-order forward-mode (SOFO) gradient helper has moved to braintools. Replace brainstate.transform.sofo_grad(fn, ...) with the braintools.optim.SOFO optimizer (see examples/009_sofo_mnist.py for the updated usage).
  • Removed brainstate.graph.NodeDef and brainstate.graph.NodeRef: the graph representation was reworked. A flattened graph is now described by brainstate.graph.NodeSpec together with the new edge types (NodeEdge, StateEdge, StateLeafEdge, PytreeEdge, StaticEdge, Static). Code that referenced NodeDef/NodeRef directly must migrate to these types; users of the high-level graph.flatten / graph.treefy_split / graph.treefy_merge API are unaffected.

Typed PRNG Keys in brainstate.random

brainstate.random now uses JAX's modern typed PRNG keys (jax.random.key,
dtype key<fry>, scalar shape ()) everywhere a key is produced, replacing the
legacy raw uint32[2] representation.

  • get_key(), split_key(), split_keys(), self_assign_multi_keys(), and RandomState.value now return typed keys. A single key has shape () (was (2,)); a batch of n keys has shape (n,) (was (n, 2)). Code that asserted key.shape == (2,) or key.dtype == uint32, or that indexed the raw words of a key, must be updated.
  • Key inputs accept three forms: an integer seed, a typed JAX key, or a legacy uint32[2] array (the last is auto-wrapped via jax.random.wrap_key_data). Passing an integer seed array of size 1 is also accepted. Invalid inputs now raise TypeError (previously ValueError in some paths).
  • RandomState remains transform-compatible: typed keys vmap/jit/grad cleanly over their leading axis, and state-aware transformations that special-case RandomState continue to work unchanged.
  • The module-level DEFAULT generator still constructs without triggering JAX backend initialization at import time: it holds a lazy uint32[2] placeholder that is materialized into a typed key (via wrap_key_data, preserving the exact seed) on first use.

Migration: to recover the raw uint32[2] words from a typed key, use the new
brainstate.random.get_key_data() or jax.random.key_data(key).

New Features

Inline Type Information (PEP 561)

  • py.typed marker added: brainstate now ships inline type information, so downstream projects' type checkers (mypy, pyright, etc.) pick up brainstate's annotations automatically.
  • Typing correctness gate: a mypy configuration with a per-module "ratchet" enforces type correctness in CI, starting with brainstate.typing. Coverage expands module-by-module over time.
  • All annotations are evaluated lazily (from __future__ import annotations), so they impose no import-time or runtime cost.

Physical Unit Support in brainstate.random

Random distributions are now comprehensively and strictly compatible with
brainunit physical units
, with a consistent location–scale convention.

  • Location/scale parameters carry the output unit: normal, laplace, logistic, gumbel, wald, and truncated_normal propagate the unit of their loc/scale (or mean/bounds) into the samples. When only one of loc/scale carries a unit, the plain value is interpreted in that same unit; a compatible-but-different unit (e.g. volt against mV) is converted, while an incompatible one raises UnitMismatchError.
  • Scale-only distributions carry the scale unit: exponential, gamma, rayleigh, and weibull_min propagate the unit of their scale parameter.
  • multivariate_normal carries the unit of mean (with cov required to be mean-unit squared).
  • Shape / rate / count / probability parameters are strictly dimensionless: parameters such as df, a/b, lam, n, p, alpha, logits, kappa, concentration, and friends reject a dimensional Quantity with a clear ValueError. A genuinely dimensionless Quantity (e.g. 3.0 * u.UNITLESS) is accepted.
  • No units → plain arrays: every distribution returns a plain array when given plain inputs, so existing unitless code is unaffected.

Raw Key Interop Helper

  • brainstate.random.get_key_data() returns the current global key as a raw uint32[2] array (via jax.random.key_data), for interfacing with code that still expects the legacy representation.

Framework Interoperability (brainstate.interop)

A new brainstate.interop module converts modules to and from other JAX
frameworks, with an extensible layer registry:

  • Flax NNX: to_nnx / from_nnx.
  • Flax Linen: to_linen / from_linen.
  • Equinox: to_equinox / from_equinox.
  • Registry: register_layer_mapping, supported_layers, LayerMapping.
  • Typed errors: InteropError and its subclasses (MissingDependencyError, UnmappedLayerError, UnsupportedLayerError, UnsupportedStructureError, MissingShapeError, ConversionError).

New Transformations

brainstate.transform gains several state-aware transformations:

  • vjp / jvp: state-aware reverse- and forward-mode differentiation products (companions to grad).
  • shard_map: a state-aware wrapper over jax.shard_map for SPMD sharding.
  • named_call: attach a name to a sub-computation for clearer jaxprs and profiles.
  • Runtime checks (checkify family): checkify, check, check_error, and the error-class selectors nan_checks, div_checks, index_checks, float_checks, user_checks, automatic_checks, all_checks.
  • register_prim_handler: register custom primitive handlers for the IR/codegen pipeline.

Bug Fixes

  • multivariate_normal now propagates physical units: previously the output unit was read after the mantissa had already been stripped from mean, so units were silently dropped. Samples now correctly carry the unit of mean.
  • truncated_normal now accepts unit-carrying bounds with default loc/scale: the shared output unit is inferred from whichever of lower/upper/loc/scale carries one, and plain values are interpreted in that unit (previously a unit on the bounds with the default plain loc/scale raised UnitMismatchError).
  • brainstate.transform.vjp now supports state-only differentiation: calling vjp(fun, grad_states=...) with no differentiable positional argument (e.g. a loss that closes over trainable parameters) previously raised IndexError. It now returns a pullback yielding just the state cotangents, matching brainstate.transform.grad semantics.
  • brainstate.transform.vjp accepts argnums=None: like grad, argnums=None disables positional-argument differentiation so the pullback returns only state cotangents.
  • Clearer vjp errors: out-of-range argnums now raises a descriptive ValueError instead of a bare IndexError, and supplying neither positional primals nor grad_states raises an explanatory ValueError.
  • No jax.core.DropVar deprecation warning on import: the JAX compatibility layer now sources DropVar from jax.extend.core on JAX >= 0.10, removing a redundant deprecated import.

Known Issues

Known defects deferred to a future patch release (each has a skipped regression
test capturing the repro):

  • nn.AdaptiveAvgPool2d/3d (and Max variants) raise TypeError when a target dimension is None, despite documenting None as "do not pool this dimension".
  • random.truncated_normal / nn.init.TruncatedNormal() crash when lower/upper are left at their None defaults.
  • nn.weight_standardization raises when given a unit-carrying Quantity input.
  • The nn collective-op vmap-call helpers can leak a JAX BatchTracer into newly created state values.
  • nn delay unit retrieval can fail with a pytree-node mismatch (Quantity history vs Unit).
  • nn event fixed-probability connectivity with efferent_target='pre' can crash (and, with afferent_ratio < 1, abort) inside the brainevent CSC path.
  • State filtering with the documented {filter: axis} mapping form raises TypeError.

What's Changed

  • Expand JAX compat to 0.10 and refactor version handling by @chaoming0625 in #140
  • deps(deps): bump the production-dependencies group with 5 updates by @dependabot[bot] in #141
  • deps(deps-dev): update braintools requirement from >=0.1.0 to >=0.1.8 in the development-dependencies group by @dependabot[bot] in #142
  • deps(deps): bump appleboy/ssh-action from 1.2.0 to 1.2.5 by @dependabot[bot] in #143
  • deps(deps): bump appleboy/scp-action from 0.1.7 to 1.0.0 by @dependabot[bot] in #144
  • deps(deps): bump actions/checkout from 4 to 6 by @dependabot[bot] in #146
  • deps(deps): update brainx-sphinx-header requirement from >=0.1.0 to >=0.3.0 in the production-dependencies group by @dependabot[bot] in #147
  • deps(deps): bump actions/setup-python from 5 to 6 by @dependabot[bot] in #145
  • deps(deps): update brainx-sphinx-header requirement from >=0.3.0 to >=0.4.0 in the production-dependencies group by @dependabot[bot] in #148
  • deps(deps): bump the production-dependencies group with 2 updates by @dependabot[bot] in #150
  • deps(deps): bump actions/download-artifact from 5 to 8 by @dependabot[bot] in #149
  • Rebuild NaN/Inf debugging on jax.experimental.checkify by @chaoming0625 in #152
  • refactor(transform): move SOFO to braintools.optim by @chaoming0625 in #153
  • feat(transform): harden _ir* modules (robustness, correctness, coverage) by @chaoming0625 in #154
  • Reimplement eval_shape on StatefulFunction for state consistency by @chaoming0625 in #155
  • Unify vmap2/pmap2 and legacy vmap onto a shared state-aware engine by @chaoming0625 in #156
  • Add PEP 561 py.typed marker and mypy type-checking gate by @chaoming0625 in #157
  • test(transform): repoint _mapping1_test helper imports to _mapping_core by @chaoming0625 in #158
  • test: shared test infrastructure (Phase 0) by @chaoming0625 in #159
  • feat(transform): add state-aware vjp and jvp autodiff primitives by @chaoming0625 in #160
  • feat(transform): add parallel-in-time associative_scan and linear_recurrence by @chaoming0625 in #161
  • test(core): comprehensive State/hook/error/env/mixin/typing tests + environ deadlock fix by @chaoming0625 in #163
  • Add state-aware shard_map transform by @chaoming0625 in #162
  • Add state-aware custom_vjp and custom_jvp transforms by @chaoming0625 in #164
  • Rewrite graph engine on a clean-room flat index-keyed IR by @chaoming0625 in #166
  • Add state-aware pure_callback and io_callback transforms by @chaoming0625 in #165
  • Add state-aware functional checkify transform with check/check_error by @chaoming0625 in #167
  • Add lightweight state-transparent named_call transform by @chaoming0625 in #168
  • Add comprehensive transform-subpackage tests (78%->98%) + fix two _ir_tocode bugs by @chaoming0625 in #169
  • Add comprehensive graph-subpackage tests (77%->95%) + fix copy.deepcopy on nodes by @chaoming0625 in #170
  • test(nn): comprehensive nn API tests (79%→97%) + init.param State fix by @chaoming0625 in #171
  • test(random): comprehensive random tests (82%→99%) + 3 distribution/key bug fixes by @chaoming0625 in #172
  • test(util): comprehensive util subpackage tests (89%→97%) by @chaoming0625 in #173
  • test: guard version-sensitive tests for jax<0.10 CI compatibility by @chaoming0625 in #174
  • test: widen jax<0.10 pmap/empty guards to cover construction by @chaoming0625 in #175
  • ci: add test coverage badge (Codecov) by @chaoming0625 in #176
  • refactor(transform): remove non-state-aware associative_scan and linear_recurrence by @chaoming0625 in #177
  • refactor(transform): rename jit_named_scope to named_scope by @chaoming0625 in #178
  • fix(transform): support state-only vjp + clearer argnums errors, with examples by @chaoming0625 in #180
  • feat(random): typed PRNG keys + comprehensive physical-unit support by @chaoming0625 in #179
  • fix(transform): trace collectives during shard_map state discovery by @chaoming0625 in #181
  • docs(examples): add runnable shard_map, vjp, and jvp transform examples by @chaoming0625 in #182
  • refactor(transform): move shard_map version resolution into _compatible_import by @chaoming0625 in #183
  • refactor(transform): remove custom_vjp/custom_jvp APIs by @chaoming0625 in #184
  • refactor(transform): remove pure_callback and io_callback wrappers by @chaoming0625 in #185
  • feat(repr): clean up repr for graph.Node and nn.Module by @chaoming0625 in #186
  • test(random): comprehensive physical-unit tests; fix uniform unit symmetry by @chaoming0625 in #187
  • feat(transform): jax.vmap parity for vmap/vmap2 — collectives, pytree in_axes, mapped kwargs by @chaoming0625 in #188
  • vmap2 examples by @chaoming0625 in #189
  • feat(interop): convert models between brainstate.nn and flax.nnx / flax.linen / equinox by @chaoming0625 in #190
  • docs(api): add docstrings and type annotations across public APIs by @chaoming0625 in #191
  • feat(transform): add tqdm-free fallback for ProgressBar, make tqdm optional by @chaoming0625 in #192
  • docs: reformat Google/Sphinx docstrings to NumPy style by @chaoming0625 in #193
  • docs(api): reorganize API reference; split nn/transform, add interop & hooks by @chaoming0625 in #194
  • worktree docs reorganization by @chaoming0625 in #195
  • fix: stop emitting jax.core.DropVar deprecation warning by @chaoming0625 in #196
  • release: prep 0.4.0 — fix publish version check, finalize changelog, pin deps by @chaoming0625 in #197
  • docs: use braintools.init instead of brainstate.nn.init; drop init API page by @chaoming0625 in #198

Full Changelog: v0.3.0...v0.4.0