What's the issue?
Updating to jax==0.5.* causes a number of test failures. Mostly these seem to be precision issues (arrays being equal to only four-ish decimal places, which doesn't pass the more demanding tests), but also some of the kernel gradients seem to be being flipped.
Remove the !=0.5.* specifier from the jax dependency in pyproject.toml once done.