Skip to content

Add parallel associative-scan algorithms for the quasiseparable solver#269

Merged
dfm merged 1 commit intomainfrom
parallel-associative-scan
May 9, 2026
Merged

Add parallel associative-scan algorithms for the quasiseparable solver#269
dfm merged 1 commit intomainfrom
parallel-associative-scan

Conversation

@dfm
Copy link
Copy Markdown
Owner

@dfm dfm commented May 9, 2026

Note

This builds on and supersedes #210 which just included the parallel matmul implementation.

This PR adds jax.lax.associative_scan-based implementations of the core quasiseparable operations alongside the existing jax.lax.scan versions, and threads a parallel: bool flag from the GP layer down to select between them.

Usage

gp = GaussianProcess(kernel, x, diag=diag, parallel=True)

or directly on the QSM types:

qsm.cholesky(parallel=True)
qsm.matmul(y, parallel=True)
factor.solve(y, parallel=True)

The default remains parallel=False (sequential lax.scan), so existing behavior is unchanged.

Why

The sequential scans have O(N) depth, which serializes badly on GPUs. The associative-scan formulations have O(log N) depth at the cost of a constant factor more FLOPs, which is a much better fit for accelerators.

Rough wall-clock numbers on a single GPU at J=8:

op N sequential parallel speedup
matmul 65k 1.3 s 0.9 ms ~1500×
matmul 262k 5.6 s 2.1 ms ~2600×
cholesky 65k 2.4 s 8.4 ms ~300×
cholesky 262k 10.4 s 18 ms ~600×

On CPU the picture is mixed: the parallel matmul is roughly competitive with the sequential version at large N, but the parallel Cholesky is several times slower on CPU because each combine step does a small linalg.solve. So parallel=True is recommended for GPU/TPU and parallel=False (the default) for CPU.

Math

The matmul and triangular-solve recurrences are affine in the carry, $f_n = a_n f_{n-1} + b_n$, so they parallelize directly as a prefix scan over the monoid $(A, B) \bullet (A', B') = (A' A,\ A' B + B')$. The Cholesky carry is a Riccati-type (quadratic-over-linear) update that does not fit the affine monoid. Reading each step as a Kalman predict-then-update identifies it with the associative filtering element of Särkkä & García-Fernández: each step is represented by a triple $(A, F, G)$ and composition follows from multiplying the corresponding $2J \times 2J$ Hamiltonian matrices, giving a combine that needs one $J \times J$ linear solve. The forward pass of SymmQSM.inv carries the same Riccati state, and its backward pass reduces to a linear conjugation recurrence $z_k = \ell_k^\top z_{k+1} \ell_k + B_k$ that uses the affine-style operator again.

What's included

  • solvers/quasisep/ops.py: parallel implementations of lower_matmul, upper_matmul, lower_solve, upper_solve, cholesky, and symm_inv, plus their sequential counterparts factored out of core.py. The Cholesky and symmetric-inverse forward passes share a common _riccati_scan helper.
  • solvers/quasisep/core.py: all matmul/solve/cholesky/inv methods on the QSM classes take parallel: bool = False and dispatch to ops.py.
  • solvers/quasisep/solver.py: QuasisepSolver accepts parallel and uses it for the factorization, triangular solves, and matrix products.
  • solvers/quasisep/block.py: Block gains .mT, batched to_dense(), and batched __matmul__/__rmatmul__. This also fixes a pre-existing bug where LowerTriQSM.inv() (used in condition) failed on summed kernels.

SquareQSM.inv remains sequential — its asymmetric Riccati recurrence needs a larger associative operator and it isn't on any hot GP path.

Tests

  • New test_ops.py checks each parallel op against its sequential reference on Matern32/Matern52 kernels.
  • test_core.py parameterizes the matmul/solve/cholesky tests over parallel={False, True}.
  • test_solver.py::test_consistent_with_direct is parameterized over parallel and gains a summed-kernel case to exercise the Block path end-to-end.

@dfm dfm force-pushed the parallel-associative-scan branch from cff229a to 6003864 Compare May 9, 2026 18:31
The existing quasiseparable operations use jax.lax.scan with O(N) sequential
depth, which serializes badly on accelerators. This adds
jax.lax.associative_scan-based implementations with O(log N) depth, selectable
via a parallel=True flag threaded from GaussianProcess through QuasisepSolver
down to the QSM matmul/solve/cholesky/inv methods.

- ops.py: parallel and sequential implementations of lower/upper matmul,
  lower/upper solve, cholesky, and symm_inv. The cholesky and symm_inv forward
  passes share a common Riccati associative scan.
- core.py: QSM methods take parallel: bool = False and dispatch to ops.
- solver.py: QuasisepSolver(..., parallel=True) uses the parallel path for
  factorization, triangular solves, and matrix products.
- block.py: Block gains .mT, batched to_dense(), and batched matmul. This also
  fixes a pre-existing crash in LowerTriQSM.inv() on summed kernels.

SquareQSM.inv remains sequential.
@dfm dfm force-pushed the parallel-associative-scan branch from 6003864 to f4e651d Compare May 9, 2026 18:31
@dfm dfm merged commit f3466f0 into main May 9, 2026
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant