Skip to content

0.3.0 Release

Choose a tag to compare

@dek3rr dek3rr released this 10 Jun 13:50
· 8 commits to main since this release

[0.3.0] — 2026-06-10

Added

  • Jit-once NRPT round loop — the Gibbs + DEO swap scan now lives in a
    module-level eqx.filter_jit function, so the compilation cache persists
    across nrpt calls. Templates already at β = 1 are reused without
    rebasing, and nrpt_adaptive rebases once before the phase loop, so all
    tuning phases plus production trace and compile exactly once (β arrays
    and states are traced data). Verified by a trace-count regression test;
    measured 1.80s → 0.54s (~3.3×) per nrpt_adaptive call on an 8-chain
    48×48 Ising benchmark (CPU).

  • Temperature-linear NRPT modenrpt now accepts a single template
    (ebm, program) pair (plus an explicit betas array) instead of per-chain
    sequences. One base program is built at β = 1 and every interaction array
    is scaled by the chain's β inside the vmapped Gibbs kernel — valid for any
    model whose interactions are linear in β (the DiscreteEBMFactor family),
    consistent with the E_β = β·E_base assumption the swap math already makes.
    This avoids constructing one program per chain and storing per-chain copies
    of every interaction tensor (n_chains× less interaction memory).
    nrpt_adaptive and discover_chain_count use this mode automatically on
    their template (ebm=/program=) routes, eliminating all per-phase
    EBM/program rebuilds during schedule tuning (~22% faster adaptive tuning
    on CPU for an 8-chain 48×48 Ising benchmark). Results are bit-identical
    to the per-chain-programs path. Explicit factory routes are unchanged.

Breaking

  • Minimum supported Python is now 3.11 (was 3.10). The JAX (≥ 0.9) and
    jaxtyping (≥ 0.3) releases hamon is developed against both require
    Python ≥ 3.11, so the previous 3.10 floor could only resolve to stale
    dependency versions. Python 3.14 is now supported and tested in CI.
  • Padded interaction entries are pre-zeroed at program construction
    BlockSamplingProgram masks the sliced interaction tensors with the active
    flags once, and the built-in spin/categorical conditionals no longer
    multiply by the active mask on every Gibbs step. Custom samplers invoked
    through a BlockSamplingProgram may now rely on inactive entries being
    zero; samplers called directly with hand-built (unmasked) interactions must
    keep applying the active flags themselves.

Changed

  • jaxtyping floor raised to 0.3.10 (was 0.2.23) — picks up the fix for
    PyTree[A | B] isinstance checks silently passing (0.3.9) and cloudpickle
    round-trips of variadic annotations like Shaped[Array, "..."] (0.3.10)
  • optax floor (testing extra) raised to 0.2.8 (was 0.2.4) — older optax
    imports the jax_pmap_shmap_merge config option removed in JAX 0.10.0 and
    fails at import
  • CI test and example matrices now cover Python 3.11–3.14 (was 3.10–3.13)
  • Gibbs scan carries one copy of the sampling state, not two — the
    _run_blocks scan previously threaded both the per-block state list and
    the concatenated global state through the carry, although samplers only
    read the global state. The carry now holds just the sampler states and
    the global state; per-block states are extracted once after the scan.
    from_global_state gained the same contiguous-slice fast path as the
    write-back side, so the extraction lowers to static slices. Results are
    bit-identical; measured ~6% faster NRPT rounds on CPU (the duplicate
    carry was multiplied by vmap across chains) and ~3% faster plain Gibbs.
  • Block write-back uses contiguous slice updates instead of scatters
    free blocks always occupy contiguous ranges of the global state (a
    BlockSpec layout invariant), so the per-block write-back in the Gibbs
    scan and in scatter_block_to_global now lowers to
    lax.dynamic_update_slice with a precomputed static offset instead of a
    gather-index scatter, which XLA fuses far better (the isolated op is ~30×
    faster on CPU; scatters are disproportionately expensive on GPU).
    Non-contiguous node sets keep the scatter fallback. Results are
    bit-identical.
  • track_round_trips=False now skips the index-process update inside the
    swap pass instead of only omitting the summary.
  • Conditional samplers accumulate in the weights' dtype — the spin and
    categorical Gibbs conditionals previously seeded their parameter
    accumulators with float32 zeros regardless of the model dtype.
  • Documentation refreshed for the new internals: architecture.md now
    describes the concatenated (not padded) global state layout and the actual
    index-process representation; stale hinton_init and CategoricalNode
    docstrings corrected.

Fixed

  • sample_blocks no longer mutates the caller's state and sampler-state
    lists.
  • NRPT observers received post-swap states paired with pre-swap energies
    in the default (non-cached) energy mode, the base-energy vector was not
    permuted after accepted swaps before being handed to the observer, so
    base_energies[c] described the state that used to occupy chain c.
    Energies are now permuted alongside the states in both energy modes.
  • β₀ = 0 produced NaN base energies in NRPT, silently rejecting every swap
    nrpt recovered base energies by dividing the hottest chain's energy by β₀,
    which is 0/0 when the ladder is anchored at the reference distribution
    (beta_range=(0.0, ...), the discover_chain_count default and the range
    ising_sample uses). Swap acceptance became NaN and every swap was rejected
    with no error, degrading parallel tempering into independent chains and
    inflating Λ estimates. Base energies are now computed from an exact β = 1
    copy of the EBM via with_beta(), falling back to the coldest chain's β for
    EBM classes without with_beta() (raising a clear error if that β is 0).