Skip to content

Commit

Permalink
add timeout for unmet dependencies in task scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Apr 8, 2011
1 parent c221efd commit 23d1906
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 59 deletions.
17 changes: 11 additions & 6 deletions IPython/zmq/parallel/asyncresult.py
Expand Up @@ -36,7 +36,7 @@ def __init__(self, client, msg_ids, fname=''):
self._fname=fname
self._ready = False
self._success = None
self._flatten_result = len(msg_ids) == 1
self._single_result = len(msg_ids) == 1

def __repr__(self):
if self._ready:
Expand All @@ -50,7 +50,7 @@ def _reconstruct_result(self, res):
Override me in subclasses for turning a list of results
into the expected form.
"""
if self._flatten_result:
if self._single_result:
return res[0]
else:
return res
Expand Down Expand Up @@ -90,7 +90,12 @@ def wait(self, timeout=-1):
try:
results = map(self._client.results.get, self.msg_ids)
self._result = results
results = error.collect_exceptions(results, self._fname)
if self._single_result:
r = results[0]
if isinstance(r, Exception):
raise r
else:
results = error.collect_exceptions(results, self._fname)
self._result = self._reconstruct_result(results)
except Exception, e:
self._exception = e
Expand Down Expand Up @@ -138,7 +143,7 @@ def result(self):
@check_ready
def metadata(self):
"""metadata property."""
if self._flatten_result:
if self._single_result:
return self._metadata[0]
else:
return self._metadata
Expand All @@ -165,7 +170,7 @@ def __getitem__(self, key):
return error.collect_exceptions(self._result[key], self._fname)
elif isinstance(key, basestring):
values = [ md[key] for md in self._metadata ]
if self._flatten_result:
if self._single_result:
return values[0]
else:
return values
Expand All @@ -190,7 +195,7 @@ class AsyncMapResult(AsyncResult):
def __init__(self, client, msg_ids, mapObject, fname=''):
AsyncResult.__init__(self, client, msg_ids, fname=fname)
self._mapObject = mapObject
self._flatten_result = False
self._single_result = False

def _reconstruct_result(self, res):
"""Perform the gather on the actual results."""
Expand Down
56 changes: 33 additions & 23 deletions IPython/zmq/parallel/client.py
Expand Up @@ -765,9 +765,26 @@ def _maybe_raise(self, result):
raise result

return result


def _build_dependency(self, dep):
"""helper for building jsonable dependencies from various input forms"""
if isinstance(dep, Dependency):
return dep.as_dict()
elif isinstance(dep, AsyncResult):
return dep.msg_ids
elif dep is None:
return []
elif isinstance(dep, set):
return list(dep)
elif isinstance(dep, (list,dict)):
return dep
elif isinstance(dep, str):
return [dep]
else:
raise TypeError("Dependency may be: set,list,dict,Dependency or AsyncResult, not %r"%type(dep))

def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
after=None, follow=None):
after=None, follow=None, timeout=None):
"""Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
This is the central execution command for the client.
Expand Down Expand Up @@ -817,6 +834,10 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
This job will only be run on an engine where this dependency
is met.
timeout : float or None
Only for load-balanced execution (targets=None)
Specify an amount of time (in seconds)
Returns
-------
if block is False:
Expand Down Expand Up @@ -844,33 +865,23 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
raise TypeError("args must be tuple or list, not %s"%type(args))
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be dict, not %s"%type(kwargs))

if isinstance(after, Dependency):
after = after.as_dict()
elif isinstance(after, AsyncResult):
after=after.msg_ids
elif after is None:
after = []
if isinstance(follow, Dependency):
# if len(follow) > 1 and follow.mode == 'all':
# warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning)
follow = follow.as_dict()
elif isinstance(follow, AsyncResult):
follow=follow.msg_ids
elif follow is None:
follow = []
options = dict(bound=bound, block=block, after=after, follow=follow)

after = self._build_dependency(after)
follow = self._build_dependency(follow)

options = dict(bound=bound, block=block)

if targets is None:
return self._apply_balanced(f, args, kwargs, **options)
return self._apply_balanced(f, args, kwargs, timeout=timeout,
after=after, follow=follow, **options)
else:
return self._apply_direct(f, args, kwargs, targets=targets, **options)

def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
after=None, follow=None):
after=None, follow=None, timeout=None):
"""The underlying method for applying functions in a load balanced
manner, via the task queue."""
subheader = dict(after=after, follow=follow)
subheader = dict(after=after, follow=follow, timeout=timeout)
bufs = ss.pack_apply_message(f,args,kwargs)
content = dict(bound=bound)

Expand All @@ -885,8 +896,7 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
else:
return ar

def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
after=None, follow=None):
def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):
"""Then underlying method for applying functions to specific engines
via the MUX queue."""

Expand Down
4 changes: 2 additions & 2 deletions IPython/zmq/parallel/controller.py
Expand Up @@ -100,9 +100,9 @@ def construct_schedulers(self):
self.log.warn("task::using no Task scheduler")

else:
self.log.warn("task::using Python %s Task scheduler"%self.scheme)
self.log.info("task::using Python %s Task scheduler"%self.scheme)
sargs = (self.client_addrs['task'], self.engine_addrs['task'], self.monitor_url, self.client_addrs['notification'])
q = Process(target=launch_scheduler, args=sargs, kwargs = dict(scheme=self.scheme))
q = Process(target=launch_scheduler, args=sargs, kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level))
q.daemon=True
children.append(q)

2 changes: 1 addition & 1 deletion IPython/zmq/parallel/dependency.py
Expand Up @@ -55,7 +55,7 @@ def require(*names):
return depend(_require, *names)

class Dependency(set):
"""An object for representing a set of dependencies.
"""An object for representing a set of msg_id dependencies.
Subclassed from set()."""

Expand Down
3 changes: 3 additions & 0 deletions IPython/zmq/parallel/error.py
Expand Up @@ -154,6 +154,9 @@ class UnmetDependency(KernelError):
class ImpossibleDependency(UnmetDependency):
pass

class DependencyTimeout(UnmetDependency):
pass

class RemoteError(KernelError):
"""Error raised elsewhere"""
ename=None
Expand Down
69 changes: 42 additions & 27 deletions IPython/zmq/parallel/scheduler.py
Expand Up @@ -12,9 +12,9 @@
from __future__ import print_function
import sys
import logging
from random import randint,random
from random import randint, random
from types import FunctionType

from datetime import datetime, timedelta
try:
import numpy
except ImportError:
Expand All @@ -29,11 +29,11 @@
from IPython.utils.traitlets import Instance, Dict, List, Set

import error
from client import Client
# from client import Client
from dependency import Dependency
import streamsession as ss
from entry_point import connect_logger, local_logger
from factory import LoggingFactory
from factory import SessionFactory


@decorator
Expand Down Expand Up @@ -110,7 +110,7 @@ def leastload(loads):
# store empty default dependency:
MET = Dependency([])

class TaskScheduler(LoggingFactory):
class TaskScheduler(SessionFactory):
"""Python TaskScheduler object.
This is the simplest object that supports msg_id based
Expand All @@ -125,7 +125,6 @@ class TaskScheduler(LoggingFactory):
engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
io_loop = Instance(ioloop.IOLoop)

# internals:
dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
Expand All @@ -141,20 +140,18 @@ class TaskScheduler(LoggingFactory):
all_failed = Set() # set of all failed tasks
all_done = Set() # set of all finished tasks=union(completed,failed)
blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
session = Instance(ss.StreamSession)
auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')


def __init__(self, **kwargs):
super(TaskScheduler, self).__init__(**kwargs)

self.session = ss.StreamSession(username="TaskScheduler")

def start(self):
self.engine_stream.on_recv(self.dispatch_result, copy=False)
self._notification_handlers = dict(
registration_notification = self._register_engine,
unregistration_notification = self._unregister_engine
)
self.notifier_stream.on_recv(self.dispatch_notification)
self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 1e3, self.loop) # 1 Hz
self.auditor.start()
self.log.info("Scheduler started...%r"%self)

def resume_receiving(self):
Expand Down Expand Up @@ -261,37 +258,55 @@ def dispatch_submission(self, raw_msg):

# location dependencies
follow = Dependency(header.get('follow', []))

# check if unreachable:
if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
self.depending[msg_id] = [raw_msg,MET,MET]
self.depending[msg_id] = [raw_msg,MET,MET,None]
return self.fail_unreachable(msg_id)

# turn timeouts into datetime objects:
timeout = header.get('timeout', None)
if timeout:
timeout = datetime.now() + timedelta(0,timeout,0)

if after.check(self.all_completed, self.all_failed):
# time deps already met, try to run
if not self.maybe_run(msg_id, raw_msg, follow):
# can't run yet
self.save_unmet(msg_id, raw_msg, after, follow)
self.save_unmet(msg_id, raw_msg, after, follow, timeout)
else:
self.save_unmet(msg_id, raw_msg, after, follow)
self.save_unmet(msg_id, raw_msg, after, follow, timeout)

@logged
def fail_unreachable(self, msg_id):
def audit_timeouts(self):
"""Audit all waiting tasks for expired timeouts."""
now = datetime.now()
for msg_id in self.depending.keys():
# must recheck, in case one failure cascaded to another:
if msg_id in self.depending:
raw,after,follow,timeout = self.depending[msg_id]
if timeout and timeout < now:
self.fail_unreachable(msg_id, timeout=True)

@logged
def fail_unreachable(self, msg_id, timeout=False):
"""a message has become unreachable"""
if msg_id not in self.depending:
self.log.error("msg %r already failed!"%msg_id)
return
raw_msg, after, follow = self.depending.pop(msg_id)
raw_msg, after, follow, timeout = self.depending.pop(msg_id)
for mid in follow.union(after):
if mid in self.dependencies:
self.dependencies[mid].remove(msg_id)

# FIXME: unpacking a message I've already unpacked, but didn't save:
idents,msg = self.session.feed_identities(raw_msg, copy=False)
msg = self.session.unpack_message(msg, copy=False, content=False)
header = msg['header']

impossible = error.DependencyTimeout if timeout else error.ImpossibleDependency

try:
raise error.ImpossibleDependency()
raise impossible()
except:
content = ss.wrap_exception()

Expand Down Expand Up @@ -334,9 +349,9 @@ def can_run(idx):
return True

@logged
def save_unmet(self, msg_id, raw_msg, after, follow):
def save_unmet(self, msg_id, raw_msg, after, follow, timeout):
"""Save a message for later submission when its dependencies are met."""
self.depending[msg_id] = [raw_msg,after,follow]
self.depending[msg_id] = [raw_msg,after,follow,timeout]
# track the ids in follow or after, but not those already finished
for dep_id in after.union(follow).difference(self.all_done):
if dep_id not in self.dependencies:
Expand Down Expand Up @@ -413,10 +428,10 @@ def handle_unmet_dependency(self, idents, parent):
if msg_id not in self.blacklist:
self.blacklist[msg_id] = set()
self.blacklist[msg_id].add(engine)
raw_msg,follow = self.pending[engine].pop(msg_id)
raw_msg,follow,timeout = self.pending[engine].pop(msg_id)
if not self.maybe_run(msg_id, raw_msg, follow):
# resubmit failed, put it back in our dependency tree
self.save_unmet(msg_id, raw_msg, MET, follow)
self.save_unmet(msg_id, raw_msg, MET, follow, timeout)
pass

@logged
Expand All @@ -435,7 +450,7 @@ def update_dependencies(self, dep_id, success=True):
jobs = self.dependencies.pop(dep_id)

for msg_id in jobs:
raw_msg, after, follow = self.depending[msg_id]
raw_msg, after, follow, timeout = self.depending[msg_id]
# if dep_id in after:
# if after.mode == 'all' and (success or not after.success_only):
# after.remove(dep_id)
Expand Down Expand Up @@ -497,9 +512,9 @@ def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, logname='ZMQ', log_a
local_logger(logname, loglevel)

scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
mon_stream=mons,notifier_stream=nots,
scheme=scheme,io_loop=loop, logname=logname)

mon_stream=mons, notifier_stream=nots,
scheme=scheme, loop=loop, logname=logname)
scheduler.start()
try:
loop.start()
except KeyboardInterrupt:
Expand Down

0 comments on commit 23d1906

Please sign in to comment.