A repository-wide correctness release. Following the brainstate.transform audit shipped in 0.4.x, this cycle extended the same expert-audit discipline to nearly every remaining module — random, graph, interop, nn, util, the vmap / pmap / shard_map mapping engine, and the exp_euler integrator — and closed out a single consolidated cross-module audit. Every fix ships with a behavioral regression test, and the suite is green across the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, latest). The release also lands a graph-layer performance pass.
No public APIs are removed or renamed. The only behavioral changes are previously-silent wrong-result or invalid-input paths that now fail loudly with descriptive errors.
Performance
- Graph flatten/unflatten fast paths (#218): type-keyed value-classification cache backing the node predicates, encoder dispatch, and flattening kernel; exact-type decoder dispatch; all-static hashable pytrees collapse to a single
StaticEdge;graph_to_treereadsStates directly from theRefMap; sharedStates de-duplicated initer_leaf/states.
Bug Fixes
brainstate.random (#211) — six distribution bugs, each contradicting its own docstring:
standard_twith arraydfandsize=None(deadshape(size)branch) now infers shape fromdf.weibull_minnow multiplies byscale(was dividing).triangularreimplemented as the truetriangular(left, mode, right, size)via inverse-CDF (was a Rademacher ±1 draw).geometricnow supports{1,2,...}with an integer dtype andP(k==1) == p(was off-by-one, float).randint_likedefaulthighusesu.math.max(no longer raises on >1-D templates).chisquareuses the2·Gamma(df/2)relation, valid for any positive real / arraydf.
brainstate.graph (#212) — merge_context yields the live index dict; Node.check_valid_context derives validity from reachable States instead of raising AttributeError; pop_states detaches every shared alias of a popped state.
brainstate.interop (#213) — nnx Conv input_dilation guard; norm-channel extraction from framework metadata (fixes affine-less LayerNorm / RMSNorm / GroupNorm); bst_set_norm / bst_set_batchnorm None-handling; lookup_export no longer rebuilds an O(N) dict per call.
brainstate.nn (#215) — module-wide audit: dropout self-normalizing constants & unbatched mask dims; default softmax axis; ScaledWSLinear / AllToAll shapes; Precision / Recall weighted average; saturation-free, unit-safe bijective transforms; Delay / update_every / FixedNumConn correctness; numerous Module and collective-op fixes (including a vmap_new_states BatchTracer leak).
brainstate.transform mapping engine (#216) — eight vmap / pmap / shard_map bugs: warm/cold consistency for 'auto' RMW states, pmap2_new_states without RandomState, RMW-vs-scatter disambiguation, axis_size and 0-d validation, clearer shard_map spec errors, and static_argnums no longer mapping its argument.
brainstate.nn.exp_euler (#210) — corrected Jacobian unit conversion in the drift calculation.
Cross-module hardening (#217) — resolves every dev/issues.md finding; assert-based validation (stripped under python -O) is replaced with descriptive TypeError / ValueError across nn, random, transform, util, graph, interop, and the core.
Quality
- Full suite: 5296 passed, 23 skipped;
mypyclean; patch coverage 100% (audit) / 98% (mapping engine). - Verified green on the CI JAX matrix: 0.7.0, 0.8.0, 0.9.0, latest.
Full Changelog: v0.4.2...v0.5.0