Skip to content

Use batched library routines when available. #28544

@jpbrodrick89

Description

@jpbrodrick89

I recently performed some profiling comparing the tridiagonal solvers in jax (jax.lax.linalg.tridiagonal_solve) with lineax. I observed that the jax call to the cusparse gtsv2 routine on GPU was much faster than the partially unrolled lineax implementation of the Thomas algorithm when unbatched or for small batch sizes. However, if I increase the batch size there is a crossover point at which lineax is faster because jax.lax.linalg scales ~linearly with batch-size and lineax scales sub-linearly/in constant time (and if I manually increase the unroll in lineax to fully unrolled then lineax is faster when any batch size).

I was looking at making a PR in lineax which used the best of both worlds, but then I started to wonder why the jax implementation did not scale as well. It turns out that the jax implementation does not use the cusparse batch implementations (i.e. cusparse<t>gtsv2StridedBatch() and cusparse<t>gtsvInterleavedBatch()). Instead it just performs a for loop over the batch dimensions.

My immediate objective is to get this addressed sooner rather than later to enable clean completion of my lineax PR that I don't need to change at a later date while improving or maintaining current performance. Note my current use case does not include batch solves (I am doing an implicit solve of the 1D heat equation). In the future I may perhaps want to use the batch solve for ADI, multi-group diffusion, or batched simulations.

I would be more than happy to help but I begin to lose track at the Python/C++ barrier when trawling through the source code.

I have worded the issue title in a very general way, if we just want it to refer to tridiagonal solves happy to rename.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions