New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add bicgstab implementation #3796
Conversation
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
@googlebot I signed it! |
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
jax/scipy/sparse/linalg.py
Outdated
---------- | ||
A : function | ||
Function that calculates the matrix-vector product ``Ax`` when called | ||
like ``A(x)``. ``A`` must represent a hermitian, positive definite |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought bicgstab works on arbitrary matrices, without requiring symmetry? It would be good to clarify this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a copying error, you are correct!
jax/scipy/sparse/linalg.py
Outdated
isolve = partial(_isolve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M) | ||
|
||
real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating) | ||
symmetric = all(map(real_valued, tree_leaves(b))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this condition is only true for cg
, not bicgstab
? I would expect arguments to bicgstab
to essentially never be symmetric. As I understand it, if they were, there are better solvers to use (e.g., cg
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you have a matrix that is symmetric but NOT positive definite (e.g. pos semidefinite), this would still be useful, no? Don't know how often that would be the case in practice but it's certainly a possibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, but either way it's definitely a different condition for bicgstab
vs cg
. All real values for the solution does not imply that it's symmetric.
So it seems like isolve
should grow a symmetric
argument, which could be set in a solver specific way. Potentially we could even expose a new symmetric
argument on the public bicgstab
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, and I also wonder if there is some utility in treating isolve
as a public function or no? I can give it an underscore in the name if not!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, all the helper functions should definitely be marked as private.
jax/scipy/sparse/linalg.py
Outdated
# TODO(?): stop early | ||
# It requires accessing cond_fun like this, but it's | ||
# not possible with jit... | ||
# if cond_fun((x, s, r0, alpha_, omega, rho_, p_, q_, k)): | ||
# x_ = _add(x, _mul(alpha_, phat)) | ||
# return x_, s, rhat, alpha_, omega, rho_, p_, q, k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably do this with some cleverness (e.g., involving lax.cond
), but I think it's actually totally fine to skip it -- at least for now. In the best case, this would save one extra function evaluation.
On GPUs, dynamic control flow is so expensive that I expect we could actually get better performance in many cases by only checking for stopping every n
steps.
tests/lax_scipy_sparse_test.py
Outdated
|
||
rng = jtu.rand_default(self.rng()) | ||
A = poisson(shape, dtype) # use scipy's test to check bicgstab instead of random matrix | ||
# A = rand_sym_pos_def(rng, shape, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be in favor of testing the Poisson equation as well as random matrices, but I would definitely test random matrices (with no structure) as well. This seems much more likely to turn up issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree, I thought of this as well. I will implement that suggestion!
tests/lax_scipy_sparse_test.py
Outdated
def test_bicgstab_pytree(self): | ||
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} | ||
b = {"a": 1.0, "b": -4.0} | ||
expected = {"a": 4.0, "b": -6.0} | ||
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b) | ||
self.assertEqual(expected.keys(), actual.keys()) | ||
self.assertAlmostEqual(expected["a"], actual["a"], places=6) | ||
self.assertAlmostEqual(expected["b"], actual["b"], places=6) | ||
|
||
def test_bicgstab_errors(self): | ||
A = lambda x: x | ||
b = jnp.zeros((2,)) | ||
with self.assertRaisesRegex( | ||
ValueError, "x0 and b must have matching tree structure"): | ||
jax.scipy.sparse.linalg.bicgstab(A, {'x': b}, {'y': b}) | ||
with self.assertRaisesRegex( | ||
ValueError, "x0 and b must have matching shape"): | ||
jax.scipy.sparse.linalg.bicgstab(A, b, b[:, np.newaxis]) | ||
|
||
def test_bicgstab_without_pytree_equality(self): | ||
|
||
@register_pytree_node_class | ||
class MinimalPytree: | ||
def __init__(self, value): | ||
self.value = value | ||
def tree_flatten(self): | ||
return [self.value], None | ||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): | ||
return cls(*children) | ||
|
||
A = lambda x: MinimalPytree(2 * x.value) | ||
b = MinimalPytree(jnp.arange(5.0)) | ||
expected = b.value / 2 | ||
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b) | ||
self.assertAllClose(expected, actual.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have both cg
and bicgstab
using most of the same code. So I don't think it provides much value to add these repeated tests. I would either drop them or make the original tests parametric so they test both cg
and bicgstab
with the same method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, will do!
tests/lax_scipy_sparse_test.py
Outdated
a = poisson(shape, dtype) # TODO(sunilkpai): random doesn't work (bicg numerics), so using poisson instead... | ||
# a = rng(shape, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is at least a little surprising. I wouldn't expect random 2x2 matrices to be a problem -- that suggests that maybe there's a bug in the algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I am surprised as well. I think having breakdown checks is perhaps important. If scipy
returns -10
or -11
that means rho = 0
or omega = 0
respectively and maybe that is resulting in some issues. I will look into this; I do recall small matrices actually trips up even scipy
's bicgstab
.
jax/scipy/sparse/linalg.py
Outdated
shat = M(s) | ||
t = A(shat) | ||
# omega_ = _vdot_tree(_conj(s), t) / _vdot_tree_real(t, t) | ||
omega_ = _vdot_tree(s, t) / _vdot_tree_real(t, t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned that this might be effectively dividing by zero if the algorithm converges in the first half of the step. I wonder if this could be related to your testing challenges?
If I'm understanding correctly, this might not be happening for cg
only because demonator is also the termination condition on the loop.
One way to guard against this would be replace the denominator with something like jnp.maximum(_vdot_tree_real(t, t), epsilon)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to add that in as I have seen it used in other contexts, but I haven't really seen this done in scipy's fortran revcom or any other pseudocodes. That said, I have noticed that my implementation and scipy tend to have similar breakdown scenarios (usually we add breakdown checks for whether rho or omega are zero and that is returned as part of info
). Do you recommend we add this in anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My general rule of thumb is to try not to innovate in the details of numerical methods. I'm not an expert on most of them, and it can be easy to get the details wrong.
I would definitely suggest copying the safety/breakdown checks you notice in other implementation like SciPy rather than my ideas! Let me know if you have questions about how to implement them in JAX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The safety/breakdown checks might require some deeper changes in jax (similar to why info
is currently None
). Might be a future thing to add after getting the basic bicgstab
to work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should definitely be able to add early breakdown without pervasive changes to JAX.
There are at least two ways to do it:
- You could nest a call to
jax.lax.cond
, either returning half-way through or doing the single GMRES step. - You could keep track of whether early existing should happen in a boolean variable, and use
lax.select
/jnp.where
(basically the same) to selectively update variables, e.g., compare how loopy logic for sampling from the Poisson distribution is implemented in JAX vs NumPy:
They have slightly different performance tradeoffs, but either solution should be fine here. (1) might be a little easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is helpful!
It's fine to relax error tolerances in gradient checks as long as we are still quite confident that the result is correct (up to numerical precision). For example, I just tried switching the cg gradient test to a random 4x4 matrix. This turns up the following error:
But really, this is close enough. So if we used these tests, I would be fine relaxing the error tolerance as needed. |
Hi @shoyer I'm still fighting my code to get scipy and bicgstab to agree; they agree for many cases (especially higher numerical precision) but begin to deviate in some special cases. I've done a number of things to check what the differences are and I'm beginning to suspect that it has something to do with numerical precision becoming increasingly important after several steps (generally not after the first step). I checked scipy's implementation (and even hacked into their The test I have started to run now is to write a numpy version of Here is some example failure case output tracking
Incidentally, after using more random cases than is currently recommended, I noticed |
@sunilkpai This looks close enough to me? How does the accuracy of the final solution compare? SciPy is probably always using float64 precision. It might be worth enabling x64 mode for JAX, just for the sake of comparison. This currently requires a special flags: If the error does not decrease when you turn on x64 mode, then perhaps we are missing something important and should be worried! In that case I would look into the early termination conditions. |
@shoyer Yeah, you might be right for this case, but I think this becomes magnified as |
Also, re: early termination, this case may actually be necessary when s = 0 for |
21266dc
to
807081a
Compare
Hi @shoyer this is still wip but I am getting closer I think thanks to your help! I have been adding your suggestions for the testing framework. Some checks between numpy and scipy are necessary (since there are differences and numerical failure cases) and therefore are now part of the tests. The |
I am working on the
bicgstab
implementation (using the current pytree convention incg
); I am most of the way there as the implementation was very similar to cg and the tests generally seem to work. This addresses one of the goals of #1531.Before finalizing this PR, there are some points of confusion on the tests that have to do with numerical issues in iterative solvers. I would like some advice on this before proceeding.
jax
. I noticed the biggest shape in the gradient tests*_solve
is(2, 2)
, and when I checked larger shapes, there was an error (depending on the type, either array not equal or grad check failure)! What is the general consensus on testing with respect to gradient checks, especially for these approximate solvers?scipy
forcg
andbicgstab
more generally, I think this will need to be implemented somehow. I have a TODO comment somewhere in this PR explaining how it would be implemented ifjit
were not a requirement.I also want to make sure that the changes I made to
cg
to avoid code duplication are kosher.