Skip to content

Commit

Permalink
Switch the graph into FAILED when preparation failed (#24)
Browse files Browse the repository at this point in the history
* switch the graph into FAILED when there are errors in the preparation stage

* add tests on web
  • Loading branch information
wjsi authored and hekaisheng committed Dec 12, 2018
1 parent 7db71f6 commit 298c56b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 29 deletions.
14 changes: 10 additions & 4 deletions mars/scheduler/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,16 @@ def execute_graph(self):
self._start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
self.state = GraphState.PREPARING

self.prepare_graph()
self.scan_node()
self.place_initial_chunks()
self.create_operand_actors()
try:
self.prepare_graph()
self.scan_node()
self.place_initial_chunks()
self.create_operand_actors()
except:
logger.exception('Failed to start graph execution.')
self.stop_graph()
self.state = GraphState.FAILED
raise

def stop_graph(self):
"""
Expand Down
44 changes: 24 additions & 20 deletions mars/scheduler/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def check_process_statuses(self):
if self.proc_worker.poll() is not None:
raise SystemError('Worker not started. exit code %s' % self.proc_worker.poll())

def wait_for_termination(self, session_ref, graph_key):
check_time = time.time()
while True:
time.sleep(1)
self.check_process_statuses()
if time.time() - check_time > 60:
raise SystemError('Check graph status timeout')
if session_ref.graph_state(graph_key) in GraphState.TERMINATED_STATES:
return session_ref.graph_state(graph_key)

def testMain(self):
session_id = uuid.uuid1()
scheduler_address = '127.0.0.1:' + self.scheduler_port
Expand All @@ -142,20 +152,21 @@ def testMain(self):
session_ref.submit_tensor_graph(json.dumps(graph.to_json()),
graph_key, target_tensors=targets)

check_time = time.time()
while True:
time.sleep(1)
self.check_process_statuses()
if time.time() - check_time > 60:
raise SystemError('Check graph status timeout')
if session_ref.graph_state(graph_key) == GraphState.SUCCEEDED:
result = session_ref.fetch_result(graph_key, c.key)
break
state = self.wait_for_termination(session_ref, graph_key)
self.assertEqual(state, GraphState.SUCCEEDED)

result = session_ref.fetch_result(graph_key, c.key)
expected = (np.ones(a.shape) * 2 * 1 + 1) ** 2 * 2 + 1

assert_array_equal(loads(result), expected.sum())

graph_key = uuid.uuid1()
session_ref.submit_tensor_graph(json.dumps(graph.to_json()),
graph_key, target_tensors=targets)

# todo this behavior may change when eager mode is introduced
state = self.wait_for_termination(session_ref, graph_key)
self.assertEqual(state, GraphState.FAILED)

a = ones((100, 50), chunks=30) * 2 + 1
b = ones((50, 200), chunks=30) * 2 + 1
c = a.dot(b)
Expand All @@ -165,14 +176,7 @@ def testMain(self):
session_ref.submit_tensor_graph(json.dumps(graph.to_json()),
graph_key, target_tensors=targets)

check_time = time.time()
while True:
time.sleep(1)
self.check_process_statuses()
if time.time() - check_time > 60:
raise SystemError('Check graph status timeout')
if session_ref.graph_state(graph_key) == GraphState.SUCCEEDED:
result = session_ref.fetch_result(graph_key, c.key)
break

state = self.wait_for_termination(session_ref, graph_key)
self.assertEqual(state, GraphState.SUCCEEDED)
result = session_ref.fetch_result(graph_key, c.key)
assert_array_equal(loads(result), np.ones((100, 200)) * 450)
13 changes: 8 additions & 5 deletions mars/web/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ def run(self, tensors, compose=True, wait=True, timeout=-1):
raise ExecutionInterrupted
elif resp_json['state'] == 'failed':
# TODO add traceback
traceback = resp_json['traceback']
if isinstance(traceback, list):
traceback = ''.join(str(s) for s in traceback)
raise SystemError('Graph execution failed.\nMessage: %s\nTraceback from server:\n%s' %
(resp_json['msg'], traceback))
if 'traceback' in resp_json:
traceback = resp_json['traceback']
if isinstance(traceback, list):
traceback = ''.join(str(s) for s in traceback)
raise SystemError('Graph execution failed.\nMessage: %s\nTraceback from server:\n%s' %
(resp_json['msg'], traceback))
else:
raise SystemError('Graph execution failed with unknown reason.')
else:
raise SystemError('Unknown graph execution state %s' % resp_json['state'])
except KeyboardInterrupt:
Expand Down
4 changes: 4 additions & 0 deletions mars/web/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def testApi(self):
value = sess.run(c)
assert_array_equal(value[0], np.ones((100, 100)) * 100)

# todo this behavior may change when eager mode is introduced
with self.assertRaises(SystemError):
sess.run(c)

va = np.random.randint(0, 10000, (100, 100))
vb = np.random.randint(0, 10000, (100, 100))
a = mt.array(va, chunks=30)
Expand Down

0 comments on commit 298c56b

Please sign in to comment.