Skip to content

Commit

Permalink
Fix serialize issue when submitting calculation on parts of svd outpu…
Browse files Browse the repository at this point in the history
…ts (#43)

* try to fix serialize issue when submitting calculation on parts of multiple outputs like svd

* fix stop graph failed due to failure of serialization
  • Loading branch information
qinxuye authored and wjsi committed Dec 18, 2018
1 parent 14172de commit 45f9166
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
7 changes: 7 additions & 0 deletions mars/deploy/local/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def testMultipleOutputTensorExecute(self):
np.testing.assert_allclose(U_result, U_expected)
np.testing.assert_allclose(s_result, s_expectd)

with new_session(cluster.endpoint) as session2:
U_result, s_result = session2.run(U + 1, s + 1)
U_expected, s_expectd, _ = np.linalg.svd(raw, full_matrices=False)

np.testing.assert_allclose(U_result, U_expected + 1)
np.testing.assert_allclose(s_result, s_expectd + 1)

def testGraphFail(self):
op = SerializeMustFailOperand(f=3)
tensor = op.new_tensor(None, (3, 3))
Expand Down
3 changes: 2 additions & 1 deletion mars/tensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def build_graph(self, graph=None, cls=DAG, tiled=False, compose=True):
graph.add_node(c)
if not graph.has_successor(c, chunk):
graph.add_edge(c, chunk)
nodes.extend([c for c in children if c not in visited])
nodes.extend([c for c in itertools.chain(*[inp.op.outputs for inp in chunk.inputs or []])
if c not in visited])
if tiled and compose:
graph.compose(keys=keys)
return graph
Expand Down
10 changes: 10 additions & 0 deletions mars/tensor/expressions/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def testSVD(self):
self.assertEqual(len(new_outputs), 3)
self.assertEqual(len(set([o.op for o in new_outputs])), 1)

# test tensor graph, do some caculation
graph = DirectedGraph()
(U + 1).build_graph(tiled=False, graph=graph)
(s + 1).build_graph(tiled=False, graph=graph)
new_graph = DirectedGraph.from_json(graph.to_json())
self.assertEqual((len(new_graph)), 6)
new_outputs = [n for n in new_graph if new_graph.count_predecessors(n) == 1]
self.assertEqual(len(new_outputs), 5)
self.assertEqual(len(set([o.op for o in new_outputs])), 3)

def testLU(self):
a = mt.random.randint(1, 10, (6, 6), chunks=3)
p, l, u = mt.linalg.lu(a)
Expand Down

0 comments on commit 45f9166

Please sign in to comment.