Skip to content


Merge pull request #5299 from sunilkpai:feature/bicgstab
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 358845323
  • Loading branch information
jax authors committed Feb 22, 2021
2 parents 25c03f7 + d35ae4c commit 234990e
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 35 deletions.
193 changes: 158 additions & 35 deletions jax/_src/scipy/sparse/
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import operator

Expand Down Expand Up @@ -50,6 +51,11 @@ def _vdot_real_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))

def _vdot_tree(x, y):
return sum(tree_leaves(tree_multimap(partial(
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))

def _norm(x):
xs = tree_leaves(x)
return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
Expand Down Expand Up @@ -123,10 +129,99 @@ def body_fun(value):
return x_final

# aliases for working with pytrees

def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):

# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab
bs = _vdot_real_tree(b, b)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))


def cond_fun(value):
x, r, *_, k = value
rs = _vdot_real_tree(r, r)
# the last condition checks breakdown
return (rs > atol2) & (k < maxiter) & (k >= 0)

def body_fun(value):
x, r, rhat, alpha, omega, rho, p, q, k = value
rho_ = _vdot_tree(rhat, r)
beta = rho_ / rho * alpha / omega
p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q))))
phat = M(p_)
q_ = A(phat)
alpha_ = rho_ / _vdot_tree(rhat, q_)
s = _sub(r, _mul(alpha_, q_))
exit_early = _vdot_real_tree(s, s) < atol2
shat = M(s)
t = A(shat)
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
x_ = tree_multimap(partial(jnp.where, exit_early),
_add(x, _mul(alpha_, phat)),
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
r_ = tree_multimap(partial(jnp.where, exit_early),
s, _sub(s, _mul(omega_, t)))
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
k_ = jnp.where((rho_ == 0), -10, k_)
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_

r0 = _sub(b, A(x0))
rho0 = alpha0 = omega0 = jnp.ones(1, dtype=jnp.result_type(*tree_leaves(b)))[0]
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)

x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)

return x_final

def _shapes(pytree):
return map(jnp.shape, tree_leaves(pytree))

def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
maxiter=None, M=None, check_symmetric=False):
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)

b, x0 = device_put((b, x0))

if maxiter is None:
size = sum(bi.size for bi in tree_leaves(b))
maxiter = 10 * size # copied from scipy

if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)

if tree_structure(x0) != tree_structure(b):
raise ValueError(
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')

if _shapes(x0) != _shapes(b):
raise ValueError(
'arrays in x0 and b must have matching shapes: '
f'{_shapes(x0)} vs {_shapes(b)}')

isolve_solve = partial(
_isolve_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)

# real-valued positive-definite linear operators are symmetric
def real_valued(x):
return not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b))) \
if check_symmetric else False
x = lax.custom_linear_solve(
A, b, solve=isolve_solve, transpose_solve=isolve_solve,
info = None
return x, info

def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
Expand Down Expand Up @@ -180,41 +275,9 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)

b, x0 = device_put((b, x0))

if maxiter is None:
size = sum(bi.size for bi in tree_leaves(b))
maxiter = 10 * size # copied from scipy

if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)

if tree_structure(x0) != tree_structure(b):
raise ValueError(
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')

if _shapes(x0) != _shapes(b):
raise ValueError(
'arrays in x0 and b must have matching shapes: '
f'{_shapes(x0)} vs {_shapes(b)}')

cg_solve = partial(
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)

# real-valued positive-definite linear operators are symmetric
def real_valued(x):
return not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b)))
x = lax.custom_linear_solve(
A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)
info = None # TODO(shoyer): return the real iteration count here
return x, info
return _isolve(_cg_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M, check_symmetric=True)

def _safe_normalize(x, thresh=None):
Expand Down Expand Up @@ -624,3 +687,63 @@ def _solve(A, b):
failed = jnp.isnan(_norm(x))
info = jnp.where(failed, x=-1, y=0)
return x, info

def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Bi-Conjugate Gradient Stable iteration to solve ``Ax = b``.
The numerics of JAX's ``bicgstab`` should exact match SciPy's
``bicgstab`` (up to numerical precision), but note that the interface
is slightly different: you need to supply the linear operator ``A`` as
a function instead of a sparse matrix or ``LinearOperator``.
As with ``cg``, derivatives of ``bicgstab`` are implemented via implicit
differentiation with another ``bicgstab`` solve, rather than by
differentiating *through* the solver. They will be accurate only if
both solves converge.
A : function
Function that calculates the matrix-vector product ``Ax`` when called
like ``A(x)``. ``A`` can represent any general (nonsymmetric) linear
operator, and must return array(s) with the same structure and shape as its
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
x0 : array
Starting guess for the solution. Must have the same structure as ``b``.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : function
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
See also

return _isolve(_bicgstab_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M)
1 change: 1 addition & 0 deletions jax/scipy/sparse/
Expand Up @@ -16,4 +16,5 @@
from jax._src.scipy.sparse.linalg import (
111 changes: 111 additions & 0 deletions tests/
Expand Up @@ -52,8 +52,10 @@ def solver(func, A, b, M=None, atol=0.0, **kwargs):

lax_cg = partial(solver,
lax_gmres = partial(solver, jax.scipy.sparse.linalg.gmres)
lax_bicgstab = partial(solver, jax.scipy.sparse.linalg.bicgstab)
scipy_cg = partial(solver,
scipy_gmres = partial(solver, scipy.sparse.linalg.gmres)
scipy_bicgstab = partial(solver, scipy.sparse.linalg.bicgstab)

def rand_sym_pos_def(rng, shape, dtype):
Expand Down Expand Up @@ -193,6 +195,113 @@ def tree_unflatten(cls, aux_data, children):
actual, _ =, b)
self.assertAllClose(expected, actual.value)

jtu.format_shape_dtype_string(shape, dtype),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(5, 5)]
for dtype in [np.float64, np.complex128]
for preconditioner in [None, 'identity', 'exact', 'random']
def test_bicgstab_against_scipy(
self, shape, dtype, preconditioner):
if not config.FLAGS.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")

rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
b = rng(shape[:1], dtype)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)

def args_maker():
return A, b

partial(scipy_bicgstab, M=M, maxiter=1),
partial(lax_bicgstab, M=M, maxiter=1),

partial(scipy_bicgstab, M=M, maxiter=2),
partial(lax_bicgstab, M=M, maxiter=2),

partial(scipy_bicgstab, M=M, maxiter=1),
partial(lax_bicgstab, M=M, maxiter=1),

partial(lax_bicgstab, M=M, atol=1e-6),

jtu.format_shape_dtype_string(shape, dtype),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner}
for shape in [(2, 2), (7, 7)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
A = jnp.eye(shape[1], dtype=dtype)
solution = jnp.ones(shape[1], dtype=dtype)
rng = jtu.rand_default(self.rng())
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
b = matmul_high_precision(A, solution)
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol,
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

jtu.format_shape_dtype_string(shape, dtype),
"shape": shape, "dtype": dtype, "preconditioner": preconditioner
for shape in [(2, 2), (4, 4)]
for dtype in float_types + complex_types
for preconditioner in [None, 'identity', 'exact']
def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
solution = rng(shape[1:], dtype)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)
b = matmul_high_precision(A, solution)
tol = shape[0] * jnp.finfo(A.dtype).eps
x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
# solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0]
# jtu.check_grads(solve, (A, b), order=1, rtol=3e-1)

def test_bicgstab_pytree(self):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
b = {"a": 1.0, "b": -4.0}
expected = {"a": 4.0, "b": -6.0}
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
self.assertEqual(expected.keys(), actual.keys())
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
self.assertAlmostEqual(expected["b"], actual["b"], places=5)

Expand Down Expand Up @@ -302,6 +411,8 @@ def test_gmres_on_random_system(self, shape, dtype, preconditioner,
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
# solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0]
# jtu.check_grads(solve, (A, b), order=1, rtol=2e-1)

def test_gmres_pytree(self):
A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
Expand Down

0 comments on commit 234990e

Please sign in to comment.