Skip to content

Commit

Permalink
[linalg] Adds compute_uv to TPU SVD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 442864883
  • Loading branch information
tlu7 authored and jax authors committed Apr 19, 2022
1 parent cc2b830 commit 5a1c5ba
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
27 changes: 19 additions & 8 deletions jax/_src/lax/svd.py
Expand Up @@ -35,39 +35,45 @@

import functools

from typing import Sequence
from typing import Sequence, Union

import jax
from jax import core
from jax import lax
import jax.numpy as jnp


@functools.partial(jax.jit, static_argnums=(1, 2))
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _svd(a: jnp.ndarray,
is_hermitian: bool,
max_iterations: int) -> Sequence[jnp.ndarray]:
compute_uv: bool,
max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
"""Singular value decomposition for m x n matrix and m >= n.
Args:
a: A matrix of shape `m x n` with `m >= n`.
is_hermitian: True if `a` is Hermitian.
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
`s` is vector of length `n` containing the singular values in the descending
order, `v` is a unitary matrix of shape `n x n`, and
`a = (u * s) @ v.T.conj()`.
`a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
"""

u, h, _, _ = lax.linalg.qdwh(a, is_hermitian, max_iterations)

# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax.linalg.eigh(h)

# Flips the singular values in descending order.
s_out = jnp.flip(s)

if not compute_uv:
return s_out

# Reorders eigenvectors.
v_out = jnp.fliplr(v)

Expand All @@ -92,22 +98,24 @@ def correct_rank_deficiency(u_out):
return (u_out, s_out, v_out)


@functools.partial(jax.jit, static_argnums=(1, 2))
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def svd(a: jnp.ndarray,
is_hermitian: bool = False,
max_iterations: int = 10) -> Sequence[jnp.ndarray]:
compute_uv: bool = True,
max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
"""Singular value decomposition.
Args:
a: A matrix of shape `m x n`.
is_hermitian: True if `a` is Hermitian.
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `vh`), where `u` is a unitary matrix of shape `m x k`,
`s` is vector of length `k` containing the singular values in the descending
order, `vh` is a unitary matrix of shape `k x n`, `k = min(m, n)`, and
`a = (u * s) @ vh`.
`a = (u * s) @ vh`. For `compute_uv=False`, only `s` is returned.
"""

is_hermitian = core.concrete_or_error(
Expand All @@ -132,7 +140,10 @@ def svd(a: jnp.ndarray,
q, a = lax.linalg.qr(a, full_matrices=False)
reduce_to_square = True

u_out, s_out, v_out = _svd(a, is_hermitian, max_iterations)
if not compute_uv:
return _svd(a, is_hermitian, compute_uv, max_iterations)

u_out, s_out, v_out = _svd(a, is_hermitian, compute_uv, max_iterations)

if reduce_to_square:
u_out = q @ u_out
Expand Down
40 changes: 40 additions & 0 deletions tests/svd_test.py
Expand Up @@ -131,6 +131,46 @@ def testSvdWithOnRankDeficientInput(self, m, r, log_cond):

np.testing.assert_almost_equal(diff, 1E-4, decimal=2)

@parameterized.named_parameters(jtu.cases_from_list(
{ # pylint:disable=g-complex-comprehension
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSingularValues(self, m, n, log_cond):
"""Tests singular values."""
with jax.default_matmul_precision('float32'):
a = np.random.uniform(
low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE)
u, s, v = osp_linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = np.linspace(cond, 1, min(m, n))
a = (u * s) @ v
a = a + 1j * a

# Only computes singular values.
compute_uv = False

osp_linalg_fn = functools.partial(
osp_linalg.svd, full_matrices=False, compute_uv=compute_uv)
actual_s = svd.svd(a, compute_uv=compute_uv)

expected_s = osp_linalg_fn(a)

args_maker = lambda: [a]

with self.subTest('Test JIT compatibility'):
self._CompileAndCheck(svd.svd, args_maker)

with self.subTest('Test s.'):
self.assertAllClose(expected_s, actual_s, rtol=_SVD_RTOL, atol=1E-6)

with self.subTest('Test non-increasing order.'):
# Computes `actual_diff[i] = s[i+1] - s[i]`.
actual_diff = jnp.diff(actual_s, append=0)
np.testing.assert_array_less(actual_diff, np.zeros_like(actual_diff))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 5a1c5ba

Please sign in to comment.