Skip to content

Commit

Permalink
[linalg] Adds full_matrices option to TPU SVD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 443163571
  • Loading branch information
tlu7 authored and jax authors committed Apr 20, 2022
1 parent 29d54e3 commit 455c9f8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 40 deletions.
50 changes: 33 additions & 17 deletions jax/_src/lax/svd.py
Expand Up @@ -45,14 +45,14 @@

@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _svd(a: jnp.ndarray,
is_hermitian: bool,
hermitian: bool,
compute_uv: bool,
max_iterations: int) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
"""Singular value decomposition for m x n matrix and m >= n.
Args:
a: A matrix of shape `m x n` with `m >= n`.
is_hermitian: True if `a` is Hermitian.
hermitian: True if `a` is Hermitian.
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
max_iterations: The predefined maximum number of iterations of QDWH.
Expand All @@ -63,7 +63,7 @@ def _svd(a: jnp.ndarray,
`a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
"""

u, h, _, _ = lax.linalg.qdwh(a, is_hermitian, max_iterations)
u, h, _, _ = lax.linalg.qdwh(a, hermitian, max_iterations)

# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax.linalg.eigh(h)
Expand Down Expand Up @@ -98,28 +98,33 @@ def correct_rank_deficiency(u_out):
return (u_out, s_out, v_out)


@functools.partial(jax.jit, static_argnums=(1, 2, 3))
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def svd(a: jnp.ndarray,
is_hermitian: bool = False,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10) -> Union[jnp.ndarray, Sequence[jnp.ndarray]]:
"""Singular value decomposition.
Args:
a: A matrix of shape `m x n`.
is_hermitian: True if `a` is Hermitian.
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
where `k = min(m, n)`.
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `vh`), where `u` is a unitary matrix of shape `m x k`,
`s` is vector of length `k` containing the singular values in the descending
order, `vh` is a unitary matrix of shape `k x n`, `k = min(m, n)`, and
`a = (u * s) @ vh`. For `compute_uv=False`, only `s` is returned.
A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
`s` is vector of length `k` containing the singular values in the
non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh`
depend on the value of `full_matrices`. For `compute_uv=False`,
only `s` is returned.
"""

is_hermitian = core.concrete_or_error(
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
hermitian = core.concrete_or_error(
bool, hermitian, 'The `hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')

max_iterations = core.concrete_or_error(
Expand All @@ -135,19 +140,30 @@ def svd(a: jnp.ndarray,
is_flip = True

reduce_to_square = False
if m > 1.15 * n:
m = n
q, a = lax.linalg.qr(a, full_matrices=False)
if full_matrices:
q_full, a_full = lax.linalg.qr(a, full_matrices=True)
q = q_full[:, :n]
u_out_null = q_full[:, n:]
a = a_full[:n, :]
reduce_to_square = True
else:
# The constant `1.15` comes from Yuji Nakatsukasa's implementation
# https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
if m > 1.15 * n:
q, a = lax.linalg.qr(a, full_matrices=False)
reduce_to_square = True

if not compute_uv:
return _svd(a, is_hermitian, compute_uv, max_iterations)
return _svd(a, hermitian, compute_uv, max_iterations)

u_out, s_out, v_out = _svd(a, is_hermitian, compute_uv, max_iterations)
u_out, s_out, v_out = _svd(a, hermitian, compute_uv, max_iterations)

if reduce_to_square:
u_out = q @ u_out

if full_matrices:
u_out = jnp.hstack((u_out, u_out_null))

if is_flip:
return(v_out, s_out, u_out.T.conj())

Expand Down
59 changes: 36 additions & 23 deletions tests/svd_test.py
Expand Up @@ -47,45 +47,55 @@ class SvdTest(jtu.JaxTestCase):

@parameterized.named_parameters(jtu.cases_from_list(
{ # pylint:disable=g-complex-comprehension
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
'm': m, 'n': n, 'log_cond': log_cond}
'testcase_name': '_m={}_by_n={}_log_cond={}_full_matrices={}'.format(
m, n, log_cond, full_matrices),
'm': m, 'n': n, 'log_cond': log_cond, 'full_matrices': full_matrices}
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)
for full_matrices in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSvdWithRectangularInput(self, m, n, log_cond):
def testSvdWithRectangularInput(self, m, n, log_cond, full_matrices):
"""Tests SVD with rectangular input."""
with jax.default_matmul_precision('float32'):
a = np.random.uniform(
low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE)
u, s, v = jnp.linalg.svd(a, full_matrices=False)
u, s, v = osp_linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.linspace(cond, 1, min(m, n))
a = (u * s) @ v
a = a + 1j * a

osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=False)
actual_u, actual_s, actual_v = svd.svd(a)
osp_linalg_fn = functools.partial(
osp_linalg.svd, full_matrices=full_matrices)
actual_u, actual_s, actual_v = svd.svd(a, full_matrices=full_matrices)

k = min(m, n)
if m > n:
unitary_u = jnp.abs(actual_u.T.conj() @ actual_u)
unitary_v = jnp.abs(actual_v.T.conj() @ actual_v)
unitary_u = jnp.real(actual_u.T.conj() @ actual_u)
unitary_v = jnp.real(actual_v.T.conj() @ actual_v)
unitary_u_size = m if full_matrices else k
unitary_v_size = k
else:
unitary_u = jnp.abs(actual_u @ actual_u.T.conj())
unitary_v = jnp.abs(actual_v @ actual_v.T.conj())
unitary_u = jnp.real(actual_u @ actual_u.T.conj())
unitary_v = jnp.real(actual_v @ actual_v.T.conj())
unitary_u_size = k
unitary_v_size = n if full_matrices else k

_, expected_s, _ = osp_linalg_fn(a)

svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices)
args_maker = lambda: [a]

with self.subTest('Test JIT compatibility'):
self._CompileAndCheck(svd.svd, args_maker)
self._CompileAndCheck(svd_fn, args_maker)

with self.subTest('Test unitary u.'):
self.assertAllClose(np.eye(k), unitary_u, rtol=_SVD_RTOL, atol=2E-3)
self.assertAllClose(np.eye(unitary_u_size), unitary_u, rtol=_SVD_RTOL,
atol=2E-3)

with self.subTest('Test unitary v.'):
self.assertAllClose(np.eye(k), unitary_v, rtol=_SVD_RTOL, atol=2E-3)
self.assertAllClose(np.eye(unitary_v_size), unitary_v, rtol=_SVD_RTOL,
atol=2E-3)

with self.subTest('Test s.'):
self.assertAllClose(
Expand All @@ -100,7 +110,7 @@ def testSvdWithSkinnyTallInput(self, m, n):
with jax.default_matmul_precision('float32'):
np.random.seed(1235)
a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE)
u, s, v = svd.svd(a, is_hermitian=False)
u, s, v = svd.svd(a, full_matrices=False, hermitian=False)

relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a)

Expand All @@ -126,19 +136,21 @@ def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
a = (u * s) @ v

with jax.default_matmul_precision('float32'):
u, s, v = svd.svd(a, is_hermitian=False)
u, s, v = svd.svd(a, full_matrices=False, hermitian=False)
diff = np.linalg.norm(a - (u * s) @ v)

np.testing.assert_almost_equal(diff, 1E-4, decimal=2)

@parameterized.named_parameters(jtu.cases_from_list(
{ # pylint:disable=g-complex-comprehension
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
'm': m, 'n': n, 'log_cond': log_cond}
'testcase_name': '_m={}_by_n={}_log_cond={}_full_matrices={}'.format(
m, n, log_cond, full_matrices),
'm': m, 'n': n, 'log_cond': log_cond, 'full_matrices': full_matrices}
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)
for full_matrices in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSingularValues(self, m, n, log_cond):
def testSingularValues(self, m, n, log_cond, full_matrices):
"""Tests singular values."""
with jax.default_matmul_precision('float32'):
a = np.random.uniform(
Expand All @@ -153,15 +165,16 @@ def testSingularValues(self, m, n, log_cond):
compute_uv = False

osp_linalg_fn = functools.partial(
osp_linalg.svd, full_matrices=False, compute_uv=compute_uv)
actual_s = svd.svd(a, compute_uv=compute_uv)
osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv)
actual_s = svd.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)

expected_s = osp_linalg_fn(a)

svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices)
args_maker = lambda: [a]

with self.subTest('Test JIT compatibility'):
self._CompileAndCheck(svd.svd, args_maker)
self._CompileAndCheck(svd_fn, args_maker)

with self.subTest('Test s.'):
self.assertAllClose(expected_s, actual_s, rtol=_SVD_RTOL, atol=1E-6)
Expand Down

0 comments on commit 455c9f8

Please sign in to comment.