diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index cf0ffa562ef7..1657ea4c9d63 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -42,21 +42,46 @@ from jax._src import core -@functools.partial(jax.jit, static_argnums=(1, 2)) -def _zero_svd(a: Any, - full_matrices: bool, - compute_uv: bool = True) -> Union[Any, Sequence[Any]]: +@functools.partial(jax.jit, static_argnums=(2, 3)) +def _constant_svd( + a: Any, return_nan: bool, full_matrices: bool, compute_uv: bool = True +) -> Union[Any, Sequence[Any]]: """SVD on matrix of all zeros.""" m, n = a.shape k = min(m, n) - s = jnp.zeros(shape=(k,), dtype=a.real.dtype) + s = jnp.where( + return_nan, + jnp.full(shape=(k,), fill_value=jnp.nan, dtype=a.real.dtype), + jnp.zeros(shape=(k,), dtype=a.real.dtype), + ) if compute_uv: + fill_value = ( + jnp.nan + 1j * jnp.nan + if jnp.issubdtype(a.dtype, jnp.complexfloating) + else jnp.nan + ) if full_matrices: - u = jnp.eye(m, m, dtype=a.dtype) - vh = jnp.eye(n, n, dtype=a.dtype) + u = jnp.where( + return_nan, + jnp.full((m, m), fill_value, dtype=a.dtype), + jnp.eye(m, m, dtype=a.dtype), + ) + vh = jnp.where( + return_nan, + jnp.full((n, n), fill_value, dtype=a.dtype), + jnp.eye(n, n, dtype=a.dtype), + ) else: - u = jnp.eye(m, k, dtype=a.dtype) - vh = jnp.eye(k, n, dtype=a.dtype) + u = jnp.where( + return_nan, + jnp.full((m, k), fill_value, dtype=a.dtype), + jnp.eye(m, k, dtype=a.dtype), + ) + vh = jnp.where( + return_nan, + jnp.full((k, n), fill_value, dtype=a.dtype), + jnp.eye(k, n, dtype=a.dtype), + ) return (u, s, vh) else: return s @@ -86,6 +111,9 @@ def _svd_tall_and_square_input( # TODO: Uses `eigvals_only=True` if `compute_uv=False`. v, s = lax.linalg.eigh(h) + # Singular values are non-negative by definition. But eigh could return small + # negative values, so we clamp them to zero. + s = jnp.maximum(s, 0.0) # Flips the singular values in descending order. s_out = jnp.flip(s) @@ -228,11 +256,22 @@ def svd(a: Any, # X_{k+1} = X_k(a_k I + b_k {X_k}^H X_k)(I + c_k {X_k}^H X_k)^{−1} and # X_0 = A/alpha, where alpha = ||A||_2, the triplet (a_k, b_k, c_k) are # weighting parameters, and X_k denotes the k^{th} iterate. - return jax.lax.cond(jnp.all(a == 0), - functools.partial(_zero_svd, full_matrices=full_matrices, - compute_uv=compute_uv), - functools.partial(_qdwh_svd, full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - max_iterations=max_iterations), - operand=(a)) + all_zero = jnp.all(a == 0) + non_finite = jnp.logical_not(jnp.all(jnp.isfinite(a))) + return lax.cond( + jnp.logical_or(all_zero, non_finite), + functools.partial( + _constant_svd, + return_nan=non_finite, + full_matrices=full_matrices, + compute_uv=compute_uv, + ), + functools.partial( + _qdwh_svd, + full_matrices=full_matrices, + compute_uv=compute_uv, + hermitian=hermitian, + max_iterations=max_iterations, + ), + operand=(a), + ) diff --git a/tests/svd_test.py b/tests/svd_test.py index b11ed508b22a..abb1f4037194 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -173,21 +173,44 @@ def testSingularValues(self, m, n, log_cond, full_matrices): np.testing.assert_array_less(actual_diff, np.zeros_like(actual_diff)) @jtu.sample_product( - [dict(m=m, n=n) for m, n in zip([2, 4, 8], [4, 4, 6])], - full_matrices=[True, False], - compute_uv=[True, False], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, + [dict(m=m, n=n) for m, n in zip([2, 4, 8], [4, 4, 6])], + full_matrices=[True, False], + compute_uv=[True, False], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) - def testSvdOnZero(self, m, n, full_matrices, compute_uv, dtype): - """Tests SVD on matrix of all zeros.""" - osp_fun = functools.partial(osp_linalg.svd, full_matrices=full_matrices, - compute_uv=compute_uv) - lax_fun = functools.partial(svd.svd, full_matrices=full_matrices, - compute_uv=compute_uv) + def testSvdAllZero(self, m, n, full_matrices, compute_uv, dtype): + """Tests SVD on matrix of all zeros, +/-infinity or NaN.""" + osp_fun = functools.partial( + osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv + ) + lax_fun = functools.partial( + svd.svd, full_matrices=full_matrices, compute_uv=compute_uv + ) args_maker_svd = lambda: [jnp.zeros((m, n), dtype=dtype)] self._CheckAgainstNumpy(osp_fun, lax_fun, args_maker_svd) self._CompileAndCheck(lax_fun, args_maker_svd) + @jtu.sample_product( + [dict(m=m, n=n) for m, n in zip([2, 4, 8], [4, 4, 6])], + fill_value=[-np.inf, np.inf, np.nan], + full_matrices=[True, False], + compute_uv=[True, False], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def testSvdNonFiniteValues( + self, m, n, fill_value, full_matrices, compute_uv, dtype + ): + """Tests SVD on matrix of all zeros, +/-infinity or NaN.""" + lax_fun = functools.partial( + svd.svd, full_matrices=full_matrices, compute_uv=compute_uv + ) + args_maker_svd = lambda: [ + jnp.full((m, n), fill_value=fill_value, dtype=dtype) + ] + result = lax_fun(args_maker_svd()[0]) + for r in result: + self.assertTrue(jnp.all(jnp.isnan(r))) + self._CompileAndCheck(lax_fun, args_maker_svd) @jtu.sample_product( [dict(m=m, n=n, r=r, c=c)