Skip to content
Browse files

add timeout for unmet dependencies in task scheduler

  • Loading branch information...
1 parent c221efd commit 23d190659dd7c2ed4db834054b835d31f92a0932 @minrk minrk committed Feb 9, 2011
View
17 IPython/zmq/parallel/asyncresult.py
@@ -36,7 +36,7 @@ def __init__(self, client, msg_ids, fname=''):
self._fname=fname
self._ready = False
self._success = None
- self._flatten_result = len(msg_ids) == 1
+ self._single_result = len(msg_ids) == 1
def __repr__(self):
if self._ready:
@@ -50,7 +50,7 @@ def _reconstruct_result(self, res):
Override me in subclasses for turning a list of results
into the expected form.
"""
- if self._flatten_result:
+ if self._single_result:
return res[0]
else:
return res
@@ -90,7 +90,12 @@ def wait(self, timeout=-1):
try:
results = map(self._client.results.get, self.msg_ids)
self._result = results
- results = error.collect_exceptions(results, self._fname)
+ if self._single_result:
+ r = results[0]
+ if isinstance(r, Exception):
+ raise r
+ else:
+ results = error.collect_exceptions(results, self._fname)
self._result = self._reconstruct_result(results)
except Exception, e:
self._exception = e
@@ -138,7 +143,7 @@ def result(self):
@check_ready
def metadata(self):
"""metadata property."""
- if self._flatten_result:
+ if self._single_result:
return self._metadata[0]
else:
return self._metadata
@@ -165,7 +170,7 @@ def __getitem__(self, key):
return error.collect_exceptions(self._result[key], self._fname)
elif isinstance(key, basestring):
values = [ md[key] for md in self._metadata ]
- if self._flatten_result:
+ if self._single_result:
return values[0]
else:
return values
@@ -190,7 +195,7 @@ class AsyncMapResult(AsyncResult):
def __init__(self, client, msg_ids, mapObject, fname=''):
AsyncResult.__init__(self, client, msg_ids, fname=fname)
self._mapObject = mapObject
- self._flatten_result = False
+ self._single_result = False
def _reconstruct_result(self, res):
"""Perform the gather on the actual results."""
View
56 IPython/zmq/parallel/client.py
@@ -765,9 +765,26 @@ def _maybe_raise(self, result):
raise result
return result
-
+
+ def _build_dependency(self, dep):
+ """helper for building jsonable dependencies from various input forms"""
+ if isinstance(dep, Dependency):
+ return dep.as_dict()
+ elif isinstance(dep, AsyncResult):
+ return dep.msg_ids
+ elif dep is None:
+ return []
+ elif isinstance(dep, set):
+ return list(dep)
+ elif isinstance(dep, (list,dict)):
+ return dep
+ elif isinstance(dep, str):
+ return [dep]
+ else:
+ raise TypeError("Dependency may be: set,list,dict,Dependency or AsyncResult, not %r"%type(dep))
+
def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
- after=None, follow=None):
+ after=None, follow=None, timeout=None):
"""Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
This is the central execution command for the client.
@@ -817,6 +834,10 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
This job will only be run on an engine where this dependency
is met.
+ timeout : float or None
+ Only for load-balanced execution (targets=None)
+ Specify an amount of time (in seconds)
+
Returns
-------
if block is False:
@@ -844,33 +865,23 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
raise TypeError("args must be tuple or list, not %s"%type(args))
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be dict, not %s"%type(kwargs))
-
- if isinstance(after, Dependency):
- after = after.as_dict()
- elif isinstance(after, AsyncResult):
- after=after.msg_ids
- elif after is None:
- after = []
- if isinstance(follow, Dependency):
- # if len(follow) > 1 and follow.mode == 'all':
- # warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning)
- follow = follow.as_dict()
- elif isinstance(follow, AsyncResult):
- follow=follow.msg_ids
- elif follow is None:
- follow = []
- options = dict(bound=bound, block=block, after=after, follow=follow)
+
+ after = self._build_dependency(after)
+ follow = self._build_dependency(follow)
+
+ options = dict(bound=bound, block=block)
if targets is None:
- return self._apply_balanced(f, args, kwargs, **options)
+ return self._apply_balanced(f, args, kwargs, timeout=timeout,
+ after=after, follow=follow, **options)
else:
return self._apply_direct(f, args, kwargs, targets=targets, **options)
def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
- after=None, follow=None):
+ after=None, follow=None, timeout=None):
"""The underlying method for applying functions in a load balanced
manner, via the task queue."""
- subheader = dict(after=after, follow=follow)
+ subheader = dict(after=after, follow=follow, timeout=timeout)
bufs = ss.pack_apply_message(f,args,kwargs)
content = dict(bound=bound)
@@ -885,8 +896,7 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
else:
return ar
- def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
- after=None, follow=None):
+ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):
"""Then underlying method for applying functions to specific engines
via the MUX queue."""
View
4 IPython/zmq/parallel/controller.py
@@ -100,9 +100,9 @@ def construct_schedulers(self):
self.log.warn("task::using no Task scheduler")
else:
- self.log.warn("task::using Python %s Task scheduler"%self.scheme)
+ self.log.info("task::using Python %s Task scheduler"%self.scheme)
sargs = (self.client_addrs['task'], self.engine_addrs['task'], self.monitor_url, self.client_addrs['notification'])
- q = Process(target=launch_scheduler, args=sargs, kwargs = dict(scheme=self.scheme))
+ q = Process(target=launch_scheduler, args=sargs, kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level))
q.daemon=True
children.append(q)
View
2 IPython/zmq/parallel/dependency.py
@@ -55,7 +55,7 @@ def require(*names):
return depend(_require, *names)
class Dependency(set):
- """An object for representing a set of dependencies.
+ """An object for representing a set of msg_id dependencies.
Subclassed from set()."""
View
3 IPython/zmq/parallel/error.py
@@ -154,6 +154,9 @@ class UnmetDependency(KernelError):
class ImpossibleDependency(UnmetDependency):
pass
+class DependencyTimeout(UnmetDependency):
+ pass
+
class RemoteError(KernelError):
"""Error raised elsewhere"""
ename=None
View
69 IPython/zmq/parallel/scheduler.py
@@ -12,9 +12,9 @@
from __future__ import print_function
import sys
import logging
-from random import randint,random
+from random import randint, random
from types import FunctionType
-
+from datetime import datetime, timedelta
try:
import numpy
except ImportError:
@@ -29,11 +29,11 @@
from IPython.utils.traitlets import Instance, Dict, List, Set
import error
-from client import Client
+# from client import Client
from dependency import Dependency
import streamsession as ss
from entry_point import connect_logger, local_logger
-from factory import LoggingFactory
+from factory import SessionFactory
@decorator
@@ -110,7 +110,7 @@ def leastload(loads):
# store empty default dependency:
MET = Dependency([])
-class TaskScheduler(LoggingFactory):
+class TaskScheduler(SessionFactory):
"""Python TaskScheduler object.
This is the simplest object that supports msg_id based
@@ -125,7 +125,6 @@ class TaskScheduler(LoggingFactory):
engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
- io_loop = Instance(ioloop.IOLoop)
# internals:
dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
@@ -141,20 +140,18 @@ class TaskScheduler(LoggingFactory):
all_failed = Set() # set of all failed tasks
all_done = Set() # set of all finished tasks=union(completed,failed)
blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
- session = Instance(ss.StreamSession)
+ auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
- def __init__(self, **kwargs):
- super(TaskScheduler, self).__init__(**kwargs)
-
- self.session = ss.StreamSession(username="TaskScheduler")
-
+ def start(self):
self.engine_stream.on_recv(self.dispatch_result, copy=False)
self._notification_handlers = dict(
registration_notification = self._register_engine,
unregistration_notification = self._unregister_engine
)
self.notifier_stream.on_recv(self.dispatch_notification)
+ self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 1e3, self.loop) # 1 Hz
+ self.auditor.start()
self.log.info("Scheduler started...%r"%self)
def resume_receiving(self):
@@ -261,37 +258,55 @@ def dispatch_submission(self, raw_msg):
# location dependencies
follow = Dependency(header.get('follow', []))
-
# check if unreachable:
if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
- self.depending[msg_id] = [raw_msg,MET,MET]
+ self.depending[msg_id] = [raw_msg,MET,MET,None]
return self.fail_unreachable(msg_id)
+ # turn timeouts into datetime objects:
+ timeout = header.get('timeout', None)
+ if timeout:
+ timeout = datetime.now() + timedelta(0,timeout,0)
+
if after.check(self.all_completed, self.all_failed):
# time deps already met, try to run
if not self.maybe_run(msg_id, raw_msg, follow):
# can't run yet
- self.save_unmet(msg_id, raw_msg, after, follow)
+ self.save_unmet(msg_id, raw_msg, after, follow, timeout)
else:
- self.save_unmet(msg_id, raw_msg, after, follow)
+ self.save_unmet(msg_id, raw_msg, after, follow, timeout)
@logged
- def fail_unreachable(self, msg_id):
+ def audit_timeouts(self):
+ """Audit all waiting tasks for expired timeouts."""
+ now = datetime.now()
+ for msg_id in self.depending.keys():
+ # must recheck, in case one failure cascaded to another:
+ if msg_id in self.depending:
+ raw,after,follow,timeout = self.depending[msg_id]
+ if timeout and timeout < now:
+ self.fail_unreachable(msg_id, timeout=True)
+
+ @logged
+ def fail_unreachable(self, msg_id, timeout=False):
"""a message has become unreachable"""
if msg_id not in self.depending:
self.log.error("msg %r already failed!"%msg_id)
return
- raw_msg, after, follow = self.depending.pop(msg_id)
+ raw_msg, after, follow, timeout = self.depending.pop(msg_id)
for mid in follow.union(after):
if mid in self.dependencies:
self.dependencies[mid].remove(msg_id)
+ # FIXME: unpacking a message I've already unpacked, but didn't save:
idents,msg = self.session.feed_identities(raw_msg, copy=False)
msg = self.session.unpack_message(msg, copy=False, content=False)
header = msg['header']
+ impossible = error.DependencyTimeout if timeout else error.ImpossibleDependency
+
try:
- raise error.ImpossibleDependency()
+ raise impossible()
except:
content = ss.wrap_exception()
@@ -334,9 +349,9 @@ def can_run(idx):
return True
@logged
- def save_unmet(self, msg_id, raw_msg, after, follow):
+ def save_unmet(self, msg_id, raw_msg, after, follow, timeout):
"""Save a message for later submission when its dependencies are met."""
- self.depending[msg_id] = [raw_msg,after,follow]
+ self.depending[msg_id] = [raw_msg,after,follow,timeout]
# track the ids in follow or after, but not those already finished
for dep_id in after.union(follow).difference(self.all_done):
if dep_id not in self.dependencies:
@@ -413,10 +428,10 @@ def handle_unmet_dependency(self, idents, parent):
if msg_id not in self.blacklist:
self.blacklist[msg_id] = set()
self.blacklist[msg_id].add(engine)
- raw_msg,follow = self.pending[engine].pop(msg_id)
+ raw_msg,follow,timeout = self.pending[engine].pop(msg_id)
if not self.maybe_run(msg_id, raw_msg, follow):
# resubmit failed, put it back in our dependency tree
- self.save_unmet(msg_id, raw_msg, MET, follow)
+ self.save_unmet(msg_id, raw_msg, MET, follow, timeout)
pass
@logged
@@ -435,7 +450,7 @@ def update_dependencies(self, dep_id, success=True):
jobs = self.dependencies.pop(dep_id)
for msg_id in jobs:
- raw_msg, after, follow = self.depending[msg_id]
+ raw_msg, after, follow, timeout = self.depending[msg_id]
# if dep_id in after:
# if after.mode == 'all' and (success or not after.success_only):
# after.remove(dep_id)
@@ -497,9 +512,9 @@ def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, logname='ZMQ', log_a
local_logger(logname, loglevel)
scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
- mon_stream=mons,notifier_stream=nots,
- scheme=scheme,io_loop=loop, logname=logname)
-
+ mon_stream=mons, notifier_stream=nots,
+ scheme=scheme, loop=loop, logname=logname)
+ scheduler.start()
try:
loop.start()
except KeyboardInterrupt:

0 comments on commit 23d1906

Please sign in to comment.
Something went wrong with that request. Please try again.