Skip to content

Commit

Permalink
fix chunk index of tensor of singular values calculated by svd decomp…
Browse files Browse the repository at this point in the history
…osition (#58)
  • Loading branch information
qinxuye authored and hekaisheng committed Dec 20, 2018
1 parent 2ff2cfe commit 599819f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
10 changes: 6 additions & 4 deletions mars/tensor/expressions/linalg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def tile(cls, op):
s_shape = (stage2_r_chunk.shape[1],)
v_shape = (stage2_r_chunk.shape[1],) * 2
stage2_usv_chunks = svd_op.new_chunks([stage2_r_chunk], [u_shape, s_shape, v_shape],
index=stage2_r_chunk.index,
kws=[{'side': 'U', 'dtype': U_dtype},
{'side': 's', 'dtype': s_dtype},
{'side': 'V', 'dtype': V_dtype}])
kws=[{'side': 'U', 'dtype': U_dtype,
'index': stage2_r_chunk.index},
{'side': 's', 'dtype': s_dtype,
'index': stage2_r_chunk.index[1:]},
{'side': 'V', 'dtype': V_dtype,
'index': stage2_r_chunk.index}])
stage2_u_chunk, stage2_s_chunk, stage2_v_chunk = stage2_usv_chunks

# stage 4, U = Q @ u
Expand Down
10 changes: 6 additions & 4 deletions mars/tensor/expressions/linalg/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ def tile(cls, op):
in_chunk = in_tensor.chunks[0]
chunk_op = op.copy().reset_key()
svd_chunks = chunk_op.new_chunks([in_chunk], (U_shape, s_shape, V_shape),
index=in_chunk.index,
kws=[
{'side': 'U', 'dtype': U_dtype},
{'side': 's', 'dtype': s_dtype},
{'side': 'V', 'dtype': V_dtype}
{'side': 'U', 'dtype': U_dtype,
'index': in_chunk.index},
{'side': 's', 'dtype': s_dtype,
'index': in_chunk.index[1:]},
{'side': 'V', 'dtype': V_dtype,
'index': in_chunk.index}
])
U_chunk, s_chunk, V_chunk = svd_chunks

Expand Down
18 changes: 18 additions & 0 deletions mars/tensor/expressions/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ def testSVD(self):
self.assertEqual(len(s.chunks), 1)
self.assertEqual(len(V.chunks), 1)

self.assertEqual(s.ndim, 1)
self.assertEqual(len(s.chunks[0].index), 1)

a = mt.random.rand(9, 6, chunks=(9, 6))
U, s, V = mt.linalg.svd(a)

self.assertEqual(U.shape, (9, 6))
self.assertEqual(s.shape, (6,))
self.assertEqual(V.shape, (6, 6))

U.tiles()
self.assertEqual(len(U.chunks), 1)
self.assertEqual(len(s.chunks), 1)
self.assertEqual(len(V.chunks), 1)

self.assertEqual(s.ndim, 1)
self.assertEqual(len(s.chunks[0].index), 1)

rs = mt.random.RandomState(1)
a = rs.rand(9, 6, chunks=(3, 6))
U, s, V = mt.linalg.svd(a)
Expand Down

0 comments on commit 599819f

Please sign in to comment.