New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiable matrix-free linear algebra, optimization and equation solving #1531
Comments
I'll start things off. I've written initial versions of higher order primitives defining gradients for non-linear root finding (see discussion in #1448) and linear equation solving (#1402). I've also written experimental implementations in JAX for |
Yes, I'm working on a general feature for referring to (and extracting, injecting, differentiating w.r.t.) intermediate/aux values; it's coming soon 🙂. |
I'm glad to hear @jekbradbury is thinking about, because I have not! I agree it's important. In this case, auxiliary outputs should not have derivatives defined. If we were to solve this entirely inside |
@jekbradbury Awesome, looking forward to it! (@ all JAX maintainers/contributors) I'm loving style and the direction of JAX, keep up the great work! |
#2566 adds |
@shoyer, what's the status of |
I have a very naive implementation of GMRES with preconditioning that you can find here: https://gist.github.com/shoyer/cbac2cf8c8675b2f3a45e4837e3bed80 It needs more careful testing (and possibly improvements to the numerics) before merging into JAX. |
To be clear, I have no immediate plans to continue work on my |
@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first. |
@romanodev this is really awesome work. |
This other PR has a vectorized version of Gram Schmidt, which I think could replace |
OK, here's a version with vectorized Gram Schmidt, which is perhaps 10-20% faster on CPU and ~10x faster on GPUs: The GPU is still speeding most of its time waiting when solving for a 100-dimensional vector, but for large enough solves the performance should be reasonable. I suspect the main remaining improvements (which we should probably have before merging) would adding some form of early termination based on residual values. |
@shoyer, great. Any reason for not using experimental.loops? Just for the sake of prototyping, I rewrote your version using loops (just personal taste for faster iterations). Using jnp.lstsq, I plotted the residual at each iteration. Within loops we can easily handle early termination, but I am not sure how to do it with lax.scan. ln any case, the preconditioning does not seem to work yet. I considered a simple case where we take the inversion of the diagonals of A, and it doesn't match scipy. Although this choice is not justified since A does not have a dominant diagonal, it should at least serve for testing. Here is the gist: https://gist.github.com/romanodev/e3f6bd23c499cd8a5f26b26c140abcac |
It's just a matter of taste. Personally I don't find it much clearer than using functions if there's only one level of loops. It's also a little easier to avoid inadvertently using Python control flow the control flow functions, e.g., in your example there's one place where you should be using For printing intermediate outputs from inside compiled code like "for" loops, take a look at For early termination logic in general, take a look at the source code for jax/jax/scipy/sparse/linalg.py Line 54 in 0d81e98
|
Hi, in parallel with this, I'd like to add an implementation for bicgstab (which i believe is also matrix free)! I've started working on something similar in cupy so I thought I might as well add it to |
Sure, we would love to see an implementation of bicgstab. The
implementation should be relatively straightforward (easier than gmres).
Please copy the style of the existing cg implementation, including tests.
…On Thu, Jul 16, 2020 at 7:39 PM sunilkpai ***@***.***> wrote:
Hi, in parallel with this, I'd like to add an implementation for bicgstab
(which i believe is also matrix free)! I've started working on something
similar in cupy <cupy/cupy#3569> so I thought I
might as well add it to jax.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1531 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVUGIKRE6JHJFO25TBLR362XZANCNFSM4JCVYSAA>
.
|
Excellent! If GMRES works out too, then we can even do a more general form of BiCGSTAB, which may be useful for EM problems (as suggested by meep). Not sure if this is implemented in scipy, so unsure where something like that would go? |
Meanwhile, as a reference, this is what I have so far on GMRES (using https://gist.github.com/romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f |
With regards to the "more general version of GMRES", I'm sure this could be useful but my preference would be to stick to established algorithms in JAX. Inclusion in SciPy is one good indication of that. I would suggest releasing more experimental algorithms in separate packages. |
Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like |
The reason for doing this is that it lets us inverting linear equations on arbitrary shaped arrays in arbitrary structures. In some cases, this is a much more natural way to implement linear operators, e.g., you might use a handful of 3D arrays for representing a solution to a PDE on a grid. By supporting pytrees, we don't need to copy these arrays into a single flattened vector. This can actually add significant overhead due to extra copying, e.g., it's 60-90% slower to solve a Poisson equation in 2D using CG with a flattened array than with a 2D array: That said, these algorithms are much easier to implement/verify on a single vectors, and in the long term I'd like to solve vectorization with a So if you prefer, I would also be OK with implementing these algorithms on the single vectors and adding explicit flattening/unflattening to handle pytrees. You can find an example of doing this sort of thing inside Lines 210 to 214 in fa2a027
|
@shoyer I added a PR for |
I would just like to add another comment on @shoyer's on supporting operators on pytree arrays. I have a use case where I need to compute the vector norm / dot product for arrays structured as pytrees, it would be very useful to support these operations. |
To be clear, this was a suggestion for making it easier to write new solvers. Actually, it would even be fine not to support pytrees at all on new solvers.
|
We should absolutely feel free to copy SciPy's tests. #3101 (which I will be merging shortly) has a good example of how to do this. |
Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax. I've started a conversation in google/TensorNetwork#785 with the other TensorNetwork devs on the matter. We are also planning to implement LGMRES, which is a modification to GMRES that performs better on nearly-symmetric operators. This would probably be my job since I wrote the GMRES implementation, and I'd be happy to simply write it here instead. |
Note that both LGMRES and bicgstab are accessible through SciPy. |
Agreed, I would love to upstream all of these into JAX. |
@alewis I noted preconditioning is not supported for GMRES with JAX backend in TensorNetwork. Is there any plan in that direction? |
@romanodev Preconditioning has been sort of a vague long-term goal, but certainly we'd be happy to have it |
@alewis I am seeing GMRES is close to being added to the main sparse namespace. Nice work! I have a couple of questions:
|
You need to specify your own preconditioner as a function. I'm not sure
what ILU is but if you can do that you're good.
Currently some of the low-level design of the sparse solvers makes it
impossible to expose the Krylov subspace, but this is likely to be modified.
…On Tue, Nov 24, 2020 at 3:12 PM Giuseppe Romano ***@***.***> wrote:
@alewis <https://github.com/alewis> I am seeing GMRES is close to being
added to the main sparse namespace. Nice work! I have a couple of questions:
1.
Is ILU preconditioning on the roadmap? I am trying scipy's GMRES and,
in my (sparse) case, neither Jacobi nor Gauss–Seidel preconditioners speed
up convergence. Only ILU seems to be working.
2.
In my use case, I have to solve many similar sparse problems. Is there
a way of exposing some of the Krilov subspace so it can be reused?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1531 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAINHBTOKQ3KHDU4UAZEDVLSRQHRFANCNFSM4JCVYSAA>
.
|
Hey, I just added a separate PR #5299 (in case I'd need to do a lot of rebase work due to the addition of GMRES) with the BICGSTAB implementation and tests. I feel like should think about test consolidation at some point, but it appears to work as is for now. I used your suggestions for using |
I was looking at implementing a simplified lobpcg, though the result will probably be highly inefficient until a dense generalized eigh is available. |
Wrt |
This is a meta-issue for keeping track of progress on implementing differentiable higher-order functions from SciPy in JAX, e.g.,
scipy.sparse.linalg.gmres
andcg
: matrix-free linear solvesscipy.sparse.linalg.eigs
andeigsh
: matrix-free eigenvalue problemsscipy.optimize.root
: nonlinear equation solvingscipy.optimize.fixed_point
: solving for fixed pointsscipy.integrate.odeint
: solving ordinary differential equationsscipy.optimize.minimize
: nonlinear minimizationThese higher-order functions are important for implementing sophisticated differentiable programs, both for scientific applications and for machine learning.
Implementations should leverage and build upon JAX's custom transformation capabilities. For example,
scipy.optimize.root
should leverage autodiff for calculating the Jacobians or Jacobian-vector products needed for Newton's method.In most cases, I think the right way to do this involves two separate steps, which could happens in parallel:
lax.custom_linear_solve
from lax.custom_linear_solve primitive #1402.while_loop
) or could leverage existing external implementations on particular backends. Either way they will almost certainly need custom derivative rules, rather than differentiation through the forward algorithm.There's lots of work to be done here, so please comment if you're interested in using or implementing any of these.
The text was updated successfully, but these errors were encountered: