brainstate 0.4.2
A correctness-hardening patch release for brainstate.transform. A JAX-expert audit of the state-based transformation layer — jit, grad / vector_grad / jacobian / hessian, cond / switch / ifelse, the bounded and collecting loops, the state-aware mapping engine, shard_map, checkify, named_scope, and checkpoint — surfaced a family of stale-cache, tracer-leak, and silent-misbehavior bugs. This release fixes every reproduced issue and tightens argument validation so that previously silent wrong-result paths now fail loudly. The minimum supported JAX is raised to 0.7.0. Each fix ships with a regression test verified to fail before and pass after the change (#207, #208).
Bug Fixes
- Stale compiled trace after an out-of-band state change: when a captured
State's shape or dtype changes between calls,StatefulFunctionno longer replays a stale cached jaxpr (which silently produced wrong results). A state-aval mismatch is now treated as a cache miss, triggering recompilation acrossget_arg_cache_key,make_jaxpr, and__call__(#207). cond/switch/ifelsewith asymmetric branch state access: fixed a crash when a state is written in one branch but only read in others, and fixed a state-value misalignment between the merged trace order and each branch's own trace order in the multi-branch wrappers (#207).bounded_while_loopcorrectness: fixed wrong results caused by the checkpointed-scan counter bump leaking into user carries, bymax_steps=1ignoring the loop condition, by missing per-lane masking undervmap, and by iteration-cap overshoot (#207).- Tracer leaks on the failure path:
make_jaxpr, the state-aware mapping engine,shard_map,checkify,vmap_new_states,map, andeval_shapenow snapshot and restore original state values (including RNG backups) when the wrapped execution raises, so a failed trace no longer leaves dead tracers in global states. The mapping engine additionally detects a stale cached plan via a write-set watcher and rebuilds it once before failing (#207). - States created inside a trace no longer leak a dead tracer: such a
Stateis poisoned after tracing with an_InvalidatedTraceValuesentinel — reading it raises a descriptiveTraceContextError, and assigning a concrete value clears the poison (#207). - Cached compilations no longer retain enclosing-trace tracers: original-value snapshots are replaced with their avals before a trace is cached, so
grad-under-jitnow passesjax.checking_leaks()(#207). grad(..., debug_nan=True): fixed anAttributeErrorwhen the transformed callable is afunctools.partial(which has no__name__); under an enclosing trace, the NaN flag is now routed throughlax.condplus an ordered callback instead of being concretized (which raisedTracerBoolConversionErrorunderjit) (#207).hessianblock structure: results are now returned structured likegrad_statesrather than exposing internal id-keyed dictionaries (#207).- Ahead-of-time
jitpaths (eval_shape/lower/trace/compile) no longer perform a spurious state writeback that marked read-only states as written in an enclosing trace (#207). Statespassed via keyword arguments are no longer silently flattened: the in-kwargsstate check now runs before abstractification inget_arg_cache_key(#207).named_scope: jit-compiled functions are now cached per static configuration; aconda:falsetrace-name typo incond, an incorrectifelsedocstring example, and documentation for nonexistentnon_static_*parameters were all corrected (#207).NewStateCatcher.get_by_tagnow matches against the catcher's tag set instead of failing to find tagged states (#207).
Behavior Changes (stricter validation)
The following paths previously produced silently wrong results or accepted invalid input; they now raise descriptive errors:
- Writing a tracer into a pre-existing
Stateoutside abrainstatetrace (for example under rawjax.jit/vmap/grad/scan) now raises aTraceContextErrorinstead of silently storing the tracer. States created inside the current JAX trace remain exempt, since they die with that trace (#207). grad/vector_grad/jacobian/hessianreject negative and non-integerargnumsup front instead of differentiating the wrong argument;hessianadditionally rejects thegrad_states+argnumscombination (#207).jitaligns user-suppliedin_shardings/out_shardingswith the internally prepended state argument and rejects negativestatic_argnums/donate_argnums;checkpoint/rematlikewise reject negativestatic_argnums(#207).- Unhashable static arguments raise an actionable
TypeError(#207). checkpointed_scanraises a clearValueErrorforlength < 1instead of a math-domain error, andProgressBarfrequency validation raisesValueErrorrather than failing anassert(#207).
Build
- Minimum JAX raised to
>=0.7.0(previously>=0.6.0) across allpyproject.tomlextras (cpu,cuda12,cuda13,tpu,testing) andrequirements.txt(#208).
Full Changelog: v0.4.1...v0.4.2