Version 0.4.0
This release modernizes brainstate.random with JAX typed PRNG keys and comprehensive physical-unit support, ships inline type information (PEP 561) gated by a mypy CI ratchet, adds a new brainstate.interop module for converting models to/from Flax NNX, Flax Linen, and Equinox, and expands brainstate.transform with several new state-aware transformations (vjp/jvp, shard_map, named_call, and the checkify runtime-check family).
Breaking Changes
- Renamed
jit_named_scopetonamed_scope: Thebrainstate.transform.jit_named_scopedecorator is now exported asbrainstate.transform.named_scope. Update any usage accordingly. - Removed
brainstate.transform.sofo_grad: the second-order forward-mode (SOFO) gradient helper has moved tobraintools. Replacebrainstate.transform.sofo_grad(fn, ...)with thebraintools.optim.SOFOoptimizer (seeexamples/009_sofo_mnist.pyfor the updated usage). - Removed
brainstate.graph.NodeDefandbrainstate.graph.NodeRef: the graph representation was reworked. A flattened graph is now described bybrainstate.graph.NodeSpectogether with the new edge types (NodeEdge,StateEdge,StateLeafEdge,PytreeEdge,StaticEdge,Static). Code that referencedNodeDef/NodeRefdirectly must migrate to these types; users of the high-levelgraph.flatten/graph.treefy_split/graph.treefy_mergeAPI are unaffected.
Typed PRNG Keys in brainstate.random
brainstate.random now uses JAX's modern typed PRNG keys (jax.random.key,
dtype key<fry>, scalar shape ()) everywhere a key is produced, replacing the
legacy raw uint32[2] representation.
get_key(),split_key(),split_keys(),self_assign_multi_keys(), andRandomState.valuenow return typed keys. A single key has shape()(was(2,)); a batch ofnkeys has shape(n,)(was(n, 2)). Code that assertedkey.shape == (2,)orkey.dtype == uint32, or that indexed the raw words of a key, must be updated.- Key inputs accept three forms: an integer seed, a typed JAX key, or a legacy
uint32[2]array (the last is auto-wrapped viajax.random.wrap_key_data). Passing an integer seed array of size 1 is also accepted. Invalid inputs now raiseTypeError(previouslyValueErrorin some paths). RandomStateremains transform-compatible: typed keysvmap/jit/gradcleanly over their leading axis, and state-aware transformations that special-caseRandomStatecontinue to work unchanged.- The module-level
DEFAULTgenerator still constructs without triggering JAX backend initialization at import time: it holds a lazyuint32[2]placeholder that is materialized into a typed key (viawrap_key_data, preserving the exact seed) on first use.
Migration: to recover the raw uint32[2] words from a typed key, use the new
brainstate.random.get_key_data() or jax.random.key_data(key).
New Features
Inline Type Information (PEP 561)
py.typedmarker added:brainstatenow ships inline type information, so downstream projects' type checkers (mypy, pyright, etc.) pick up brainstate's annotations automatically.- Typing correctness gate: a
mypyconfiguration with a per-module "ratchet" enforces type correctness in CI, starting withbrainstate.typing. Coverage expands module-by-module over time. - All annotations are evaluated lazily (
from __future__ import annotations), so they impose no import-time or runtime cost.
Physical Unit Support in brainstate.random
Random distributions are now comprehensively and strictly compatible with
brainunit physical units, with a consistent location–scale convention.
- Location/scale parameters carry the output unit:
normal,laplace,logistic,gumbel,wald, andtruncated_normalpropagate the unit of theirloc/scale(ormean/bounds) into the samples. When only one ofloc/scalecarries a unit, the plain value is interpreted in that same unit; a compatible-but-different unit (e.g.voltagainstmV) is converted, while an incompatible one raisesUnitMismatchError. - Scale-only distributions carry the scale unit:
exponential,gamma,rayleigh, andweibull_minpropagate the unit of theirscaleparameter. multivariate_normalcarries the unit ofmean(withcovrequired to bemean-unit squared).- Shape / rate / count / probability parameters are strictly dimensionless: parameters such as
df,a/b,lam,n,p,alpha,logits,kappa,concentration, and friends reject a dimensionalQuantitywith a clearValueError. A genuinely dimensionlessQuantity(e.g.3.0 * u.UNITLESS) is accepted. - No units → plain arrays: every distribution returns a plain array when given plain inputs, so existing unitless code is unaffected.
Raw Key Interop Helper
brainstate.random.get_key_data()returns the current global key as a rawuint32[2]array (viajax.random.key_data), for interfacing with code that still expects the legacy representation.
Framework Interoperability (brainstate.interop)
A new brainstate.interop module converts modules to and from other JAX
frameworks, with an extensible layer registry:
- Flax NNX:
to_nnx/from_nnx. - Flax Linen:
to_linen/from_linen. - Equinox:
to_equinox/from_equinox. - Registry:
register_layer_mapping,supported_layers,LayerMapping. - Typed errors:
InteropErrorand its subclasses (MissingDependencyError,UnmappedLayerError,UnsupportedLayerError,UnsupportedStructureError,MissingShapeError,ConversionError).
New Transformations
brainstate.transform gains several state-aware transformations:
vjp/jvp: state-aware reverse- and forward-mode differentiation products (companions tograd).shard_map: a state-aware wrapper overjax.shard_mapfor SPMD sharding.named_call: attach a name to a sub-computation for clearer jaxprs and profiles.- Runtime checks (
checkifyfamily):checkify,check,check_error, and the error-class selectorsnan_checks,div_checks,index_checks,float_checks,user_checks,automatic_checks,all_checks. register_prim_handler: register custom primitive handlers for the IR/codegen pipeline.
Bug Fixes
multivariate_normalnow propagates physical units: previously the output unit was read after the mantissa had already been stripped frommean, so units were silently dropped. Samples now correctly carry the unit ofmean.truncated_normalnow accepts unit-carrying bounds with defaultloc/scale: the shared output unit is inferred from whichever oflower/upper/loc/scalecarries one, and plain values are interpreted in that unit (previously a unit on the bounds with the default plainloc/scaleraisedUnitMismatchError).brainstate.transform.vjpnow supports state-only differentiation: callingvjp(fun, grad_states=...)with no differentiable positional argument (e.g. a loss that closes over trainable parameters) previously raisedIndexError. It now returns a pullback yielding just the state cotangents, matchingbrainstate.transform.gradsemantics.brainstate.transform.vjpacceptsargnums=None: likegrad,argnums=Nonedisables positional-argument differentiation so the pullback returns only state cotangents.- Clearer
vjperrors: out-of-rangeargnumsnow raises a descriptiveValueErrorinstead of a bareIndexError, and supplying neither positional primals norgrad_statesraises an explanatoryValueError. - No
jax.core.DropVardeprecation warning on import: the JAX compatibility layer now sourcesDropVarfromjax.extend.coreon JAX >= 0.10, removing a redundant deprecated import.
Known Issues
Known defects deferred to a future patch release (each has a skipped regression
test capturing the repro):
nn.AdaptiveAvgPool2d/3d(and Max variants) raiseTypeErrorwhen a target dimension isNone, despite documentingNoneas "do not pool this dimension".random.truncated_normal/nn.init.TruncatedNormal()crash whenlower/upperare left at theirNonedefaults.nn.weight_standardizationraises when given a unit-carryingQuantityinput.- The
nncollective-opvmap-call helpers can leak a JAXBatchTracerinto newly created state values. nndelay unit retrieval can fail with a pytree-node mismatch (Quantityhistory vsUnit).nnevent fixed-probability connectivity withefferent_target='pre'can crash (and, withafferent_ratio < 1, abort) inside thebraineventCSC path.- State filtering with the documented
{filter: axis}mapping form raisesTypeError.
What's Changed
- Expand JAX compat to 0.10 and refactor version handling by @chaoming0625 in #140
- deps(deps): bump the production-dependencies group with 5 updates by @dependabot[bot] in #141
- deps(deps-dev): update braintools requirement from >=0.1.0 to >=0.1.8 in the development-dependencies group by @dependabot[bot] in #142
- deps(deps): bump appleboy/ssh-action from 1.2.0 to 1.2.5 by @dependabot[bot] in #143
- deps(deps): bump appleboy/scp-action from 0.1.7 to 1.0.0 by @dependabot[bot] in #144
- deps(deps): bump actions/checkout from 4 to 6 by @dependabot[bot] in #146
- deps(deps): update brainx-sphinx-header requirement from >=0.1.0 to >=0.3.0 in the production-dependencies group by @dependabot[bot] in #147
- deps(deps): bump actions/setup-python from 5 to 6 by @dependabot[bot] in #145
- deps(deps): update brainx-sphinx-header requirement from >=0.3.0 to >=0.4.0 in the production-dependencies group by @dependabot[bot] in #148
- deps(deps): bump the production-dependencies group with 2 updates by @dependabot[bot] in #150
- deps(deps): bump actions/download-artifact from 5 to 8 by @dependabot[bot] in #149
- Rebuild NaN/Inf debugging on jax.experimental.checkify by @chaoming0625 in #152
- refactor(transform): move SOFO to braintools.optim by @chaoming0625 in #153
- feat(transform): harden _ir* modules (robustness, correctness, coverage) by @chaoming0625 in #154
- Reimplement eval_shape on StatefulFunction for state consistency by @chaoming0625 in #155
- Unify vmap2/pmap2 and legacy vmap onto a shared state-aware engine by @chaoming0625 in #156
- Add PEP 561 py.typed marker and mypy type-checking gate by @chaoming0625 in #157
- test(transform): repoint _mapping1_test helper imports to _mapping_core by @chaoming0625 in #158
- test: shared test infrastructure (Phase 0) by @chaoming0625 in #159
- feat(transform): add state-aware vjp and jvp autodiff primitives by @chaoming0625 in #160
- feat(transform): add parallel-in-time associative_scan and linear_recurrence by @chaoming0625 in #161
- test(core): comprehensive State/hook/error/env/mixin/typing tests + environ deadlock fix by @chaoming0625 in #163
- Add state-aware shard_map transform by @chaoming0625 in #162
- Add state-aware custom_vjp and custom_jvp transforms by @chaoming0625 in #164
- Rewrite graph engine on a clean-room flat index-keyed IR by @chaoming0625 in #166
- Add state-aware pure_callback and io_callback transforms by @chaoming0625 in #165
- Add state-aware functional checkify transform with check/check_error by @chaoming0625 in #167
- Add lightweight state-transparent named_call transform by @chaoming0625 in #168
- Add comprehensive transform-subpackage tests (78%->98%) + fix two _ir_tocode bugs by @chaoming0625 in #169
- Add comprehensive graph-subpackage tests (77%->95%) + fix copy.deepcopy on nodes by @chaoming0625 in #170
- test(nn): comprehensive nn API tests (79%→97%) + init.param State fix by @chaoming0625 in #171
- test(random): comprehensive random tests (82%→99%) + 3 distribution/key bug fixes by @chaoming0625 in #172
- test(util): comprehensive util subpackage tests (89%→97%) by @chaoming0625 in #173
- test: guard version-sensitive tests for jax<0.10 CI compatibility by @chaoming0625 in #174
- test: widen jax<0.10 pmap/empty guards to cover construction by @chaoming0625 in #175
- ci: add test coverage badge (Codecov) by @chaoming0625 in #176
- refactor(transform): remove non-state-aware associative_scan and linear_recurrence by @chaoming0625 in #177
- refactor(transform): rename jit_named_scope to named_scope by @chaoming0625 in #178
- fix(transform): support state-only vjp + clearer argnums errors, with examples by @chaoming0625 in #180
- feat(random): typed PRNG keys + comprehensive physical-unit support by @chaoming0625 in #179
- fix(transform): trace collectives during shard_map state discovery by @chaoming0625 in #181
- docs(examples): add runnable shard_map, vjp, and jvp transform examples by @chaoming0625 in #182
- refactor(transform): move shard_map version resolution into _compatible_import by @chaoming0625 in #183
- refactor(transform): remove custom_vjp/custom_jvp APIs by @chaoming0625 in #184
- refactor(transform): remove pure_callback and io_callback wrappers by @chaoming0625 in #185
- feat(repr): clean up repr for graph.Node and nn.Module by @chaoming0625 in #186
- test(random): comprehensive physical-unit tests; fix uniform unit symmetry by @chaoming0625 in #187
- feat(transform): jax.vmap parity for vmap/vmap2 — collectives, pytree in_axes, mapped kwargs by @chaoming0625 in #188
- vmap2 examples by @chaoming0625 in #189
- feat(interop): convert models between brainstate.nn and flax.nnx / flax.linen / equinox by @chaoming0625 in #190
- docs(api): add docstrings and type annotations across public APIs by @chaoming0625 in #191
- feat(transform): add tqdm-free fallback for ProgressBar, make tqdm optional by @chaoming0625 in #192
- docs: reformat Google/Sphinx docstrings to NumPy style by @chaoming0625 in #193
- docs(api): reorganize API reference; split nn/transform, add interop & hooks by @chaoming0625 in #194
- worktree docs reorganization by @chaoming0625 in #195
- fix: stop emitting jax.core.DropVar deprecation warning by @chaoming0625 in #196
- release: prep 0.4.0 — fix publish version check, finalize changelog, pin deps by @chaoming0625 in #197
- docs: use braintools.init instead of brainstate.nn.init; drop init API page by @chaoming0625 in #198
Full Changelog: v0.3.0...v0.4.0