Skip to content

Commit

Permalink
Add a tridiagonal eigh solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed May 3, 2021
1 parent 75b00a1 commit 41b2bff
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,9 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
be used within `jit` ({jax-issue}`#6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`).
* Added {func}`jax.scipy.linalg.eigh_tridiagonal` that computes the
eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
present.
* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
Expand Down
139 changes: 139 additions & 0 deletions jax/_src/scipy/linalg.py
Expand Up @@ -15,6 +15,7 @@

from functools import partial

import numpy as np
import scipy.linalg
import textwrap

Expand Down Expand Up @@ -434,3 +435,141 @@ def block_diag(*arrs):
acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
acc = lax.concatenate([acc, a], dimension=0)
return acc


# TODO(phawkins): use static_argnames when jaxlib 0.1.66 is the minimum and
# remove this wrapper.
@_wraps(scipy.linalg.eigh_tridiagonal)
def eigh_tridiagonal(d, e, tol=None, eigvals_only=False):
return _eigh_tridiagonal(d, e, tol, eigvals_only)

@partial(jit, static_argnums=(3,))
def _eigh_tridiagonal(d, e, tol, eigvals_only):
if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented")

alpha = d
beta = e

def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
"""Implements the Sturm sequence recurrence."""
n = alpha.shape[0]
zeros = jnp.zeros(x.shape, dtype=jnp.int32)
ones = jnp.ones(x.shape, dtype=jnp.int32)

# The first step in the Sturm sequence recurrence
# requires special care if x is equal to alpha[0].
def sturm_step0():
q = alpha[0] - x
count = jnp.where(q < 0, ones, zeros)
q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
return q, count

# Subsequent steps all take this form:
def sturm_step(i, q, count):
q = alpha[i] - beta_sq[i - 1] / q - x
count = jnp.where(q <= pivmin, count + 1, count)
q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
return q, count

# The first step initializes q and count.
q, count = sturm_step0()

# Peel off ((n-1) % blocksize) steps from the main loop, so we can run
# the bulk of the iterations unrolled by a factor of blocksize.
blocksize = 16
i = 1
peel = (n - 1) % blocksize
unroll_cnt = peel

def unrolled_steps(args):
start, q, count = args
for j in range(unroll_cnt):
q, count = sturm_step(start + j, q, count)
return start + unroll_cnt, q, count

i, q, count = unrolled_steps((i, q, count))

# Run the remaining steps of the Sturm sequence using an partially
# unrolled while loop.
unroll_cnt = blocksize
def cond(iqc):
i, q, count = iqc
return jnp.less(i, n)
_, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
return count

alpha = jnp.asarray(alpha)
beta = jnp.asarray(beta)
supported_dtypes = (jnp.float32, jnp.float64)
if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
raise TypeError("Only float32 and float64 inputs are supported as inputs "
"to jax.scipy.linalg.eigh_tridiagonal, got "
f"{alpha.dtype} and {beta.dtype}")
n = alpha.shape[0]
if n == 1:
return alpha
beta_abs = jnp.abs(beta)
beta_sq = beta * beta

# Compute an interval containing the eigenvalue of T using the Gerschgorin
# circle theorem.
# |beta_i| + |beta_{i+1}|
off_diag_abs_row_sum = jnp.concatenate(
[beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta[-1:]], axis=0)
fudge_factor = 1.1 # We widen the Gershgorin interval a bit.
lambda_max = jnp.amax(alpha + fudge_factor * off_diag_abs_row_sum)
lambda_min = jnp.amin(alpha - fudge_factor * off_diag_abs_row_sum)

# Compute the smallest allowed pivot in the Sturm sequence to avoid
# overflow.
finfo = np.finfo(alpha.dtype)
one = np.ones([], dtype=alpha.dtype)
safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
pivmin = safemin * jnp.amax(beta_sq)
alpha0_perturbation = jnp.square(finfo.eps * beta[0])
abs_tol = finfo.eps * jnp.maximum(
jnp.abs(lambda_min), jnp.abs(lambda_max))
if tol is not None:
abs_tol = jnp.maximum(tol, abs_tol)
# In the worst case, if when the absolute tolerance is eps*lambda_max, and
# lambda_max = -lambda_min, we have to take as many bisection steps as there
# are bits in the mantissa plus 1.
# The proof is left as an exercise to the reader.
max_it = finfo.nmant + 1

# We want to find [lambda_0, lambda_1, ..., lambda_{n-1}], such that the
# number of eigenvalues of T less than lambda_i is i.
# TODO(rmlarsen): Extend this logic to support the "select" keyword to
# to specify a subset of eigenvalues to compute.
target_counts = jnp.arange(n)

# Run binary search for all desired eigenvalues in parallel, starting from
# the interval [lambda_min, lambda_max].
upper = jnp.broadcast_to(lambda_max, shape=target_counts.shape)
lower = jnp.broadcast_to(lambda_min, shape=target_counts.shape)
mid = 0.5 * (upper + lower)

# Pre-broadcast the fixed scalars used in the Sturm sequence for improved
# performance.
pivmin = jnp.broadcast_to(pivmin, target_counts.shape)
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation,
target_counts.shape)

# Start parallel binary searches.
def cond(args):
i, lower, _, upper = args
return jnp.logical_and(
jnp.less(i, max_it),
jnp.less(abs_tol, jnp.amax(upper - lower)))

def body(args):
i, lower, mid, upper = args
counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
lower = jnp.where(counts <= target_counts, mid, lower)
upper = jnp.where(counts > target_counts, mid, upper)
mid = 0.5 * (lower + upper)
return i + 1, lower, mid, upper

_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
return mid
1 change: 1 addition & 0 deletions jax/scipy/linalg.py
Expand Up @@ -21,6 +21,7 @@
cho_solve,
det,
eigh,
eigh_tridiagonal,
expm,
expm_frechet,
inv,
Expand Down
26 changes: 26 additions & 0 deletions tests/linalg_test.py
Expand Up @@ -18,6 +18,7 @@
import unittest

import numpy as np
import scipy
import scipy as osp

from absl.testing import absltest
Expand Down Expand Up @@ -1411,5 +1412,30 @@ def expm(x):
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)

class EighTridiagonalTest(jtu.JaxTestCase):

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
"n": n, "dtype": dtype}
for n in [2, 3, 7, 8, 100]
for dtype in float_types))
def testToeplitz(self, n, dtype):
jtu.skip_if_unsupported_type(dtype)
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
if (jtu.device_under_test() == "cpu" and dtype == np.float64 and
b == 1e-10):
# TODO(phawkins): this test fails on CPU but not on GPU.
continue
alpha = a * np.ones([n], dtype=dtype)
beta = b * np.ones([n - 1], dtype=dtype)
eigvals_expected = scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
eigvals = jax.scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
finfo = np.finfo(dtype)
atol = 4 * np.sqrt(n) * finfo.eps * np.amax(np.abs(eigvals_expected))
self.assertAllClose(eigvals_expected, eigvals, atol=atol, rtol=1e-4)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 41b2bff

Please sign in to comment.