Skip to content

Commit

Permalink
[TPU] Switch the default eigendecomposition implementation on TPU to …
Browse files Browse the repository at this point in the history
…use QDWH-eig.

Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule.

PiperOrigin-RevId: 451471088
  • Loading branch information
hawkinsp authored and jax authors committed May 27, 2022
1 parent 0553f9e commit 5ccdcc5
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 20 deletions.
24 changes: 16 additions & 8 deletions jax/_src/lax/eigh.py
Expand Up @@ -24,6 +24,7 @@
import jax._src.numpy.linalg as jnp_linalg
from jax import lax
from jax._src.lax import qdwh
from jax._src.lax import linalg as lax_linalg
from jax._src.lax.stack import Stack


Expand Down Expand Up @@ -384,16 +385,19 @@ def loop_body(state):
return blocks[:, 0], eigenvectors


def eigh(H, *, precision="float32", termination_size=256, n=None):
def eigh(H, *, precision="float32", termination_size=256, n=None,
sort_eigenvalues=True):
""" Computes the eigendecomposition of the symmetric/Hermitian matrix H.
Args:
H: The `n x n` Hermitian input.
H: The `n x n` Hermitian input, padded to `N x N`.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
termination_size: Recursion ends once the blocks reach this linear size.
n: the true (dynamic) size of the matrix.
sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
highest.
Returns:
vals: The `n` eigenvalues of `H`, sorted from lowest to highest.
vals: The `n` eigenvalues of `H`.
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
to numerical error.
Expand All @@ -403,7 +407,10 @@ def eigh(H, *, precision="float32", termination_size=256, n=None):
raise TypeError(f"Input H of shape {H.shape} must be square.")

if N <= termination_size:
return jnp_linalg.eigh(H)
if n is not None:
H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
return lax_linalg.eigh_jacobi(
H, sort_eigenvalues=sort_eigenvalues)

# TODO(phawkins): consider rounding N up to a larger size to maximize reuse
# between matrices.
Expand All @@ -412,7 +419,8 @@ def eigh(H, *, precision="float32", termination_size=256, n=None):
with jax.default_matmul_precision(precision):
eig_vals, eig_vecs = _eigh_work(H, n, termination_size=termination_size)
eig_vals = _mask(jnp.real(eig_vals), (n,), jnp.nan)
sort_idxs = jnp.argsort(eig_vals)
eig_vals = eig_vals[sort_idxs]
eig_vecs = eig_vecs[:, sort_idxs]
if sort_eigenvalues:
sort_idxs = jnp.argsort(eig_vals)
eig_vals = eig_vals[sort_idxs]
eig_vecs = eig_vecs[:, sort_idxs]
return eig_vals, eig_vecs
97 changes: 89 additions & 8 deletions jax/_src/lax/linalg.py
Expand Up @@ -35,6 +35,8 @@
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.lax import control_flow
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
import jax._src.lib
Expand Down Expand Up @@ -574,17 +576,54 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,

# Symmetric/Hermitian eigendecomposition

def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,

def eigh_jacobi(x, *, lower: bool = True, sort_eigenvalues: bool = True):
"""Helper Jacobi eigendecomposition implemented by XLA.
Used as a subroutine of QDWH-eig on TPU."""
w, v = eigh_jacobi_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
return w, v

def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return v, w
return w, v

def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
sort_eigenvalues):
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
"got shape {}".format(operand.shape))

batch_dims = operand.shape[:-2]
n = operand.shape[-1]
w = operand.update(shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(operand.dtype))
v = operand.update(shape=batch_dims + (n, n))
else:
w, v = operand, operand
return w, v

def _eigh_jacobi_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
sort_eigenvalues):
operand_aval, = avals_in
if operand_aval.shape[-1] == 0:
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
return xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
return [xops.Real(xops.Reshape(operand, operand_aval.shape[:-1])), operand]
v, w = xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
return w, v

eigh_jacobi_p = Primitive('eigh_jacobi')
eigh_jacobi_p.multiple_results = True
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
xla.register_translation(eigh_jacobi_p, _eigh_jacobi_translation_rule)


def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return v, w

def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
Expand Down Expand Up @@ -625,6 +664,45 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
w, _nan_like_mhlo(w_aval))
return [v, w]

def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
*_, m, n = x.shape
assert m == n, (m, n)

termination_size = 256

if m <= termination_size:
eig_vals, eig_vecs = eigh_jacobi(x, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return eig_vecs, eig_vals

def eigh_qdwh(x):
if len(x.shape) > 2:
return control_flow.map(eigh_qdwh, x)

# We should only look at elements from the lower/upper triangle. Reflects
# that triangle into the other triangle to form a Hermitian matrix.
if lower:
mask = jnp.tri(n, k=0, dtype=bool)
else:
mask = jnp.logical_not(jnp.tri(n, k=-1, dtype=bool))
if dtypes.issubdtype(x.dtype, jnp.complexfloating):
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
if lower:
im_mask = jnp.tri(n, k=-1, dtype=bool)
else:
im_mask = jnp.logical_not(jnp.tri(n, k=0, dtype=bool))
im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x)))
im = lax.select(mask, im, -_T(im))
x = lax.complex(re, im)
else:
x = lax.select(mask, x, _T(x))

return lax_eigh.eigh(x, sort_eigenvalues=sort_eigenvalues,
termination_size=termination_size)

eig_vals, eig_vecs = eigh_qdwh(x)
return eig_vecs, eig_vals

def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
Expand Down Expand Up @@ -663,7 +741,6 @@ def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
eigh_p.multiple_results = True
eigh_p.def_impl(_eigh_impl)
eigh_p.def_abstract_eval(_eigh_abstract_eval)
xla.register_translation(eigh_p, _eigh_translation_rule)
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
batching.primitive_batchers[eigh_p] = _eigh_batching_rule

Expand All @@ -685,6 +762,10 @@ def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
platform='gpu')

mlir.register_lowering(
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
platform='tpu')


triangular_solve_dtype_rule = partial(
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1012,6 +1012,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"xla_pmap",
"geqrf",
"orgqr",
"eigh_jacobi",
]

tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
Expand Down
8 changes: 4 additions & 4 deletions tests/linalg_test.py
Expand Up @@ -332,7 +332,7 @@ def testEigBatching(self, shape, dtype):
jtu.format_shape_dtype_string((n,n), dtype), lower,
sort_eigenvalues),
"n": n, "dtype": dtype, "lower": lower}
for n in [0, 4, 5, 50]
for n in [0, 4, 5, 50, 512]
for dtype in float_types + complex_types
for lower in [True, False]
for sort_eigenvalues in [True, False]))
Expand Down Expand Up @@ -459,16 +459,16 @@ def testEighGradPrecision(self):
{"testcase_name":
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
"shape": shape, "dtype": dtype}
for shape in [(1, 1), (4, 4), (5, 5)]
for shape in [(1, 1), (4, 4), (5, 5), (300, 300)]
for dtype in float_types + complex_types))
def testEighBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
args = rng(shape, dtype)
args = (args + np.conj(T(args))) / 2
ws, vs = vmap(jsp.linalg.eigh)(args)
self.assertTrue(np.all(np.linalg.norm(
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs))
self.assertTrue(norm < 3e-2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
Expand Down

0 comments on commit 5ccdcc5

Please sign in to comment.