Permalink
Browse files

add retries flag to LoadBalancedView

also add some lbv tests, and related fixes

closes gh-412
  • Loading branch information...
1 parent 21b0f4c commit 6549d091655829eb6487b2bfa7134f7f7483025e @minrk committed May 4, 2011
@@ -19,7 +19,7 @@
import zmq
from IPython.testing import decorators as testdec
-from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat
+from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
from IPython.external.decorator import decorator
@@ -791,9 +791,10 @@ class LoadBalancedView(View):
follow=Any()
after=Any()
timeout=CFloat()
+ retries = CInt(0)
_task_scheme = Any()
- _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout'])
+ _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
def __init__(self, client=None, socket=None, **flags):
super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
@@ -851,7 +852,7 @@ def set_flags(self, **kwargs):
whether to create a MessageTracker to allow the user to
safely edit after arrays and buffers during non-copying
sends.
- #
+
after : Dependency or collection of msg_ids
Only for load-balanced execution (targets=None)
Specify a list of msg_ids as a time-based dependency.
@@ -869,6 +870,9 @@ def set_flags(self, **kwargs):
Specify an amount of time (in seconds) for the scheduler to
wait for dependencies to be met before failing with a
DependencyTimeout.
+
+ retries : int
+ Number of times a task will be retried on failure.
"""
super(LoadBalancedView, self).set_flags(**kwargs)
@@ -892,7 +896,7 @@ def set_flags(self, **kwargs):
@save_ids
def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
after=None, follow=None, timeout=None,
- targets=None):
+ targets=None, retries=None):
"""calls f(*args, **kwargs) on a remote engine, returning the result.
This method temporarily sets all of `apply`'s flags for a single call.
@@ -933,10 +937,11 @@ def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
raise RuntimeError(msg)
if self._task_scheme == 'pure':
- # pure zmq scheme doesn't support dependencies
- msg = "Pure ZMQ scheduler doesn't support dependencies"
- if (follow or after):
- # hard fail on DAG dependencies
+ # pure zmq scheme doesn't support extra features
+ msg = "Pure ZMQ scheduler doesn't support the following flags:"
+ "follow, after, retries, targets, timeout"
+ if (follow or after or retries or targets or timeout):
+ # hard fail on Scheduler flags
raise RuntimeError(msg)
if isinstance(f, dependent):
# soft warn on functional dependencies
@@ -948,18 +953,22 @@ def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
block = self.block if block is None else block
track = self.track if track is None else track
after = self.after if after is None else after
+ retries = self.retries if retries is None else retries
follow = self.follow if follow is None else follow
timeout = self.timeout if timeout is None else timeout
targets = self.targets if targets is None else targets
+ if not isinstance(retries, int):
+ raise TypeError('retries must be int, not %r'%type(retries))
+
if targets is None:
idents = []
else:
idents = self.client._build_targets(targets)[0]
after = self._render_dependency(after)
follow = self._render_dependency(follow)
- subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
+ subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
subheader=subheader)
@@ -137,6 +137,7 @@ class TaskScheduler(SessionFactory):
# internals:
graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
+ retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
# waiting = List() # list of msg_ids ready to run, but haven't due to HWM
depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
pending = Dict() # dict by engine_uuid of submitted tasks
@@ -205,6 +206,8 @@ def _register_engine(self, uid):
self.pending[uid] = {}
if len(self.targets) == 1:
self.resume_receiving()
+ # rescan the graph:
+ self.update_graph(None)
def _unregister_engine(self, uid):
"""Existing engine with ident `uid` became unavailable."""
@@ -215,11 +218,11 @@ def _unregister_engine(self, uid):
# handle any potentially finished tasks:
self.engine_stream.flush()
- self.completed.pop(uid)
- self.failed.pop(uid)
- # don't pop destinations, because it might be used later
+ # don't pop destinations, because they might be used later
# map(self.destinations.pop, self.completed.pop(uid))
# map(self.destinations.pop, self.failed.pop(uid))
+
+ # prevent this engine from receiving work
idx = self.targets.index(uid)
self.targets.pop(idx)
self.loads.pop(idx)
@@ -229,28 +232,40 @@ def _unregister_engine(self, uid):
if self.pending[uid]:
dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
dc.start()
+ else:
+ self.completed.pop(uid)
+ self.failed.pop(uid)
+
@logged
def handle_stranded_tasks(self, engine):
"""Deal with jobs resident in an engine that died."""
- lost = self.pending.pop(engine)
-
- for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
- self.all_failed.add(msg_id)
- self.all_done.add(msg_id)
+ lost = self.pending[engine]
+ for msg_id in lost.keys():
+ if msg_id not in self.pending[engine]:
+ # prevent double-handling of messages
+ continue
+
+ raw_msg = lost[msg_id][0]
+
idents,msg = self.session.feed_identities(raw_msg, copy=False)
msg = self.session.unpack_message(msg, copy=False, content=False)
parent = msg['header']
- idents = [idents[0],engine]+idents[1:]
- # print (idents)
+ idents = [engine, idents[0]]
+
+ # build fake error reply
try:
raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
except:
content = error.wrap_exception()
- msg = self.session.send(self.client_stream, 'apply_reply', content,
- parent=parent, ident=idents)
- self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
- self.update_graph(msg_id)
+ msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
+ raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
+ # and dispatch it
+ self.dispatch_result(raw_reply)
+
+ # finally scrub completed/failed lists
+ self.completed.pop(engine)
+ self.failed.pop(engine)
#-----------------------------------------------------------------------
@@ -277,6 +292,8 @@ def dispatch_submission(self, raw_msg):
# targets
targets = set(header.get('targets', []))
+ retries = header.get('retries', 0)
+ self.retries[msg_id] = retries
# time dependencies
after = Dependency(header.get('after', []))
@@ -315,7 +332,9 @@ def dispatch_submission(self, raw_msg):
# time deps already met, try to run
if not self.maybe_run(msg_id, *args):
# can't run yet
- self.save_unmet(msg_id, *args)
+ if msg_id not in self.all_failed:
+ # could have failed as unreachable
+ self.save_unmet(msg_id, *args)
else:
self.save_unmet(msg_id, *args)
@@ -328,7 +347,7 @@ def audit_timeouts(self):
if msg_id in self.depending:
raw,after,targets,follow,timeout = self.depending[msg_id]
if timeout and timeout < now:
- self.fail_unreachable(msg_id, timeout=True)
+ self.fail_unreachable(msg_id, error.TaskTimeout)
@logged
def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
@@ -369,7 +388,7 @@ def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
# we need a can_run filter
def can_run(idx):
# check hwm
- if self.loads[idx] == self.hwm:
+ if self.hwm and self.loads[idx] == self.hwm:
return False
target = self.targets[idx]
# check blacklist
@@ -382,6 +401,7 @@ def can_run(idx):
return follow.check(self.completed[target], self.failed[target])
indices = filter(can_run, range(len(self.targets)))
+
if not indices:
# couldn't run
if follow.all:
@@ -395,12 +415,14 @@ def can_run(idx):
for m in follow.intersection(relevant):
dests.add(self.destinations[m])
if len(dests) > 1:
+ self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
self.fail_unreachable(msg_id)
return False
if targets:
# check blacklist+targets for impossibility
targets.difference_update(blacklist)
if not targets or not targets.intersection(self.targets):
+ self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
self.fail_unreachable(msg_id)
return False
return False
@@ -454,20 +476,34 @@ def dispatch_result(self, raw_msg):
idents,msg = self.session.feed_identities(raw_msg, copy=False)
msg = self.session.unpack_message(msg, content=False, copy=False)
engine = idents[0]
- idx = self.targets.index(engine)
- self.finish_job(idx)
+ try:
+ idx = self.targets.index(engine)
+ except ValueError:
+ pass # skip load-update for dead engines
+ else:
+ self.finish_job(idx)
except Exception:
self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
return
header = msg['header']
+ parent = msg['parent_header']
if header.get('dependencies_met', True):
success = (header['status'] == 'ok')
- self.handle_result(idents, msg['parent_header'], raw_msg, success)
- # send to Hub monitor
- self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
+ msg_id = parent['msg_id']
+ retries = self.retries[msg_id]
+ if not success and retries > 0:
+ # failed
+ self.retries[msg_id] = retries - 1
+ self.handle_unmet_dependency(idents, parent)
+ else:
+ del self.retries[msg_id]
+ # relay to client and update graph
+ self.handle_result(idents, parent, raw_msg, success)
+ # send to Hub monitor
+ self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
else:
- self.handle_unmet_dependency(idents, msg['parent_header'])
+ self.handle_unmet_dependency(idents, parent)
@logged
def handle_result(self, idents, parent, raw_msg, success=True):
@@ -511,13 +547,19 @@ def handle_unmet_dependency(self, idents, parent):
self.depending[msg_id] = args
self.fail_unreachable(msg_id)
elif not self.maybe_run(msg_id, *args):
- # resubmit failed, put it back in our dependency tree
- self.save_unmet(msg_id, *args)
+ # resubmit failed
+ if msg_id not in self.all_failed:
+ # put it back in our dependency tree
+ self.save_unmet(msg_id, *args)
if self.hwm:
- idx = self.targets.index(engine)
- if self.loads[idx] == self.hwm-1:
- self.update_graph(None)
+ try:
+ idx = self.targets.index(engine)
+ except ValueError:
+ pass # skip load-update for dead engines
+ else:
+ if self.loads[idx] == self.hwm-1:
+ self.update_graph(None)
@@ -526,7 +568,7 @@ def update_graph(self, dep_id=None, success=True):
"""dep_id just finished. Update our dependency
graph and submit any jobs that just became runable.
- Called with dep_id=None to update graph for hwm, but without finishing
+ Called with dep_id=None to update entire graph for hwm, but without finishing
a task.
"""
# print ("\n\n***********")
@@ -538,9 +580,11 @@ def update_graph(self, dep_id=None, success=True):
# print ("\n\n***********\n\n")
# update any jobs that depended on the dependency
jobs = self.graph.pop(dep_id, [])
- # if we have HWM and an engine just become no longer full
- # recheck *all* jobs:
- if self.hwm and any( [ load==self.hwm-1 for load in self.loads]):
+
+ # recheck *all* jobs if
+ # a) we have HWM and an engine just become no longer full
+ # or b) dep_id was given as None
+ if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
jobs = self.depending.keys()
for msg_id in jobs:
@@ -48,7 +48,7 @@ def start(self):
def setup():
cp = TestProcessLauncher()
cp.cmd_and_args = ipcontroller_cmd_argv + \
- ['--profile', 'iptest', '--log-level', '99', '-r', '--usethreads']
+ ['--profile', 'iptest', '--log-level', '99', '-r']
cp.start()
launchers.append(cp)
cluster_dir = os.path.join(get_ipython_dir(), 'cluster_iptest')
Oops, something went wrong.

0 comments on commit 6549d09

Please sign in to comment.