Skip to content

brainstate 0.5.0

Latest

Choose a tag to compare

@chaoming0625 chaoming0625 released this 13 Jun 17:10
1fa7964

A repository-wide correctness release. Following the brainstate.transform audit shipped in 0.4.x, this cycle extended the same expert-audit discipline to nearly every remaining module — random, graph, interop, nn, util, the vmap / pmap / shard_map mapping engine, and the exp_euler integrator — and closed out a single consolidated cross-module audit. Every fix ships with a behavioral regression test, and the suite is green across the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, latest). The release also lands a graph-layer performance pass.

No public APIs are removed or renamed. The only behavioral changes are previously-silent wrong-result or invalid-input paths that now fail loudly with descriptive errors.

Performance

  • Graph flatten/unflatten fast paths (#218): type-keyed value-classification cache backing the node predicates, encoder dispatch, and flattening kernel; exact-type decoder dispatch; all-static hashable pytrees collapse to a single StaticEdge; graph_to_tree reads States directly from the RefMap; shared States de-duplicated in iter_leaf / states.

Bug Fixes

brainstate.random (#211) — six distribution bugs, each contradicting its own docstring:

  • standard_t with array df and size=None (dead shape(size) branch) now infers shape from df.
  • weibull_min now multiplies by scale (was dividing).
  • triangular reimplemented as the true triangular(left, mode, right, size) via inverse-CDF (was a Rademacher ±1 draw).
  • geometric now supports {1,2,...} with an integer dtype and P(k==1) == p (was off-by-one, float).
  • randint_like default high uses u.math.max (no longer raises on >1-D templates).
  • chisquare uses the 2·Gamma(df/2) relation, valid for any positive real / array df.

brainstate.graph (#212) — merge_context yields the live index dict; Node.check_valid_context derives validity from reachable States instead of raising AttributeError; pop_states detaches every shared alias of a popped state.

brainstate.interop (#213) — nnx Conv input_dilation guard; norm-channel extraction from framework metadata (fixes affine-less LayerNorm / RMSNorm / GroupNorm); bst_set_norm / bst_set_batchnorm None-handling; lookup_export no longer rebuilds an O(N) dict per call.

brainstate.nn (#215) — module-wide audit: dropout self-normalizing constants & unbatched mask dims; default softmax axis; ScaledWSLinear / AllToAll shapes; Precision / Recall weighted average; saturation-free, unit-safe bijective transforms; Delay / update_every / FixedNumConn correctness; numerous Module and collective-op fixes (including a vmap_new_states BatchTracer leak).

brainstate.transform mapping engine (#216) — eight vmap / pmap / shard_map bugs: warm/cold consistency for 'auto' RMW states, pmap2_new_states without RandomState, RMW-vs-scatter disambiguation, axis_size and 0-d validation, clearer shard_map spec errors, and static_argnums no longer mapping its argument.

brainstate.nn.exp_euler (#210) — corrected Jacobian unit conversion in the drift calculation.

Cross-module hardening (#217) — resolves every dev/issues.md finding; assert-based validation (stripped under python -O) is replaced with descriptive TypeError / ValueError across nn, random, transform, util, graph, interop, and the core.

Quality

  • Full suite: 5296 passed, 23 skipped; mypy clean; patch coverage 100% (audit) / 98% (mapping engine).
  • Verified green on the CI JAX matrix: 0.7.0, 0.8.0, 0.9.0, latest.

Full Changelog: v0.4.2...v0.5.0