0.3.0 Release
[0.3.0] — 2026-06-10
Added
-
Jit-once NRPT round loop — the Gibbs + DEO swap scan now lives in a
module-leveleqx.filter_jitfunction, so the compilation cache persists
acrossnrptcalls. Templates already at β = 1 are reused without
rebasing, andnrpt_adaptiverebases once before the phase loop, so all
tuning phases plus production trace and compile exactly once (β arrays
and states are traced data). Verified by a trace-count regression test;
measured 1.80s → 0.54s (~3.3×) pernrpt_adaptivecall on an 8-chain
48×48 Ising benchmark (CPU). -
Temperature-linear NRPT mode —
nrptnow accepts a single template
(ebm, program)pair (plus an explicitbetasarray) instead of per-chain
sequences. One base program is built at β = 1 and every interaction array
is scaled by the chain's β inside the vmapped Gibbs kernel — valid for any
model whose interactions are linear in β (theDiscreteEBMFactorfamily),
consistent with the E_β = β·E_base assumption the swap math already makes.
This avoids constructing one program per chain and storing per-chain copies
of every interaction tensor (n_chains× less interaction memory).
nrpt_adaptiveanddiscover_chain_countuse this mode automatically on
their template (ebm=/program=) routes, eliminating all per-phase
EBM/program rebuilds during schedule tuning (~22% faster adaptive tuning
on CPU for an 8-chain 48×48 Ising benchmark). Results are bit-identical
to the per-chain-programs path. Explicit factory routes are unchanged.
Breaking
- Minimum supported Python is now 3.11 (was 3.10). The JAX (≥ 0.9) and
jaxtyping (≥ 0.3) releases hamon is developed against both require
Python ≥ 3.11, so the previous 3.10 floor could only resolve to stale
dependency versions. Python 3.14 is now supported and tested in CI. - Padded interaction entries are pre-zeroed at program construction —
BlockSamplingProgrammasks the sliced interaction tensors with the active
flags once, and the built-in spin/categorical conditionals no longer
multiply by the active mask on every Gibbs step. Custom samplers invoked
through aBlockSamplingProgrammay now rely on inactive entries being
zero; samplers called directly with hand-built (unmasked) interactions must
keep applying the active flags themselves.
Changed
- jaxtyping floor raised to 0.3.10 (was 0.2.23) — picks up the fix for
PyTree[A | B]isinstance checks silently passing (0.3.9) and cloudpickle
round-trips of variadic annotations likeShaped[Array, "..."](0.3.10) - optax floor (testing extra) raised to 0.2.8 (was 0.2.4) — older optax
imports thejax_pmap_shmap_mergeconfig option removed in JAX 0.10.0 and
fails at import - CI test and example matrices now cover Python 3.11–3.14 (was 3.10–3.13)
- Gibbs scan carries one copy of the sampling state, not two — the
_run_blocksscan previously threaded both the per-block state list and
the concatenated global state through the carry, although samplers only
read the global state. The carry now holds just the sampler states and
the global state; per-block states are extracted once after the scan.
from_global_stategained the same contiguous-slice fast path as the
write-back side, so the extraction lowers to static slices. Results are
bit-identical; measured ~6% faster NRPT rounds on CPU (the duplicate
carry was multiplied byvmapacross chains) and ~3% faster plain Gibbs. - Block write-back uses contiguous slice updates instead of scatters —
free blocks always occupy contiguous ranges of the global state (a
BlockSpeclayout invariant), so the per-block write-back in the Gibbs
scan and inscatter_block_to_globalnow lowers to
lax.dynamic_update_slicewith a precomputed static offset instead of a
gather-index scatter, which XLA fuses far better (the isolated op is ~30×
faster on CPU; scatters are disproportionately expensive on GPU).
Non-contiguous node sets keep the scatter fallback. Results are
bit-identical. track_round_trips=Falsenow skips the index-process update inside the
swap pass instead of only omitting the summary.- Conditional samplers accumulate in the weights' dtype — the spin and
categorical Gibbs conditionals previously seeded their parameter
accumulators with float32 zeros regardless of the model dtype. - Documentation refreshed for the new internals:
architecture.mdnow
describes the concatenated (not padded) global state layout and the actual
index-process representation; stalehinton_initandCategoricalNode
docstrings corrected.
Fixed
sample_blocksno longer mutates the caller's state and sampler-state
lists.- NRPT observers received post-swap states paired with pre-swap energies —
in the default (non-cached) energy mode, the base-energy vector was not
permuted after accepted swaps before being handed to the observer, so
base_energies[c]described the state that used to occupy chainc.
Energies are now permuted alongside the states in both energy modes. - β₀ = 0 produced NaN base energies in NRPT, silently rejecting every swap —
nrptrecovered base energies by dividing the hottest chain's energy by β₀,
which is 0/0 when the ladder is anchored at the reference distribution
(beta_range=(0.0, ...), thediscover_chain_countdefault and the range
ising_sampleuses). Swap acceptance became NaN and every swap was rejected
with no error, degrading parallel tempering into independent chains and
inflating Λ estimates. Base energies are now computed from an exact β = 1
copy of the EBM viawith_beta(), falling back to the coldest chain's β for
EBM classes withoutwith_beta()(raising a clear error if that β is 0).