Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/bicgstab #5299

Merged
merged 2 commits into from Feb 22, 2021
Merged

Feature/bicgstab #5299

merged 2 commits into from Feb 22, 2021

Conversation

sunilkpai
Copy link
Contributor

This is an update on #3796 for an implementation of BICGSTAB for the new repo and is no longer a WIP. I used the same tests as GMRES and they seem to work fine.

return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_

r0 = _sub(b, A(x0))
# need to use _vdot_tree to match dtype (hacky...)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could tree map ones_like I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am just trying to get the number 1 to match the dtype (complex or float, 32 or 64 bit). How would that work?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry... thought I was being helpful with my drive by comment... so you actually want a scalar 1 . Perhaps https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.result_type.html#jax.numpy.result_type would be helpful, but I'll defer to someone else who knows for sure!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would recommend something like jnp.result_type(tree_leaves(b))

Comment on lines 161 to 175
x_ = lax.cond(
exit_early,
lambda _: _add(x, _mul(alpha_, phat)),
lambda _: _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))),
None
)
r_ = lax.cond(
exit_early,
lambda _: s,
lambda _: _sub(s, _mul(omega_, t)),
None
)
k_ = lax.cond((omega_ == 0) | (alpha_ == 0),
lambda _: -11, lambda _: k + 1, None)
k_ = lax.cond((rho_ == 0), lambda _: -10, lambda _: k_, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rule of thumb, it's best to prefer using jnp.where to lax.cond, unless the avoided computation is expensive. cond always does control flow on the CPU, which can easily be a major bottleneck when running on a GPU.

If you're wondering how to do this with pytrees, try something like tree_multimap(partial(where, exit_early), x, y).

return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_

r0 = _sub(b, A(x0))
# need to use _vdot_tree to match dtype (hacky...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would recommend something like jnp.result_type(tree_leaves(b))

Comment on lines 223 to 226
# real-valued positive-definite linear operators are symmetric
def real_valued(x):
return not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is only valid for CG. I believe BICGSTAB can be used for arbitrary non-symmetric linear operators, and thus should pass symmetric=False into custom_linear_solve. Otherwise I believe we will calculate the wrong gradients for non-symmetric linear operators.

using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests looks great, but I realize that we currently don't do any test of gradients. That would probably be a worthwhile addition.

See jtu.check_grads() for examples of how we do that elsewhere in JAX. In principle, I think that could be as simple as adding one line here: jtu.check_grads(solve, (a, b), order=2) where solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0].

Copy link
Contributor Author

@sunilkpai sunilkpai Jan 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shoyer By just adding this line, in a couple cases, GMRES seems to fail this test, moreso for the complex64 and float32 dtypes. But since scipy matches, I expect the same thing would happen if we had used scipy for the iterative solves.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try relaxing the error tolerance a little bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok it's already 0.2 right now, and I saw this case for GMRES on my local tests:

Mismatch: 100%
Max absolute difference: 0.6610913
Max relative difference: 161.79356
 x: array(0.657005, dtype=float32)
 y: array(-0.004086, dtype=float32)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm! My guess is that there is a genuine bug in the derivative rule for GMRES. I would suggest commenting out this test for GMRES for now (opening an issue so we don't forget about it), and we can try to fix that in a follow-on PR.

@sunilkpai
Copy link
Contributor Author

@shoyer It's ready now. I needed to relax the gradient check a bit more to 0.3 for BICGSTAB, and I think we can address GMRES gradient issues in a future PR as you mentioned so I commented it out for now.

@shoyer shoyer added the pull ready Ready for copybara import and testing label Jan 11, 2021
@shoyer
Copy link
Member

shoyer commented Jan 13, 2021

Unfortunately some of these gradient checks (both GMRES and BICGSTAB) are failing on Google's internal CI, which also runs things on GPU/TPU and Google's slightly different linear algebra libraries for CPU.

This is rather annoying for non-Googlers to debug, but I suspect it's entirely spurious because it mostly turns up for low precision. If you want to give this one more try, perhaps set order=1 in check_grads() (under the theory that higher order derivatives have larger numerical error). Otherwise, I guess we can comment out those checks, too, for now.

@sunilkpai
Copy link
Contributor Author

@shoyer what do you think is the best course of action here?

@shoyer
Copy link
Member

shoyer commented Jan 27, 2021

One of the GitHub CI tests seems to be timing out. It used to run in about 30 minutes and is now taking over an hour. Can you check to see if the tests this PR introduces are much slower than expected?

The timing out test sets JAX_NUM_GENERATED_CASES=25:
https://jax.readthedocs.io/en/latest/developer.html#running-the-tests

@shoyer
Copy link
Member

shoyer commented Jan 27, 2021

It might be worth squashing your changes and rebasing on master to see if that helps the slow tests.

@sunilkpai
Copy link
Contributor Author

Hi @shoyer hmm no dice on the squashing. I couldn't immediately find which tests took so long (it didn't seem to matter if I increase the number of cases), but the longest tests seemed to be whereever grads are checked for CG. Is there a way I can see which tests took so long during CI?

@shoyer
Copy link
Member

shoyer commented Jan 27, 2021

I thought you commented out the gradient tests? If you can't reproduce this locally I'm honestly baffled here. Let's see if other JAX maintainers have ideas...

@shoyer
Copy link
Member

shoyer commented Jan 28, 2021

If you could sync in the latest changes from the JAX repo, the unit tests will run in verbose mode. That should hopefully make it more obvious where the test suite is getting stuck.

@sunilkpai
Copy link
Contributor Author

@shoyer re: grad tests, I only kept the grad tests that didn't pose an issue in master! And will do!

fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
@sunilkpai
Copy link
Contributor Author

@shoyer I'm wondering if the issue was because of the x64 enable I accidentally kept?

removing commented-out x64 line
@sunilkpai
Copy link
Contributor Author

@shoyer I think this should be good to go, let me know if there's anything else you'd like me to change!

@copybara-service copybara-service bot merged commit 234990e into google:master Feb 22, 2021
@shoyer
Copy link
Member

shoyer commented Feb 22, 2021

Thanks for your perseverance here, @sunilkpai !

We are really excited to have this feature.

@sunilkpai
Copy link
Contributor Author

Awesome, thanks @shoyer, happy to contribute!

@tetterl
Copy link

tetterl commented May 4, 2021

Great work! For visibility https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.sparse.linalg should be updated as well.

@jakevdp
Copy link
Collaborator

jakevdp commented May 4, 2021

Thanks! Done in #6647

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants