Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiable matrix-free linear algebra, optimization and equation solving #1531

Open
6 tasks
shoyer opened this issue Oct 20, 2019 · 38 comments
Open
6 tasks
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@shoyer
Copy link
Member

shoyer commented Oct 20, 2019

This is a meta-issue for keeping track of progress on implementing differentiable higher-order functions from SciPy in JAX, e.g.,

  • scipy.sparse.linalg.gmres and cg: matrix-free linear solves
  • scipy.sparse.linalg.eigs and eigsh: matrix-free eigenvalue problems
  • scipy.optimize.root: nonlinear equation solving
  • scipy.optimize.fixed_point: solving for fixed points
  • scipy.integrate.odeint: solving ordinary differential equations
  • scipy.optimize.minimize: nonlinear minimization

These higher-order functions are important for implementing sophisticated differentiable programs, both for scientific applications and for machine learning.

Implementations should leverage and build upon JAX's custom transformation capabilities. For example, scipy.optimize.root should leverage autodiff for calculating the Jacobians or Jacobian-vector products needed for Newton's method.

In most cases, I think the right way to do this involves two separate steps, which could happens in parallel:

  1. Higher order primitives for defining automatic differentiation rules, but not specialized to any particular algorithm, e.g., lax.custom_linear_solve from lax.custom_linear_solve primitive #1402.
  2. Implementations of particular algorithms for the forward problems, e.g., a conjugate gradient method for linear solves. These could either be implemented from scratch using JAX's functional control flow (e.g., while_loop) or could leverage existing external implementations on particular backends. Either way they will almost certainly need custom derivative rules, rather than differentiation through the forward algorithm.

There's lots of work to be done here, so please comment if you're interested in using or implementing any of these.

@shoyer
Copy link
Member Author

shoyer commented Oct 20, 2019

I'll start things off.

I've written initial versions of higher order primitives defining gradients for non-linear root finding (see discussion in #1448) and linear equation solving (#1402).

I've also written experimental implementations in JAX for gmres and a Thick-Restart Lanczos method (for eigsh). These are not shared publicly yet, but I could drop them into jax.experimental. (EDIT: see #3114 for thick-restart lanczos)

@gehring
Copy link
Contributor

gehring commented Oct 30, 2019

@shoyer With this addition and, in general, implicit diff related features, are there any plans for a mechanism to extract intermediate/aux values when differentiating to allow us to log things like number of iterations for tangent solve, residual, etc.? (related: #1197 #844)

@jekbradbury
Copy link
Contributor

Yes, I'm working on a general feature for referring to (and extracting, injecting, differentiating w.r.t.) intermediate/aux values; it's coming soon 🙂.

@shoyer
Copy link
Member Author

shoyer commented Oct 30, 2019

I'm glad to hear @jekbradbury is thinking about, because I have not! I agree it's important.

In this case, auxiliary outputs should not have derivatives defined. If we were to solve this entirely inside custom_linear_solve etc then I would suggest simply adding an explicit has_aux argument that changes the function signature to return an extra argument. But then it isn't obvious how we could pipe out auxiliary outputs from the forward or transpose passes.

@gehring
Copy link
Contributor

gehring commented Oct 30, 2019

@jekbradbury Awesome, looking forward to it!

(@ all JAX maintainers/contributors) I'm loving style and the direction of JAX, keep up the great work!

@shoyer
Copy link
Member Author

shoyer commented Apr 5, 2020

#2566 adds cg, so at least that's a start.

@romanodev
Copy link

@shoyer, what's the status of gmres? I started its implementation today but then I realized you have something already.

@shoyer
Copy link
Member Author

shoyer commented Jul 7, 2020

I have a very naive implementation of GMRES with preconditioning that you can find here: https://gist.github.com/shoyer/cbac2cf8c8675b2f3a45e4837e3bed80

It needs more careful testing (and possibly improvements to the numerics) before merging into JAX.

@shoyer
Copy link
Member Author

shoyer commented Jul 7, 2020

To be clear, I have no immediate plans to continue work on my gmres solver. If you want to take on this on, that would be fantastic!

@romanodev
Copy link

romanodev commented Jul 11, 2020

@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.

@mattjj
Copy link
Member

mattjj commented Jul 11, 2020

@romanodev this is really awesome work.

@shoyer
Copy link
Member Author

shoyer commented Jul 11, 2020

@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can't beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.

This other PR has a vectorized version of Gram Schmidt, which I think could replace _inner in my implementation of GMRES above: #3114 (it is basically the same algorithm)

@shoyer
Copy link
Member Author

shoyer commented Jul 12, 2020

OK, here's a version with vectorized Gram Schmidt, which is perhaps 10-20% faster on CPU and ~10x faster on GPUs:
https://gist.github.com/shoyer/dc33a5850337b6a87d48ed97b4727d29

The GPU is still speeding most of its time waiting when solving for a 100-dimensional vector, but for large enough solves the performance should be reasonable.

I suspect the main remaining improvements (which we should probably have before merging) would adding some form of early termination based on residual values.

@romanodev
Copy link

@shoyer, great. Any reason for not using experimental.loops?

Just for the sake of prototyping, I rewrote your version using loops (just personal taste for faster iterations). Using jnp.lstsq, I plotted the residual at each iteration. Within loops we can easily handle early termination, but I am not sure how to do it with lax.scan.

ln any case, the preconditioning does not seem to work yet. I considered a simple case where we take the inversion of the diagonals of A, and it doesn't match scipy. Although this choice is not justified since A does not have a dominant diagonal, it should at least serve for testing.

Here is the gist:

https://gist.github.com/romanodev/e3f6bd23c499cd8a5f26b26c140abcac

@shoyer
Copy link
Member Author

shoyer commented Jul 13, 2020

Any reason for not using experimental.loops?

It's just a matter of taste. Personally I don't find it much clearer than using functions if there's only one level of loops.

It's also a little easier to avoid inadvertently using Python control flow the control flow functions, e.g., in your example there's one place where you should be using s.range() instead of range() (at least in the "real" version, maybe you avoided that intentionally for printing?).

For printing intermediate outputs from inside compiled code like "for" loops, take a look at jax.experimental.host_callback: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html

For early termination logic in general, take a look at the source code for cg. The loop needs to be written in terms of while_loop instead of scan:

def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):

@sunilkpai
Copy link
Contributor

Hi, in parallel with this, I'd like to add an implementation for bicgstab (which i believe is also matrix free)! I've started working on something similar in cupy so I thought I might as well add it to jax.

@shoyer
Copy link
Member Author

shoyer commented Jul 17, 2020 via email

@sunilkpai
Copy link
Contributor

Excellent! If GMRES works out too, then we can even do a more general form of BiCGSTAB, which may be useful for EM problems (as suggested by meep). Not sure if this is implemented in scipy, so unsure where something like that would go?

@romanodev
Copy link

Meanwhile, as a reference, this is what I have so far on GMRES (using while_loop, still needs to be tested carefully)

https://gist.github.com/romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f

@shoyer
Copy link
Member Author

shoyer commented Jul 17, 2020

With regards to the "more general version of GMRES", I'm sure this could be useful but my preference would be to stick to established algorithms in JAX. Inclusion in SciPy is one good indication of that. I would suggest releasing more experimental algorithms in separate packages.

@sunilkpai
Copy link
Contributor

Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like _add and _sub in the _cg_solve method? Is this for jit compilation? It appears I will need to add a more general _vdot_tree function for bicgstab to work for more general matrices, which is why I'm asking!

@shoyer
Copy link
Member Author

shoyer commented Jul 18, 2020

Hi @shoyer, I have a general question as I prepare this PR. What is the motivation of defining pytree functions like _add and _sub in the _cg_solve method? Is this for jit compilation? It appears I will need to add a more general _vdot_tree function for bicgstab to work for more general matrices, which is why I'm asking!

The reason for doing this is that it lets us inverting linear equations on arbitrary shaped arrays in arbitrary structures. In some cases, this is a much more natural way to implement linear operators, e.g., you might use a handful of 3D arrays for representing a solution to a PDE on a grid.

By supporting pytrees, we don't need to copy these arrays into a single flattened vector. This can actually add significant overhead due to extra copying, e.g., it's 60-90% slower to solve a Poisson equation in 2D using CG with a flattened array than with a 2D array:
https://gist.github.com/shoyer/6826d02949e4d2ce82122a8bd5c62cf7

That said, these algorithms are much easier to implement/verify on a single vectors, and in the long term I'd like to solve vectorization with a tree_vectorize code transformation instead -- see #3263. This will hopefully be merged into JAX within the next month or so.

So if you prefer, I would also be OK with implementing these algorithms on the single vectors and adding explicit flattening/unflattening to handle pytrees. You can find an example of doing this sort of thing inside _odeint_wrapper -- you could basically use the exact same thing for something like cg by omitting the jax.vmap at the end:

jax/jax/experimental/ode.py

Lines 210 to 214 in fa2a027

def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
func = ravel_first_arg(func, unravel)
out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
return jax.vmap(unravel)(out)

@sunilkpai
Copy link
Contributor

@shoyer I added a PR for bicgstab but it's still WIP, was just looking for some comments before I finalize the implementation and tests but I think it's close! I think changing to the flattened version would be easy to extend to both cg and bicgstab, so we could handle that in a separate PR?

@ethanluoyc
Copy link
Contributor

I would just like to add another comment on @shoyer's on supporting operators on pytree arrays. I have a use case where I need to compute the vector norm / dot product for arrays structured as pytrees, it would be very useful to support these operations.

@shoyer
Copy link
Member Author

shoyer commented Jul 21, 2020

I think changing to the flattened version would be easy to extend to both cg and bicgstab, so we could handle that in a separate PR?

To be clear, this was a suggestion for making it easier to write new solvers. Actually, it would even be fine not to support pytrees at all on new solvers.

cg already supports pytrees without flattening, and we should keep this functionality!

@sunilkpai
Copy link
Contributor

sunilkpai commented Jul 29, 2020

Hi @shoyer, re: #3796 I think we should think about either a more robust testing pipeline for all matrix-free methods or just copy scipy's tests and use that as our standard. What do you think are the appropriate next steps?

@shoyer
Copy link
Member Author

shoyer commented Jul 29, 2020

Hi @shoyer, re: #3796 I think we should think about either a more robust testing pipeline for all matrix-free methods or just copy scipy's tests and use that as our standard. What do you think are the appropriate next steps?

We should absolutely feel free to copy SciPy's tests. #3101 (which I will be merging shortly) has a good example of how to do this.

@alewis
Copy link
Contributor

alewis commented Aug 20, 2020

Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax. I've started a conversation in google/TensorNetwork#785 with the other TensorNetwork devs on the matter.

We are also planning to implement LGMRES, which is a modification to GMRES that performs better on nearly-symmetric operators. This would probably be my job since I wrote the GMRES implementation, and I'd be happy to simply write it here instead.

@alewis
Copy link
Contributor

alewis commented Aug 20, 2020

Note that both LGMRES and bicgstab are accessible through SciPy.

@shoyer
Copy link
Member Author

shoyer commented Aug 20, 2020

Functioning and apprently efficient Jax implementations of eigs, eigsh using implicitly restarted Lanczos, and gmres are already present in https://github.com/google/TensorNetwork (matrix-free methods are very important in tensor network computations), and perhaps it would make more sense to merge some or all of these into Jax

Agreed, I would love to upstream all of these into JAX.

@romanodev
Copy link

@alewis I noted preconditioning is not supported for GMRES with JAX backend in TensorNetwork. Is there any plan in that direction?

@alewis
Copy link
Contributor

alewis commented Aug 20, 2020

@romanodev Preconditioning has been sort of a vague long-term goal, but certainly we'd be happy to have it

@romanodev
Copy link

@alewis I am seeing GMRES is close to being added to the main sparse namespace. Nice work! I have a couple of questions:

  1. Is ILU preconditioning on the roadmap? I am trying scipy's GMRES and, in my (sparse) case, neither Jacobi nor Gauss–Seidel preconditioners speed up convergence. Only ILU seems to be working.

  2. In my use case, I have to solve many similar sparse problems. Is there a way of exposing some of the Krilov subspace so it can be reused?

@alewis
Copy link
Contributor

alewis commented Nov 24, 2020 via email

@sunilkpai
Copy link
Contributor

Hey, I just added a separate PR #5299 (in case I'd need to do a lot of rebase work due to the addition of GMRES) with the BICGSTAB implementation and tests. I feel like should think about test consolidation at some point, but it appears to work as is for now. I used your suggestions for using lax.cond to check for early exit / breakdown conditions to match scipy info.

@jackd
Copy link
Contributor

jackd commented Jan 19, 2021

I was looking at implementing a simplified lobpcg, though the result will probably be highly inefficient until a dense generalized eigh is available.

@patrick-kidger
Copy link
Collaborator

Wrt scipy.integrate.odeint: for those who come across this thread (I think it would be way out of scope to merge into JAX itself): check out Diffrax, which is a library of ODE(, SDE, etc.) solvers written in pure JAX.

@denis-bz
Copy link

denis-bz commented Jul 5, 2022

@romandev, fwiw this gist has nice test matrices for linear solvers and eigensolvers, e.g. Poisson2 - 4.0001 I.
Would you know of other non-posdef test cases -- the shorter the better, sizeable with N=, with logfiles ?

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Aug 10, 2022
@sudhakarsingh27 sudhakarsingh27 added enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) and removed P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Sep 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests