Skip to content

Commit

Permalink
Wait for graph to finish instead of querying with fixed intervals (#701)
Browse files Browse the repository at this point in the history
* Allow querying worker meta via api

* Wait for graph finish event instead of querying using fixed intervals
  • Loading branch information
wjsi authored and Xuye (Chris) Qin committed Sep 4, 2019
1 parent 163950a commit ea9fc15
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -77,6 +77,7 @@ mars/*.c*
mars/lib/*.c*
mars/actors/*.c*
mars/actors/pool/*.c*
mars/optimizes/**/*.c*
mars/scheduler/*.c*
mars/serialize/*.c*
mars/worker/*.c*
8 changes: 8 additions & 0 deletions mars/actors/pool/gevent_pool.pyx
Expand Up @@ -185,6 +185,10 @@ cdef class ActorContext:
return self._comm.tell(actor_ref, message, delay=delay,
wait=wait, callback=callback)

@staticmethod
def event():
return gevent.event.Event()

@staticmethod
def sleep(seconds):
gevent.sleep(seconds)
Expand Down Expand Up @@ -1655,6 +1659,10 @@ cdef class ActorClient:
raise ValueError('address must be provided')
return ref

@staticmethod
def event():
return gevent.event.Event()

@staticmethod
def sleep(seconds):
gevent.sleep(seconds)
Expand Down
4 changes: 4 additions & 0 deletions mars/api.py
Expand Up @@ -147,6 +147,10 @@ def get_graph_state(self, session_id, graph_key):
state = GraphState(state.lower())
return state

def wait_graph_finish(self, session_id, graph_key, timeout=None):
graph_meta_ref = self.get_graph_meta_ref(session_id, graph_key)
self.actor_client.actor_ref(graph_meta_ref.get_wait_ref()).wait(timeout)

def fetch_data(self, session_id, graph_key, tileable_key, index_obj=None, compressions=None):
graph_uid = GraphActor.gen_uid(session_id, graph_key)
graph_ref = self.get_actor_ref(graph_uid)
Expand Down
2 changes: 2 additions & 0 deletions mars/deploy/local/core.py
Expand Up @@ -139,6 +139,8 @@ def stop_service(self):
def serve_forever(self):
try:
self._pool.join()
except KeyboardInterrupt:
pass
finally:
self.stop_service()

Expand Down
8 changes: 5 additions & 3 deletions mars/deploy/local/session.py
Expand Up @@ -127,15 +127,17 @@ def run(self, *tileables, **kw):
graph_key, targets, compose=compose)

exec_start_time = time.time()
while timeout <= 0 or time.time() - exec_start_time <= timeout:
time.sleep(0.1)

time_elapsed = 0
while timeout <= 0 or time_elapsed < timeout:
timeout_val = min(5, timeout - time_elapsed) if timeout > 0 else 5
self._api.wait_graph_finish(self._session_id, graph_key, timeout=timeout_val)
graph_state = self._api.get_graph_state(self._session_id, graph_key)
if graph_state == GraphState.SUCCEEDED:
break
if graph_state == GraphState.FAILED:
# TODO(qin): add traceback
raise ExecutionFailed('Graph execution failed with unknown reason')
time_elapsed = time.time() - exec_start_time

if 0 < timeout < time.time() - exec_start_time:
raise TimeoutError
Expand Down
36 changes: 36 additions & 0 deletions mars/scheduler/graph.py
Expand Up @@ -60,12 +60,15 @@ def __init__(self, session_id, graph_key):
self._graph_key = graph_key

self._kv_store_ref = None
self._graph_wait_ref = None

self._start_time = None
self._end_time = None
self._state = None
self._final_state = None

self._graph_finish_event = None

self._op_infos = defaultdict(dict)
self._state_to_infos = defaultdict(dict)

Expand All @@ -75,6 +78,20 @@ def post_create(self):
if not self.ctx.has_actor(self._kv_store_ref):
self._kv_store_ref = None

self._graph_finish_event = self.ctx.event()
graph_wait_uid = GraphWaitActor.gen_uid(self._session_id, self._graph_key)
try:
graph_wait_uid = self.ctx.distributor.make_same_process(
graph_wait_uid, self.uid)
except AttributeError:
pass
self._graph_wait_ref = self.ctx.create_actor(
GraphWaitActor, self._graph_finish_event, uid=graph_wait_uid)

def pre_destroy(self):
self._graph_wait_ref.destroy()
super(GraphMetaActor, self).pre_destroy()

def get_graph_info(self):
return self._start_time, self._end_time, len(self._op_infos)

Expand All @@ -83,6 +100,7 @@ def set_graph_start(self):

def set_graph_end(self):
self._end_time = time.time()
self._graph_finish_event.set()

def set_state(self, state):
self._state = state
Expand All @@ -93,6 +111,9 @@ def set_state(self, state):
def get_state(self):
return self._state

def get_wait_ref(self):
return self._graph_wait_ref

def set_final_state(self, state):
self._final_state = state
if self._kv_store_ref is not None:
Expand Down Expand Up @@ -168,6 +189,19 @@ def calc_stats(self):
return ops, transposed, percentage


class GraphWaitActor(SchedulerActor):
@staticmethod
def gen_uid(session_id, graph_key):
return 's:0:graph_wait$%s$%s' % (session_id, graph_key)

def __init__(self, graph_event):
super(GraphWaitActor, self).__init__()
self._graph_event = graph_event

def wait(self, timeout=None):
self._graph_event.wait(timeout)


class GraphActor(SchedulerActor):
"""
Actor handling execution and status of a Mars graph
Expand Down Expand Up @@ -329,6 +363,7 @@ def _detect_cancel(callback=None):
logger.exception('Failed to start graph execution.')
self.stop_graph()
self.state = GraphState.FAILED
self._graph_meta_ref.set_graph_end(_tell=True)
raise

if len(self._chunk_graph_cache) == 0:
Expand Down Expand Up @@ -365,6 +400,7 @@ def stop_graph(self):
ref.stop_operand(_tell=True)
if not has_stopping:
self.state = GraphState.CANCELLED
self._graph_meta_ref.set_graph_end(_tell=True)

@log_unhandled
def reload_chunk_graph(self):
Expand Down
1 change: 1 addition & 0 deletions mars/tensor/__init__.py
Expand Up @@ -48,6 +48,7 @@
var, std, nanvar, nanstd, nancumsum, nancumprod, count_nonzero, allclose, array_equal
from .reshape import reshape
from .merge import concatenate, stack, hstack, vstack, dstack, column_stack
from .indexing import take, compress, extract, choose, unravel_index, nonzero, flatnonzero
from .rechunk import rechunk
from .lib.index_tricks import mgrid, ogrid, ndindex

Expand Down
27 changes: 21 additions & 6 deletions mars/web/apihandlers.py
Expand Up @@ -128,8 +128,20 @@ class GraphApiHandler(MarsApiRequestHandler):
@gen.coroutine
def get(self, session_id, graph_key):
from ..scheduler.utils import GraphState
wait_timeout = int(self.get_argument('wait_timeout', None))

try:
if wait_timeout:
executor = futures.ThreadPoolExecutor(1)
if wait_timeout <= 0:
wait_timeout = None

def _wait_fun():
web_api = MarsWebAPI(self._scheduler)
return web_api.wait_graph_finish(session_id, graph_key, wait_timeout)

_ = yield executor.submit(_wait_fun)

state = self.web_api.get_graph_state(session_id, graph_key)
except GraphNotExists:
raise web.HTTPError(404, 'Graph not exists')
Expand All @@ -154,7 +166,7 @@ def delete(self, session_id, graph_key):
self._dump_exception(sys.exc_info(), 404)


class GraphDataHandler(MarsApiRequestHandler):
class GraphDataApiHandler(MarsApiRequestHandler):
@gen.coroutine
def get(self, session_id, graph_key, tileable_key):
data_type = self.get_argument('type', None)
Expand Down Expand Up @@ -191,11 +203,14 @@ def delete(self, session_id, graph_key, tileable_key):

class WorkersApiHandler(MarsApiRequestHandler):
def get(self):
workers_num = self.web_api.count_workers()
self.write(json.dumps(workers_num))
action = self.get_argument('action', None)
if action == 'count':
self.write(json.dumps(self.web_api.count_workers()))
else:
self.write(json.dumps(self.web_api.get_workers_meta()))


class MutableTensorHandler(MarsApiRequestHandler):
class MutableTensorApiHandler(MarsApiRequestHandler):
def get(self, session_id, name):
try:
meta = self.web_api.get_mutable_tensor(session_id, name)
Expand Down Expand Up @@ -239,5 +254,5 @@ def put(self, session_id, name):
register_web_handler('/api/session/(?P<session_id>[^/]+)/graph', GraphsApiHandler)
register_web_handler('/api/session/(?P<session_id>[^/]+)/graph/(?P<graph_key>[^/]+)', GraphApiHandler)
register_web_handler('/api/session/(?P<session_id>[^/]+)/graph/(?P<graph_key>[^/]+)/data/(?P<tileable_key>[^/]+)',
GraphDataHandler)
register_web_handler('/api/session/(?P<session_id>[^/]+)/mutable-tensor/(?P<name>[^/]+)', MutableTensorHandler)
GraphDataApiHandler)
register_web_handler('/api/session/(?P<session_id>[^/]+)/mutable-tensor/(?P<name>[^/]+)', MutableTensorApiHandler)
18 changes: 11 additions & 7 deletions mars/web/session.py
Expand Up @@ -93,10 +93,10 @@ def _set_tileable_graph_key(self, tileable, graph_key):
else:
self._executed_tileables[tileable_key] = graph_key, {tileable_id}

def _check_response_finished(self, graph_url):
def _check_response_finished(self, graph_url, timeout=None):
import requests
try:
resp = self._req_session.get(graph_url)
resp = self._req_session.get(graph_url, params={'wait_timeout': timeout})
except requests.ConnectionError as ex:
err_msg = str(ex)
if 'ConnectionResetError' in err_msg or 'Connection refused' in err_msg:
Expand Down Expand Up @@ -153,26 +153,30 @@ def run(self, *tileables, **kw):

resp_json = self._submit_graph(graph_json, targets_join, compose=compose)
graph_key = resp_json['graph_key']
graph_url = session_url + '/graph/' + graph_key
graph_url = '%s/graph/%s' % (session_url, graph_key)

for t in tileables:
self._set_tileable_graph_key(t, graph_key)

exec_start_time = time.time()
while timeout <= 0 or time.time() - exec_start_time <= timeout:
time_elapsed = 0
while timeout <= 0 or time_elapsed < timeout:
timeout_val = min(5, timeout - time_elapsed) if timeout > 0 else 5
try:
time.sleep(1)
if self._check_response_finished(graph_url):
if self._check_response_finished(graph_url, timeout_val):
break
except KeyboardInterrupt:
resp = self._req_session.delete(graph_url)
if resp.status_code >= 400:
raise ExecutionNotStopped(
'Failed to stop graph execution. Code: %d, Reason: %s, Content:\n%s' %
(resp.status_code, resp.reason, resp.text))
finally:
time_elapsed = time.time() - exec_start_time

if 0 < timeout < time.time() - exec_start_time:
raise TimeoutError

if not fetch:
return
else:
Expand Down Expand Up @@ -384,7 +388,7 @@ def check_service_ready(self, timeout=1):
return True

def count_workers(self):
resp = self._req_session.get(self._endpoint + '/api/worker', timeout=1)
resp = self._req_session.get(self._endpoint + '/api/worker?action=count', timeout=1)
return json.loads(resp.text)

def get_task_count(self):
Expand Down
10 changes: 9 additions & 1 deletion mars/web/tests/test_api.py
Expand Up @@ -257,6 +257,14 @@ def normalize_tbs(tb_lines):

service_ep = 'http://127.0.0.1:' + self.web_port

# query worker info
res = requests.get('%s/api/worker' % service_ep)
self.assertEqual(res.status_code, 200)
self.assertEqual(len(json.loads(res.text)), 1)
res = requests.get('%s/api/worker?action=count' % service_ep)
self.assertEqual(res.status_code, 200)
self.assertEqual(int(res.text), 1)

# raise on malicious python version
res = requests.post('%s/api/session' % service_ep, dict(pyver='mal.version'))
self.assertEqual(res.status_code, 400)
Expand Down Expand Up @@ -319,7 +327,7 @@ def data(self, data):
@staticmethod
def mocked_requests_get(*arg, **_):
url = arg[0]
if url.endswith('worker'):
if '/worker' in url:
return MockResponse(200, json_text=1)
if url.split('/')[-2] == 'graph':
return MockResponse(200, json_text={"state": 'success'})
Expand Down

0 comments on commit ea9fc15

Please sign in to comment.