Skip to content

brainstate 0.4.2

Choose a tag to compare

@chaoming0625 chaoming0625 released this 10 Jun 08:54
· 13 commits to main since this release
64c40a3

A correctness-hardening patch release for brainstate.transform. A JAX-expert audit of the state-based transformation layer — jit, grad / vector_grad / jacobian / hessian, cond / switch / ifelse, the bounded and collecting loops, the state-aware mapping engine, shard_map, checkify, named_scope, and checkpoint — surfaced a family of stale-cache, tracer-leak, and silent-misbehavior bugs. This release fixes every reproduced issue and tightens argument validation so that previously silent wrong-result paths now fail loudly. The minimum supported JAX is raised to 0.7.0. Each fix ships with a regression test verified to fail before and pass after the change (#207, #208).

Bug Fixes

  • Stale compiled trace after an out-of-band state change: when a captured State's shape or dtype changes between calls, StatefulFunction no longer replays a stale cached jaxpr (which silently produced wrong results). A state-aval mismatch is now treated as a cache miss, triggering recompilation across get_arg_cache_key, make_jaxpr, and __call__ (#207).
  • cond / switch / ifelse with asymmetric branch state access: fixed a crash when a state is written in one branch but only read in others, and fixed a state-value misalignment between the merged trace order and each branch's own trace order in the multi-branch wrappers (#207).
  • bounded_while_loop correctness: fixed wrong results caused by the checkpointed-scan counter bump leaking into user carries, by max_steps=1 ignoring the loop condition, by missing per-lane masking under vmap, and by iteration-cap overshoot (#207).
  • Tracer leaks on the failure path: make_jaxpr, the state-aware mapping engine, shard_map, checkify, vmap_new_states, map, and eval_shape now snapshot and restore original state values (including RNG backups) when the wrapped execution raises, so a failed trace no longer leaves dead tracers in global states. The mapping engine additionally detects a stale cached plan via a write-set watcher and rebuilds it once before failing (#207).
  • States created inside a trace no longer leak a dead tracer: such a State is poisoned after tracing with an _InvalidatedTraceValue sentinel — reading it raises a descriptive TraceContextError, and assigning a concrete value clears the poison (#207).
  • Cached compilations no longer retain enclosing-trace tracers: original-value snapshots are replaced with their avals before a trace is cached, so grad-under-jit now passes jax.checking_leaks() (#207).
  • grad(..., debug_nan=True): fixed an AttributeError when the transformed callable is a functools.partial (which has no __name__); under an enclosing trace, the NaN flag is now routed through lax.cond plus an ordered callback instead of being concretized (which raised TracerBoolConversionError under jit) (#207).
  • hessian block structure: results are now returned structured like grad_states rather than exposing internal id-keyed dictionaries (#207).
  • Ahead-of-time jit paths (eval_shape / lower / trace / compile) no longer perform a spurious state writeback that marked read-only states as written in an enclosing trace (#207).
  • States passed via keyword arguments are no longer silently flattened: the in-kwargs state check now runs before abstractification in get_arg_cache_key (#207).
  • named_scope: jit-compiled functions are now cached per static configuration; a conda:false trace-name typo in cond, an incorrect ifelse docstring example, and documentation for nonexistent non_static_* parameters were all corrected (#207).
  • NewStateCatcher.get_by_tag now matches against the catcher's tag set instead of failing to find tagged states (#207).

Behavior Changes (stricter validation)

The following paths previously produced silently wrong results or accepted invalid input; they now raise descriptive errors:

  • Writing a tracer into a pre-existing State outside a brainstate trace (for example under raw jax.jit / vmap / grad / scan) now raises a TraceContextError instead of silently storing the tracer. States created inside the current JAX trace remain exempt, since they die with that trace (#207).
  • grad / vector_grad / jacobian / hessian reject negative and non-integer argnums up front instead of differentiating the wrong argument; hessian additionally rejects the grad_states + argnums combination (#207).
  • jit aligns user-supplied in_shardings / out_shardings with the internally prepended state argument and rejects negative static_argnums / donate_argnums; checkpoint / remat likewise reject negative static_argnums (#207).
  • Unhashable static arguments raise an actionable TypeError (#207).
  • checkpointed_scan raises a clear ValueError for length < 1 instead of a math-domain error, and ProgressBar frequency validation raises ValueError rather than failing an assert (#207).

Build

  • Minimum JAX raised to >=0.7.0 (previously >=0.6.0) across all pyproject.toml extras (cpu, cuda12, cuda13, tpu, testing) and requirements.txt (#208).

Full Changelog: v0.4.1...v0.4.2