Skip to content

Commit

Permalink
Respect dtype in svd compression_matrix #2849 (#6802)
Browse files Browse the repository at this point in the history
This prevents from promoting datatypes of less then 4 bytes precision to promote to 8 bytes as explained in #2849. This also improves the performances drastically as the following QR and SVD will be faster.
  • Loading branch information
RogerMoens committed Nov 12, 2020
1 parent 061613e commit aeb7815
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
5 changes: 4 additions & 1 deletion dask/array/linalg.py
Expand Up @@ -675,9 +675,12 @@ def compression_matrix(data, q, n_power_iter=0, seed=None, compute=False):
state = seed
else:
state = RandomState(seed)
datatype = np.float64
if (data.dtype).type in {np.float32, np.complex64}:
datatype = np.float32
omega = state.standard_normal(
size=(n, comp_level), chunks=(data.chunks[1], (comp_level,))
)
).astype(datatype, copy=False)
mat_h = data.dot(omega)
for j in range(n_power_iter):
if compute:
Expand Down
9 changes: 9 additions & 0 deletions dask/array/tests/test_linalg.py
Expand Up @@ -480,6 +480,15 @@ def test_svd_compressed():
assert_eq(s, s_exact) # s must contain the singular values


@pytest.mark.parametrize(
"input_dtype, output_dtype", [(np.float32, np.float32), (np.float64, np.float64)]
)
def test_svd_compressed_dtype_preservation(input_dtype, output_dtype):
x = da.random.random((50, 50), chunks=(50, 50)).astype(input_dtype)
u, s, vt = svd_compressed(x, 1, seed=4321)
assert u.dtype == s.dtype == vt.dtype == output_dtype


@pytest.mark.parametrize("chunks", [(10, 50), (50, 10), (-1, -1)])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_svd_dtype_preservation(chunks, dtype):
Expand Down

0 comments on commit aeb7815

Please sign in to comment.