Skip to content

Commit

Permalink
Consolidate compute calls in svd test
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Sep 2, 2015
1 parent 1019c88 commit 24977ef
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions dask/array/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,16 @@ def test_linalg_consistent_names():
def test_svd_compressed():
m, n = 300, 250
r = 10
np.random.seed(1234)
np.random.seed(4321)
mat1 = np.random.randn(m, r)
mat2 = np.random.randn(r, n)
mat = mat1.dot(mat2)
data = da.from_array(mat, chunks=(50, 50))

n_iter = 6
for i in range(n_iter):
u, s, vt = svd_compressed(data, r, seed=1234)
u = np.array(u)
s = np.array(s)
vt = np.array(vt)
u, s, vt = svd_compressed(data, r, seed=4321)
u, s, vt = da.compute(u, s, vt)
if i == 0:
usvt = np.dot(u, np.dot(np.diag(s), vt))
else:
Expand All @@ -128,10 +126,11 @@ def test_svd_compressed():
np.linalg.norm(mat),
rtol=tol, atol=tol) # average accuracy check

u, s, vt = svd_compressed(data, r, seed=1234)
u = np.array(u)[:, :r]
s = np.array(s)[:r]
vt = np.array(vt)[:r, :]
u, s, vt = svd_compressed(data, r, seed=4321)
u, s, vt = da.compute(u, s, vt)
u = u[:, :r]
s = s[:r]
vt = vt[:r, :]

s_exact = np.linalg.svd(mat)[1]
s_exact = s_exact[:r]
Expand Down

0 comments on commit 24977ef

Please sign in to comment.