New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/bicgstab #5299
Feature/bicgstab #5299
Conversation
jax/_src/scipy/sparse/linalg.py
Outdated
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_ | ||
|
||
r0 = _sub(b, A(x0)) | ||
# need to use _vdot_tree to match dtype (hacky...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could tree map ones_like
I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I am just trying to get the number 1 to match the dtype (complex or float, 32 or 64 bit). How would that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry... thought I was being helpful with my drive by comment... so you actually want a scalar 1
. Perhaps https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.result_type.html#jax.numpy.result_type would be helpful, but I'll defer to someone else who knows for sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would recommend something like jnp.result_type(tree_leaves(b))
jax/_src/scipy/sparse/linalg.py
Outdated
x_ = lax.cond( | ||
exit_early, | ||
lambda _: _add(x, _mul(alpha_, phat)), | ||
lambda _: _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))), | ||
None | ||
) | ||
r_ = lax.cond( | ||
exit_early, | ||
lambda _: s, | ||
lambda _: _sub(s, _mul(omega_, t)), | ||
None | ||
) | ||
k_ = lax.cond((omega_ == 0) | (alpha_ == 0), | ||
lambda _: -11, lambda _: k + 1, None) | ||
k_ = lax.cond((rho_ == 0), lambda _: -10, lambda _: k_, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a rule of thumb, it's best to prefer using jnp.where
to lax.cond
, unless the avoided computation is expensive. cond
always does control flow on the CPU, which can easily be a major bottleneck when running on a GPU.
If you're wondering how to do this with pytrees, try something like tree_multimap(partial(where, exit_early), x, y)
.
jax/_src/scipy/sparse/linalg.py
Outdated
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_ | ||
|
||
r0 = _sub(b, A(x0)) | ||
# need to use _vdot_tree to match dtype (hacky...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would recommend something like jnp.result_type(tree_leaves(b))
jax/_src/scipy/sparse/linalg.py
Outdated
# real-valued positive-definite linear operators are symmetric | ||
def real_valued(x): | ||
return not issubclass(x.dtype.type, np.complexfloating) | ||
symmetric = all(map(real_valued, tree_leaves(b))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is only valid for CG. I believe BICGSTAB can be used for arbitrary non-symmetric linear operators, and thus should pass symmetric=False
into custom_linear_solve. Otherwise I believe we will calculate the wrong gradients for non-symmetric linear operators.
using_x64 = solution.dtype.kind in {np.float64, np.complex128} | ||
solution_tol = 1e-8 if using_x64 else 1e-4 | ||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests looks great, but I realize that we currently don't do any test of gradients. That would probably be a worthwhile addition.
See jtu.check_grads()
for examples of how we do that elsewhere in JAX. In principle, I think that could be as simple as adding one line here: jtu.check_grads(solve, (a, b), order=2)
where solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shoyer By just adding this line, in a couple cases, GMRES seems to fail this test, moreso for the complex64
and float32
dtypes. But since scipy
matches, I expect the same thing would happen if we had used scipy
for the iterative solves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try relaxing the error tolerance a little bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, ok it's already 0.2 right now, and I saw this case for GMRES on my local tests:
Mismatch: 100%
Max absolute difference: 0.6610913
Max relative difference: 161.79356
x: array(0.657005, dtype=float32)
y: array(-0.004086, dtype=float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm! My guess is that there is a genuine bug in the derivative rule for GMRES. I would suggest commenting out this test for GMRES for now (opening an issue so we don't forget about it), and we can try to fix that in a follow-on PR.
@shoyer It's ready now. I needed to relax the gradient check a bit more to 0.3 for BICGSTAB, and I think we can address GMRES gradient issues in a future PR as you mentioned so I commented it out for now. |
Unfortunately some of these gradient checks (both GMRES and BICGSTAB) are failing on Google's internal CI, which also runs things on GPU/TPU and Google's slightly different linear algebra libraries for CPU. This is rather annoying for non-Googlers to debug, but I suspect it's entirely spurious because it mostly turns up for low precision. If you want to give this one more try, perhaps set |
@shoyer what do you think is the best course of action here? |
a87d7a0
to
d8831d4
Compare
One of the GitHub CI tests seems to be timing out. It used to run in about 30 minutes and is now taking over an hour. Can you check to see if the tests this PR introduces are much slower than expected? The timing out test sets |
It might be worth squashing your changes and rebasing on master to see if that helps the slow tests. |
d8831d4
to
cd902bd
Compare
Hi @shoyer hmm no dice on the squashing. I couldn't immediately find which tests took so long (it didn't seem to matter if I increase the number of cases), but the longest tests seemed to be whereever grads are checked for CG. Is there a way I can see which tests took so long during CI? |
I thought you commented out the gradient tests? If you can't reproduce this locally I'm honestly baffled here. Let's see if other JAX maintainers have ideas... |
If you could sync in the latest changes from the JAX repo, the unit tests will run in verbose mode. That should hopefully make it more obvious where the test suite is getting stuck. |
@shoyer re: grad tests, I only kept the grad tests that didn't pose an issue in master! And will do! |
cd902bd
to
4e84ad3
Compare
7a10647
to
e1b77c0
Compare
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison fixed flake8 added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where comment out gmres grad check, to be addressed on future PR increasing tolerance for bicgstab grad test change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check remove grad checks for now changing tolerance to pass numpy comparison test
e1b77c0
to
997ad31
Compare
@shoyer I'm wondering if the issue was because of the x64 enable I accidentally kept? |
removing commented-out x64 line
0748c9e
to
d35ae4c
Compare
@shoyer I think this should be good to go, let me know if there's anything else you'd like me to change! |
Thanks for your perseverance here, @sunilkpai ! We are really excited to have this feature. |
Awesome, thanks @shoyer, happy to contribute! |
Great work! For visibility https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.sparse.linalg should be updated as well. |
Thanks! Done in #6647 |
This is an update on #3796 for an implementation of BICGSTAB for the new repo and is no longer a WIP. I used the same tests as GMRES and they seem to work fine.