Skip to content

Commit

Permalink
Fix corner cases in JAX SVD: a) Clamp negative singular values to zer…
Browse files Browse the repository at this point in the history
…o. b) Return all NaN for matrices with non-finite values.

PiperOrigin-RevId: 540015938
  • Loading branch information
jax authors committed Jun 13, 2023
1 parent 0b20251 commit 21051ff
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 27 deletions.
73 changes: 56 additions & 17 deletions jax/_src/lax/svd.py
Expand Up @@ -42,21 +42,46 @@
from jax._src import core


@functools.partial(jax.jit, static_argnums=(1, 2))
def _zero_svd(a: Any,
full_matrices: bool,
compute_uv: bool = True) -> Union[Any, Sequence[Any]]:
@functools.partial(jax.jit, static_argnums=(2, 3))
def _constant_svd(
a: Any, return_nan: bool, full_matrices: bool, compute_uv: bool = True
) -> Union[Any, Sequence[Any]]:
"""SVD on matrix of all zeros."""
m, n = a.shape
k = min(m, n)
s = jnp.zeros(shape=(k,), dtype=a.real.dtype)
s = jnp.where(
return_nan,
jnp.full(shape=(k,), fill_value=jnp.nan, dtype=a.real.dtype),
jnp.zeros(shape=(k,), dtype=a.real.dtype),
)
if compute_uv:
fill_value = (
jnp.nan + 1j * jnp.nan
if jnp.issubdtype(a.dtype, jnp.complexfloating)
else jnp.nan
)
if full_matrices:
u = jnp.eye(m, m, dtype=a.dtype)
vh = jnp.eye(n, n, dtype=a.dtype)
u = jnp.where(
return_nan,
jnp.full((m, m), fill_value, dtype=a.dtype),
jnp.eye(m, m, dtype=a.dtype),
)
vh = jnp.where(
return_nan,
jnp.full((n, n), fill_value, dtype=a.dtype),
jnp.eye(n, n, dtype=a.dtype),
)
else:
u = jnp.eye(m, k, dtype=a.dtype)
vh = jnp.eye(k, n, dtype=a.dtype)
u = jnp.where(
return_nan,
jnp.full((m, k), fill_value, dtype=a.dtype),
jnp.eye(m, k, dtype=a.dtype),
)
vh = jnp.where(
return_nan,
jnp.full((k, n), fill_value, dtype=a.dtype),
jnp.eye(k, n, dtype=a.dtype),
)
return (u, s, vh)
else:
return s
Expand Down Expand Up @@ -86,6 +111,9 @@ def _svd_tall_and_square_input(

# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax.linalg.eigh(h)
# Singular values are non-negative by definition. But eigh could return small
# negative values, so we clamp them to zero.
s = jnp.maximum(s, 0.0)

# Flips the singular values in descending order.
s_out = jnp.flip(s)
Expand Down Expand Up @@ -228,11 +256,22 @@ def svd(a: Any,
# X_{k+1} = X_k(a_k I + b_k {X_k}^H X_k)(I + c_k {X_k}^H X_k)^{−1} and
# X_0 = A/alpha, where alpha = ||A||_2, the triplet (a_k, b_k, c_k) are
# weighting parameters, and X_k denotes the k^{th} iterate.
return jax.lax.cond(jnp.all(a == 0),
functools.partial(_zero_svd, full_matrices=full_matrices,
compute_uv=compute_uv),
functools.partial(_qdwh_svd, full_matrices=full_matrices,
compute_uv=compute_uv,
hermitian=hermitian,
max_iterations=max_iterations),
operand=(a))
all_zero = jnp.all(a == 0)
non_finite = jnp.logical_not(jnp.all(jnp.isfinite(a)))
return lax.cond(
jnp.logical_or(all_zero, non_finite),
functools.partial(
_constant_svd,
return_nan=non_finite,
full_matrices=full_matrices,
compute_uv=compute_uv,
),
functools.partial(
_qdwh_svd,
full_matrices=full_matrices,
compute_uv=compute_uv,
hermitian=hermitian,
max_iterations=max_iterations,
),
operand=(a),
)
43 changes: 33 additions & 10 deletions tests/svd_test.py
Expand Up @@ -173,21 +173,44 @@ def testSingularValues(self, m, n, log_cond, full_matrices):
np.testing.assert_array_less(actual_diff, np.zeros_like(actual_diff))

@jtu.sample_product(
[dict(m=m, n=n) for m, n in zip([2, 4, 8], [4, 4, 6])],
full_matrices=[True, False],
compute_uv=[True, False],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
[dict(m=m, n=n) for m, n in zip([2, 4, 8], [4, 4, 6])],
full_matrices=[True, False],
compute_uv=[True, False],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def testSvdOnZero(self, m, n, full_matrices, compute_uv, dtype):
"""Tests SVD on matrix of all zeros."""
osp_fun = functools.partial(osp_linalg.svd, full_matrices=full_matrices,
compute_uv=compute_uv)
lax_fun = functools.partial(svd.svd, full_matrices=full_matrices,
compute_uv=compute_uv)
def testSvdAllZero(self, m, n, full_matrices, compute_uv, dtype):
"""Tests SVD on matrix of all zeros, +/-infinity or NaN."""
osp_fun = functools.partial(
osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv
)
lax_fun = functools.partial(
svd.svd, full_matrices=full_matrices, compute_uv=compute_uv
)
args_maker_svd = lambda: [jnp.zeros((m, n), dtype=dtype)]
self._CheckAgainstNumpy(osp_fun, lax_fun, args_maker_svd)
self._CompileAndCheck(lax_fun, args_maker_svd)

@jtu.sample_product(
[dict(m=m, n=n) for m, n in zip([2, 4, 8], [4, 4, 6])],
fill_value=[-np.inf, np.inf, np.nan],
full_matrices=[True, False],
compute_uv=[True, False],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def testSvdNonFiniteValues(
self, m, n, fill_value, full_matrices, compute_uv, dtype
):
"""Tests SVD on matrix of all zeros, +/-infinity or NaN."""
lax_fun = functools.partial(
svd.svd, full_matrices=full_matrices, compute_uv=compute_uv
)
args_maker_svd = lambda: [
jnp.full((m, n), fill_value=fill_value, dtype=dtype)
]
result = lax_fun(args_maker_svd()[0])
for r in result:
self.assertTrue(jnp.all(jnp.isnan(r)))
self._CompileAndCheck(lax_fun, args_maker_svd)

@jtu.sample_product(
[dict(m=m, n=n, r=r, c=c)
Expand Down

0 comments on commit 21051ff

Please sign in to comment.