Skip to content

Commit

Permalink
fix web session when submit multiply tensors (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
hekaisheng authored and qinxuye committed Dec 17, 2018
1 parent e55c291 commit 8693310
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
13 changes: 8 additions & 5 deletions mars/web/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..compat import six, TimeoutError
from ..serialize import dataserializer
from ..errors import ExecutionInterrupted
from ..graph import DirectedGraph

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,12 +52,14 @@ def _main(self):
content = json.loads(resp.text)
self._session_id = content['session_id']

def run(self, tensors, compose=True, wait=True, timeout=-1):
from ..graph import DirectedGraph
graph = DirectedGraph()
def run(self, *tensors, **kw):
timeout = kw.pop('timeout', -1)
compose = kw.pop('compose', True)
wait = kw.pop('wait', True)
if kw:
raise TypeError('run got unexpected key arguments {0}'.format(', '.join(kw.keys())))

if not isinstance(tensors, (list, tuple, set)):
tensors = [tensors]
graph = DirectedGraph()
for t in tensors:
graph = t.build_graph(graph=graph, tiled=False, compose=compose)
targets = [t.key for t in tensors]
Expand Down
5 changes: 5 additions & 0 deletions mars/web/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def testApi(self):
value = sess.run(c, timeout=120)
assert_array_equal(value, va.dot(vb))

# test test multiple outputs
a = mt.random.rand(10, 10)
U, s, V, raw = sess.run(list(mt.linalg.svd(a)) + [a])
np.testing.assert_allclose(U.dot(np.diag(s).dot(V)), raw)

# check web UI requests
res = requests.get(service_ep)
self.assertEqual(res.status_code, 200)
Expand Down

0 comments on commit 8693310

Please sign in to comment.