Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add bicgstab implementation #3796

Closed
wants to merge 4 commits into from

Conversation

sunilkpai
Copy link
Contributor

@sunilkpai sunilkpai commented Jul 19, 2020

I am working on the bicgstab implementation (using the current pytree convention in cg); I am most of the way there as the implementation was very similar to cg and the tests generally seem to work. This addresses one of the goals of #1531.

Before finalizing this PR, there are some points of confusion on the tests that have to do with numerical issues in iterative solvers. I would like some advice on this before proceeding.

  1. If you use a iterative solver, you only get the correct answer up to some tolerance. This makes gradient checking harder and different from a lot of the other more predictable behaviors in jax. I noticed the biggest shape in the gradient tests *_solve is (2, 2), and when I checked larger shapes, there was an error (depending on the type, either array not equal or grad check failure)! What is the general consensus on testing with respect to gradient checks, especially for these approximate solvers?
  2. Mid-iteration convergence and breakdown checks may require some special implementation of loops. To truly match scipy for cg and bicgstab more generally, I think this will need to be implemented somehow. I have a TODO comment somewhere in this PR explaining how it would be implemented if jit were not a requirement.

I also want to make sure that the changes I made to cg to avoid code duplication are kosher.

@googlebot
Copy link
Collaborator

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@sunilkpai
Copy link
Contributor Author

@googlebot I signed it!

@googlebot
Copy link
Collaborator

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

----------
A : function
Function that calculates the matrix-vector product ``Ax`` when called
like ``A(x)``. ``A`` must represent a hermitian, positive definite
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought bicgstab works on arbitrary matrices, without requiring symmetry? It would be good to clarify this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a copying error, you are correct!

isolve = partial(_isolve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)

real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this condition is only true for cg, not bicgstab? I would expect arguments to bicgstab to essentially never be symmetric. As I understand it, if they were, there are better solvers to use (e.g., cg).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you have a matrix that is symmetric but NOT positive definite (e.g. pos semidefinite), this would still be useful, no? Don't know how often that would be the case in practice but it's certainly a possibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but either way it's definitely a different condition for bicgstab vs cg. All real values for the solution does not imply that it's symmetric.

So it seems like isolve should grow a symmetric argument, which could be set in a solver specific way. Potentially we could even expose a new symmetric argument on the public bicgstab function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, and I also wonder if there is some utility in treating isolve as a public function or no? I can give it an underscore in the name if not!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all the helper functions should definitely be marked as private.

Comment on lines 120 to 125
# TODO(?): stop early
# It requires accessing cond_fun like this, but it's
# not possible with jit...
# if cond_fun((x, s, r0, alpha_, omega, rho_, p_, q_, k)):
# x_ = _add(x, _mul(alpha_, phat))
# return x_, s, rhat, alpha_, omega, rho_, p_, q, k
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably do this with some cleverness (e.g., involving lax.cond), but I think it's actually totally fine to skip it -- at least for now. In the best case, this would save one extra function evaluation.

On GPUs, dynamic control flow is so expensive that I expect we could actually get better performance in many cases by only checking for stopping every n steps.


rng = jtu.rand_default(self.rng())
A = poisson(shape, dtype) # use scipy's test to check bicgstab instead of random matrix
# A = rand_sym_pos_def(rng, shape, dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be in favor of testing the Poisson equation as well as random matrices, but I would definitely test random matrices (with no structure) as well. This seems much more likely to turn up issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree, I thought of this as well. I will implement that suggestion!

Comment on lines 273 to 339
def test_bicgstab_pytree(self):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
b = {"a": 1.0, "b": -4.0}
expected = {"a": 4.0, "b": -6.0}
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
self.assertEqual(expected.keys(), actual.keys())
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
self.assertAlmostEqual(expected["b"], actual["b"], places=6)

def test_bicgstab_errors(self):
A = lambda x: x
b = jnp.zeros((2,))
with self.assertRaisesRegex(
ValueError, "x0 and b must have matching tree structure"):
jax.scipy.sparse.linalg.bicgstab(A, {'x': b}, {'y': b})
with self.assertRaisesRegex(
ValueError, "x0 and b must have matching shape"):
jax.scipy.sparse.linalg.bicgstab(A, b, b[:, np.newaxis])

def test_bicgstab_without_pytree_equality(self):

@register_pytree_node_class
class MinimalPytree:
def __init__(self, value):
self.value = value
def tree_flatten(self):
return [self.value], None
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)

A = lambda x: MinimalPytree(2 * x.value)
b = MinimalPytree(jnp.arange(5.0))
expected = b.value / 2
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
self.assertAllClose(expected, actual.value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have both cg and bicgstab using most of the same code. So I don't think it provides much value to add these repeated tests. I would either drop them or make the original tests parametric so they test both cg and bicgstab with the same method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, will do!

Comment on lines 214 to 215
a = poisson(shape, dtype) # TODO(sunilkpai): random doesn't work (bicg numerics), so using poisson instead...
# a = rng(shape, dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is at least a little surprising. I wouldn't expect random 2x2 matrices to be a problem -- that suggests that maybe there's a bug in the algorithm?

Copy link
Contributor Author

@sunilkpai sunilkpai Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I am surprised as well. I think having breakdown checks is perhaps important. If scipy returns -10 or -11 that means rho = 0 or omega = 0 respectively and maybe that is resulting in some issues. I will look into this; I do recall small matrices actually trips up even scipy's bicgstab.

shat = M(s)
t = A(shat)
# omega_ = _vdot_tree(_conj(s), t) / _vdot_tree_real(t, t)
omega_ = _vdot_tree(s, t) / _vdot_tree_real(t, t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that this might be effectively dividing by zero if the algorithm converges in the first half of the step. I wonder if this could be related to your testing challenges?

If I'm understanding correctly, this might not be happening for cg only because demonator is also the termination condition on the loop.

One way to guard against this would be replace the denominator with something like jnp.maximum(_vdot_tree_real(t, t), epsilon).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to add that in as I have seen it used in other contexts, but I haven't really seen this done in scipy's fortran revcom or any other pseudocodes. That said, I have noticed that my implementation and scipy tend to have similar breakdown scenarios (usually we add breakdown checks for whether rho or omega are zero and that is returned as part of info). Do you recommend we add this in anyway?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My general rule of thumb is to try not to innovate in the details of numerical methods. I'm not an expert on most of them, and it can be easy to get the details wrong.

I would definitely suggest copying the safety/breakdown checks you notice in other implementation like SciPy rather than my ideas! Let me know if you have questions about how to implement them in JAX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The safety/breakdown checks might require some deeper changes in jax (similar to why info is currently None). Might be a future thing to add after getting the basic bicgstab to work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should definitely be able to add early breakdown without pervasive changes to JAX.

There are at least two ways to do it:

  1. You could nest a call to jax.lax.cond, either returning half-way through or doing the single GMRES step.
  2. You could keep track of whether early existing should happen in a boolean variable, and use lax.select/jnp.where (basically the same) to selectively update variables, e.g., compare how loopy logic for sampling from the Poisson distribution is implemented in JAX vs NumPy:

They have slightly different performance tradeoffs, but either solution should be fine here. (1) might be a little easier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is helpful!

@shoyer
Copy link
Member

shoyer commented Jul 20, 2020

  • What is the general consensus on testing with respect to gradient checks, especially for these approximate solvers?

It's fine to relax error tolerances in gradient checks as long as we are still quite confident that the result is correct (up to numerical precision).

For example, I just tried switching the cg gradient test to a random 4x4 matrix. This turns up the following error:

E     AssertionError:
E     Not equal to tolerance rtol=0.04, atol=0.008
E
E     Mismatch: 25%
E     Max absolute difference: 1.0214233
E     Max relative difference: 0.09353104
E      x: array([-37.997494,   3.304602,  70.99634 , 318.19443 ], dtype=float32)
E      y: array([-37.312508,   3.021955,  70.98675 , 317.173   ], dtype=float32)

But really, this is close enough. So if we used these tests, I would be fine relaxing the error tolerance as needed.

@sunilkpai sunilkpai marked this pull request as draft July 20, 2020 02:34
@sunilkpai
Copy link
Contributor Author

sunilkpai commented Jul 24, 2020

Hi @shoyer I'm still fighting my code to get scipy and bicgstab to agree; they agree for many cases (especially higher numerical precision) but begin to deviate in some special cases. I've done a number of things to check what the differences are and I'm beginning to suspect that it has something to do with numerical precision becoming increasingly important after several steps (generally not after the first step). I checked scipy's implementation (and even hacked into their bicgstab function to print out intermediate values). Wondering what your testing procedure was so I can resolve these issues more efficiently.

The test I have started to run now is to write a numpy version of bicgstab and then to compare that version with the jax version to make sure things roughly agree given essentially the same code. But even this starts to disagree (with both scipy and jax); my current hypothesis is because of slight differences in implementation in methods such as vdot (order of addition seems to matter at times for numerical precision).

Here is some example failure case output tracking bicgstab iterations for the scipy implementation and my numpy implementation for float32[4, 4]. It's worth noting that I've observed simulations in the past where bicgstab fails for complex64 and not complex128, so precision appears to be a key factor in determining the success of bicgstab.

Iteration 1: # this iteration looks fine
r: numpy [-0.7498561 -2.9607344  5.9396653 -3.6000998] scipy [-0.74985594 -2.9607344   5.9396653  -3.6000998 ]
p: numpy [ 2.845481    0.93953496 -0.43305653 -1.9908385 ] scipy [ 2.845481    0.93953496 -0.43305653 -1.9908385 ]
t: numpy [44.767788  -6.7145605 15.403215  21.610672 ] scipy [44.767788  -6.7145605 15.403215  21.610674 ]
q: numpy [ 13.488404  15.618371 -25.598196   6.009314] scipy [ 13.488404   15.618371  -25.598196    6.0093145]
shat: numpy [-0.5489147 -2.9908729  6.008803  -3.5030997] scipy [-0.5489147 -2.9908729  6.008803  -3.5030997]

Iteration 2: # things start to change more here but still fine
r: numpy [-1.3085905 -1.9914868  4.1542206 -3.6563268] scipy [-1.3085957 -1.9914829  4.1542125 -3.656325 ]
p: numpy [-4.5599647 -4.1502147  6.374942  -0.8395064] scipy [-4.559961   -4.1502132   6.3749413  -0.83950925]
t: numpy [ 28.715466 -17.281141  27.14217   29.973507] scipy [ 28.715366 -17.281208  27.142246  29.973562]
q: numpy [ 24.778402 -29.029495  51.462784  14.110635] scipy [ 24.778421 -29.029476  51.462746  14.110645]
shat: numpy [-1.8608497 -1.6591338  3.6322193 -4.232781 ] scipy [-1.8608582 -1.6591254  3.6322048 -4.2327857]

Iteration 3: # oh boy...
r numpy [-0.00711237 -0.12015637  0.1177693  -0.10020799] scipy[-0.0025171  -0.11048866  0.0973189  -0.08428508]
p numpy [ 2.093349   1.931221  -1.9813652 -3.183013 ] scipy [ 2.0936446  1.9315835 -1.9819365 -3.1829684]
t: numpy [ 1.0321703  -0.27053687  0.5031162   0.84241974] scipy [ 0.8828936  -0.21744165  0.4108755   0.73308885]
q: numpy [  5.698114   8.345237 -17.993662  15.740916] scipy [  5.6956105   8.347048  -17.997128   15.740109 ]
shat: numpy [-0.027246   -0.11487925  0.10795546 -0.11664033] scipy [-0.02184983 -0.10572734  0.08832195 -0.10033754]

Incidentally, after using more random cases than is currently recommended, I noticed cg failed a few scipy comparison tests, either due to similar reasons or changes I've made to the tests. I'll get back to you on exactly what those failure cases are.

@shoyer
Copy link
Member

shoyer commented Jul 24, 2020

@sunilkpai This looks close enough to me? How does the accuracy of the final solution compare?

SciPy is probably always using float64 precision. It might be worth enabling x64 mode for JAX, just for the sake of comparison. This currently requires a special flags:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision

If the error does not decrease when you turn on x64 mode, then perhaps we are missing something important and should be worried! In that case I would look into the early termination conditions.

@sunilkpai
Copy link
Contributor Author

@shoyer Yeah, you might be right for this case, but I think this becomes magnified as n gets higher (e.g. n = 32 becomes a problem, this was just an unlucky case for n = 4 that was easy to use as an example of what may go wrong). I will try your suggestion and see if anything changes!

@sunilkpai
Copy link
Contributor Author

Also, re: early termination, this case may actually be necessary when s = 0 for bicgstab, which is an error case generated by some of the pytree tests. If s = 0 then t = 0 as well, resulting in nan evaluation for omega!

@sunilkpai
Copy link
Contributor Author

Hi @shoyer this is still wip but I am getting closer I think thanks to your help! I have been adding your suggestions for the testing framework. Some checks between numpy and scipy are necessary (since there are differences and numerical failure cases) and therefore are now part of the tests.

The exact test for cg seems to be the culprit in the above CI failure (either due to some unintentional changes i made to the test or just the fact that it's a different matrix). There are also random failures for bicgstab in a more exhaustive test setting... all depends on the matrix (verified this by adding extra characters to the test name to generate new set of cases). Generally this occurs on roughly about 5% of the tests or so, sometimes the error is small (twice the rel tolerance) other times very large (orders of mag larger than tolerance).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants