Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct gradient rules to speed-up numpy.linalg.solve #1747

Closed
shoyer opened this issue Nov 22, 2019 · 1 comment · Fixed by #2220
Closed

Direct gradient rules to speed-up numpy.linalg.solve #1747

shoyer opened this issue Nov 22, 2019 · 1 comment · Fixed by #2220

Comments

@shoyer
Copy link
Member

shoyer commented Nov 22, 2019

It can be significantly more efficient, both in runtime and memory, to directly define derivative rules for higher level linear algebra operations like solve rather than the constituent operations (factorization and triangular solve).

For large matrices (e.g., 500x500 on CPUs), my microbenchmark shows that we can get a 3-4x speed-up for general purpose and symmetric solves:

from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax

def positive_definite_solve(a, b):
  factors = jsp.linalg.cho_factor(a)
  def solve(matvec, x):
    return jsp.linalg.cho_solve(factors, x)
  matvec = partial(np.dot, a)
  return lax.custom_linear_solve(matvec, b, solve, symmetric=True)

def linear_solve(a, b):
  a_factors = jsp.linalg.lu_factor(a)
  def solve(matvec, x):
    return jsp.linalg.lu_solve(a_factors, x)
  def transpose_solve(vecmat, x):
    return jsp.linalg.lu_solve(a_factors, x, trans=1)
  matvec = partial(np.dot, a)
  return lax.custom_linear_solve(matvec, b, solve, transpose_solve)

def loss(solve):
  def f(a, b):
    return solve(a, b).sum()
  return f

rs = onp.random.RandomState(0)
a = rs.randn(500, 500)
a = jax.device_put(a.T @ a + 0.1 * np.eye(500))
b = jax.device_put(rs.randn(500))

# general purpose solve
# current
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
%timeit jax.device_get(grad(a, b))
# 33.8 ms per loop
# new
grad = jax.jit(jax.grad(loss(linear_solve)))
%timeit jax.device_get(grad(a, b))
# 10.1 ms per loop

# positive definite solve
# current
grad = jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
%timeit jax.device_get(grad(a, b))
# 23.7 ms per loop
# new
grad = jax.jit(jax.grad(loss(positive_definite_solve)))
%timeit jax.device_get(grad(a, b))
# 4.8 ms per loop

Unfortunately, we can't just use these prototype implementations internally in JAX, for two reasons:

  1. custom_linear_solve (like custom transforms in general) doesn't work with batching yet (custom_transforms vjp rule clobbered under vmap #1249). This was solved by Batching rule for custom_linear_solve #2099.
  2. We do an extra optimization in triangular_solve_jvp_rule_a for the case of solving many right-hand-sides at the same time with the same left-hand side (Speedup JVP for triangular solve #1466). This new gradient rule here doesn't handle this yet. Update: in practice, I don't think this optimization actually matters -- it's the difference between n*m*m+m*m*m time vs 2*m*m*m time.
  3. We need to support multiple right-hand-side arguments. After Better batching rule for triangular_solve #2138, we'll be able to do this by batching custom_linear_solve.
@shoyer
Copy link
Member Author

shoyer commented Feb 3, 2020

Thanks to @mattjj for pointing out that LU solve has the trans argument, which means we use a single factorization for both forward and reverse calculations to speed up solves on all types of matrices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant