Permalink
Browse files

allow load-balancing across subsets of engines

  • Loading branch information...
1 parent 75d9c51 commit a514d133c6a8bb25e471dce528dd4dcfb7a87d3d @minrk minrk committed Feb 18, 2011
Showing with 237 additions and 99 deletions.
  1. +127 −43 IPython/zmq/parallel/client.py
  2. +64 −33 IPython/zmq/parallel/scheduler.py
  3. +46 −23 IPython/zmq/parallel/view.py
View
170 IPython/zmq/parallel/client.py
@@ -326,7 +326,7 @@ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipytho
else:
self._registration_socket.connect(url)
self._engines = ReverseDict()
- self._ids = set()
+ self._ids = []
self.outstanding=set()
self.results = {}
self.metadata = {}
@@ -370,7 +370,8 @@ def _update_engines(self, engines):
for k,v in engines.iteritems():
eid = int(k)
self._engines[eid] = bytes(v) # force not unicode
- self._ids.add(eid)
+ self._ids.append(eid)
+ self._ids = sorted(self._ids)
if sorted(self._engines.keys()) != range(len(self._engines)) and \
self._task_scheme == 'pure' and self._task_socket:
self._stop_scheduling_tasks()
@@ -470,7 +471,6 @@ def _register_engine(self, msg):
eid = content['id']
d = {eid : content['queue']}
self._update_engines(d)
- self._ids.add(int(eid))
def _unregister_engine(self, msg):
"""Unregister an engine that has died."""
@@ -664,9 +664,9 @@ def remote(self):
"""property for convenient RemoteFunction generation.
>>> @client.remote
- ... def f():
+ ... def getpid():
import os
- print (os.getpid())
+ return os.getpid()
"""
return remote(self, block=self.block)
@@ -867,6 +867,7 @@ def _build_dependency(self, dep):
# pass to Dependency constructor
return list(Dependency(dep))
+ @defaultblock
def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
after=None, follow=None, timeout=None):
"""Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
@@ -903,24 +904,9 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
Run on each specified engine
if int:
Run on single engine
-
- after : Dependency or collection of msg_ids
- Only for load-balanced execution (targets=None)
- Specify a list of msg_ids as a time-based dependency.
- This job will only be run *after* the dependencies
- have been met.
-
- follow : Dependency or collection of msg_ids
- Only for load-balanced execution (targets=None)
- Specify a list of msg_ids as a location-based dependency.
- This job will only be run on an engine where this dependency
- is met.
- timeout : float/int or None
- Only for load-balanced execution (targets=None)
- Specify an amount of time (in seconds) for the scheduler to
- wait for dependencies to be met before failing with a
- DependencyTimeout.
+ after,follow,timeout only used in `apply_balanced`. See that docstring
+ for details.
Returns
-------
@@ -947,25 +933,88 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be dict, not %s"%type(kwargs))
- options = dict(bound=bound, block=block)
+ options = dict(bound=bound, block=block, targets=targets)
if targets is None:
- if self._task_socket:
- return self._apply_balanced(f, args, kwargs, timeout=timeout,
- after=after, follow=follow, **options)
- else:
- msg = "Task farming is disabled"
- if self._task_scheme == 'pure':
- msg += " because the pure ZMQ scheduler cannot handle"
- msg += " disappearing engines."
- raise RuntimeError(msg)
+ return self.apply_balanced(f, args, kwargs, timeout=timeout,
+ after=after, follow=follow, **options)
else:
- return self._apply_direct(f, args, kwargs, targets=targets, **options)
+ if follow or after or timeout:
+ msg = "follow, after, and timeout args are only used for load-balanced"
+ msg += "execution."
+ raise ValueError(msg)
+ return self._apply_direct(f, args, kwargs, **options)
- def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
+ @defaultblock
+ def apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None,
after=None, follow=None, timeout=None):
- """The underlying method for applying functions in a load balanced
- manner, via the task queue."""
+ """call f(*args, **kwargs) remotely in a load-balanced manner.
+
+ Parameters
+ ----------
+
+ f : function
+ The fuction to be called remotely
+ args : tuple/list
+ The positional arguments passed to `f`
+ kwargs : dict
+ The keyword arguments passed to `f`
+ bound : bool (default: True)
+ Whether to execute in the Engine(s) namespace, or in a clean
+ namespace not affecting the engine.
+ block : bool (default: self.block)
+ Whether to wait for the result, or return immediately.
+ False:
+ returns AsyncResult
+ True:
+ returns actual result(s) of f(*args, **kwargs)
+ if multiple targets:
+ list of results, matching `targets`
+ targets : int,list of ints, 'all', None
+ Specify the destination of the job.
+ if None:
+ Submit via Task queue for load-balancing.
+ if 'all':
+ Run on all active engines
+ if list:
+ Run on each specified engine
+ if int:
+ Run on single engine
+
+ after : Dependency or collection of msg_ids
+ Only for load-balanced execution (targets=None)
+ Specify a list of msg_ids as a time-based dependency.
+ This job will only be run *after* the dependencies
+ have been met.
+
+ follow : Dependency or collection of msg_ids
+ Only for load-balanced execution (targets=None)
+ Specify a list of msg_ids as a location-based dependency.
+ This job will only be run on an engine where this dependency
+ is met.
+
+ timeout : float/int or None
+ Only for load-balanced execution (targets=None)
+ Specify an amount of time (in seconds) for the scheduler to
+ wait for dependencies to be met before failing with a
+ DependencyTimeout.
+
+ Returns
+ -------
+ if block is False:
+ return AsyncResult wrapping msg_id
+ output of AsyncResult.get() is identical to that of `apply(...block=True)`
+ else:
+ wait for, and return actual result of `f(*args, **kwargs)`
+
+ """
+
+ if self._task_socket is None:
+ msg = "Task farming is disabled"
+ if self._task_scheme == 'pure':
+ msg += " because the pure ZMQ scheduler cannot handle"
+ msg += " disappearing engines."
+ raise RuntimeError(msg)
if self._task_scheme == 'pure':
# pure zmq scheme doesn't support dependencies
@@ -978,9 +1027,26 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
warnings.warn(msg, RuntimeWarning)
+ # defaults:
+ args = args if args is not None else []
+ kwargs = kwargs if kwargs is not None else {}
+
+ if targets:
+ idents,_ = self._build_targets(targets)
+ else:
+ idents = []
+
+ # enforce types of f,args,kwrags
+ if not callable(f):
+ raise TypeError("f must be callable, not %s"%type(f))
+ if not isinstance(args, (tuple, list)):
+ raise TypeError("args must be tuple or list, not %s"%type(args))
+ if not isinstance(kwargs, dict):
+ raise TypeError("kwargs must be dict, not %s"%type(kwargs))
+
after = self._build_dependency(after)
follow = self._build_dependency(follow)
- subheader = dict(after=after, follow=follow, timeout=timeout)
+ subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
bufs = ss.pack_apply_message(f,args,kwargs)
content = dict(bound=bound)
@@ -991,31 +1057,40 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
self.history.append(msg_id)
ar = AsyncResult(self, [msg_id], fname=f.__name__)
if block:
- return ar.get()
+ try:
+ return ar.get()
+ except KeyboardInterrupt:
+ return ar
else:
return ar
def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):
"""Then underlying method for applying functions to specific engines
- via the MUX queue."""
+ via the MUX queue.
+
+ Not to be called directly!
+ """
- queues,targets = self._build_targets(targets)
+ idents,targets = self._build_targets(targets)
subheader = {}
content = dict(bound=bound)
bufs = ss.pack_apply_message(f,args,kwargs)
msg_ids = []
- for queue in queues:
+ for ident in idents:
msg = self.session.send(self._mux_socket, "apply_request",
- content=content, buffers=bufs,ident=queue, subheader=subheader)
+ content=content, buffers=bufs, ident=ident, subheader=subheader)
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
self.history.append(msg_id)
msg_ids.append(msg_id)
ar = AsyncResult(self, msg_ids, fname=f.__name__)
if block:
- return ar.get()
+ try:
+ return ar.get()
+ except KeyboardInterrupt:
+ return ar
else:
return ar
@@ -1037,6 +1112,15 @@ def remote(self, bound=True, targets='all', block=True):
"""Decorator for making a RemoteFunction."""
return remote(self, bound=bound, targets=targets, block=block)
+ def view(self, targets=None, balanced=False):
+ """Method for constructing View objects"""
+ if not balanced:
+ if not targets:
+ targets = slice(None)
+ return self[targets]
+ else:
+ return LoadBalancedView(self, targets)
+
#--------------------------------------------------------------------------
# Data movement
#--------------------------------------------------------------------------
View
97 IPython/zmq/parallel/scheduler.py
@@ -265,6 +265,9 @@ def dispatch_submission(self, raw_msg):
msg_id = header['msg_id']
self.all_ids.add(msg_id)
+ # targets
+ targets = set(header.get('targets', []))
+
# time dependencies
after = Dependency(header.get('after', []))
if after.all:
@@ -279,28 +282,31 @@ def dispatch_submission(self, raw_msg):
# location dependencies
follow = Dependency(header.get('follow', []))
+ # turn timeouts into datetime objects:
+ timeout = header.get('timeout', None)
+ if timeout:
+ timeout = datetime.now() + timedelta(0,timeout,0)
+
+ args = [raw_msg, targets, after, follow, timeout]
+
+ # validate and reduce dependencies:
for dep in after,follow:
# check valid:
if msg_id in dep or dep.difference(self.all_ids):
- self.depending[msg_id] = [raw_msg,MET,MET,None]
+ self.depending[msg_id] = args
return self.fail_unreachable(msg_id, error.InvalidDependency)
# check if unreachable:
if dep.unreachable(self.all_failed):
- self.depending[msg_id] = [raw_msg,MET,MET,None]
+ self.depending[msg_id] = args
return self.fail_unreachable(msg_id)
- # turn timeouts into datetime objects:
- timeout = header.get('timeout', None)
- if timeout:
- timeout = datetime.now() + timedelta(0,timeout,0)
-
if after.check(self.all_completed, self.all_failed):
# time deps already met, try to run
- if not self.maybe_run(msg_id, raw_msg, follow, timeout):
+ if not self.maybe_run(msg_id, *args):
# can't run yet
- self.save_unmet(msg_id, raw_msg, after, follow, timeout)
+ self.save_unmet(msg_id, *args)
else:
- self.save_unmet(msg_id, raw_msg, after, follow, timeout)
+ self.save_unmet(msg_id, *args)
# @logged
def audit_timeouts(self):
@@ -309,17 +315,18 @@ def audit_timeouts(self):
for msg_id in self.depending.keys():
# must recheck, in case one failure cascaded to another:
if msg_id in self.depending:
- raw,after,follow,timeout = self.depending[msg_id]
+ raw,after,targets,follow,timeout = self.depending[msg_id]
if timeout and timeout < now:
self.fail_unreachable(msg_id, timeout=True)
@logged
def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
- """a message has become unreachable"""
+ """a task has become unreachable, send a reply with an ImpossibleDependency
+ error."""
if msg_id not in self.depending:
self.log.error("msg %r already failed!"%msg_id)
return
- raw_msg, after, follow, timeout = self.depending.pop(msg_id)
+ raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
for mid in follow.union(after):
if mid in self.graph:
self.graph[mid].remove(msg_id)
@@ -344,45 +351,59 @@ def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
self.update_graph(msg_id, success=False)
@logged
- def maybe_run(self, msg_id, raw_msg, follow=None, timeout=None):
+ def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
"""check location dependencies, and run if they are met."""
-
- if follow:
+ blacklist = self.blacklist.setdefault(msg_id, set())
+ if follow or targets or blacklist:
+ # we need a can_run filter
def can_run(idx):
target = self.targets[idx]
- return target not in self.blacklist.get(msg_id, []) and\
- follow.check(self.completed[target], self.failed[target])
+ # check targets
+ if targets and target not in targets:
+ return False
+ # check blacklist
+ if target in blacklist:
+ return False
+ # check follow
+ return follow.check(self.completed[target], self.failed[target])
indices = filter(can_run, range(len(self.targets)))
if not indices:
+ # couldn't run
if follow.all:
+ # check follow for impossibility
dests = set()
relevant = self.all_completed if follow.success_only else self.all_done
for m in follow.intersection(relevant):
dests.add(self.destinations[m])
if len(dests) > 1:
self.fail_unreachable(msg_id)
-
-
+ return False
+ if targets:
+ # check blacklist+targets for impossibility
+ targets.difference_update(blacklist)
+ if not targets or not targets.intersection(self.targets):
+ self.fail_unreachable(msg_id)
+ return False
return False
else:
indices = None
- self.submit_task(msg_id, raw_msg, follow, timeout, indices)
+ self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
return True
@logged
- def save_unmet(self, msg_id, raw_msg, after, follow, timeout):
+ def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
"""Save a message for later submission when its dependencies are met."""
- self.depending[msg_id] = [raw_msg,after,follow,timeout]
+ self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
# track the ids in follow or after, but not those already finished
for dep_id in after.union(follow).difference(self.all_done):
if dep_id not in self.graph:
self.graph[dep_id] = set()
self.graph[dep_id].add(msg_id)
@logged
- def submit_task(self, msg_id, raw_msg, follow, timeout, indices=None):
+ def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
"""Submit a task to any of a subset of our targets."""
if indices:
loads = [self.loads[i] for i in indices]
@@ -396,7 +417,7 @@ def submit_task(self, msg_id, raw_msg, follow, timeout, indices=None):
self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
self.engine_stream.send_multipart(raw_msg, copy=False)
self.add_job(idx)
- self.pending[target][msg_id] = (raw_msg, follow, timeout)
+ self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
content = dict(msg_id=msg_id, engine_id=target)
self.session.send(self.mon_stream, 'task_destination', content=content,
ident=['tracktask',self.session.session])
@@ -406,6 +427,7 @@ def submit_task(self, msg_id, raw_msg, follow, timeout, indices=None):
#-----------------------------------------------------------------------
@logged
def dispatch_result(self, raw_msg):
+ """dispatch method for result replies"""
try:
idents,msg = self.session.feed_identities(raw_msg, copy=False)
msg = self.session.unpack_message(msg, content=False, copy=False)
@@ -424,6 +446,7 @@ def dispatch_result(self, raw_msg):
@logged
def handle_result(self, idents, parent, raw_msg, success=True):
+ """handle a real task result, either success or failure"""
# first, relay result to client
engine = idents[0]
client = idents[1]
@@ -448,21 +471,30 @@ def handle_result(self, idents, parent, raw_msg, success=True):
@logged
def handle_unmet_dependency(self, idents, parent):
+ """handle an unmet dependency"""
engine = idents[0]
msg_id = parent['msg_id']
+
if msg_id not in self.blacklist:
self.blacklist[msg_id] = set()
self.blacklist[msg_id].add(engine)
- raw_msg,follow,timeout = self.pending[engine].pop(msg_id)
- if not self.maybe_run(msg_id, raw_msg, follow, timeout):
+
+ args = self.pending[engine].pop(msg_id)
+ raw,targets,after,follow,timeout = args
+
+ if self.blacklist[msg_id] == targets:
+ self.depending[msg_id] = args
+ return self.fail_unreachable(msg_id)
+
+ elif not self.maybe_run(msg_id, *args):
# resubmit failed, put it back in our dependency tree
- self.save_unmet(msg_id, raw_msg, MET, follow, timeout)
- pass
+ self.save_unmet(msg_id, *args)
+
@logged
def update_graph(self, dep_id, success=True):
"""dep_id just finished. Update our dependency
- table and submit any jobs that just became runable."""
+ graph and submit any jobs that just became runable."""
# print ("\n\n***********")
# pprint (dep_id)
# pprint (self.graph)
@@ -475,7 +507,7 @@ def update_graph(self, dep_id, success=True):
jobs = self.graph.pop(dep_id)
for msg_id in jobs:
- raw_msg, after, follow, timeout = self.depending[msg_id]
+ raw_msg, targets, after, follow, timeout = self.depending[msg_id]
# if dep_id in after:
# if after.all and (success or not after.success_only):
# after.remove(dep_id)
@@ -484,8 +516,7 @@ def update_graph(self, dep_id, success=True):
self.fail_unreachable(msg_id)
elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
- self.depending[msg_id][1] = MET
- if self.maybe_run(msg_id, raw_msg, follow, timeout):
+ if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
self.depending.pop(msg_id)
for mid in follow.union(after):
View
69 IPython/zmq/parallel/view.py
@@ -66,10 +66,15 @@ class View(object):
Don't use this class, use subclasses.
"""
- _targets = None
block=None
bound=None
history=None
+ outstanding = set()
+ results = {}
+
+ _targets = None
+ _apply_name = 'apply'
+ _default_names = ['targets', 'block']
def __init__(self, client, targets=None):
self.client = client
@@ -80,6 +85,9 @@ def __init__(self, client, targets=None):
self.history = []
self.outstanding = set()
self.results = {}
+ for name in self._default_names:
+ setattr(self, name, getattr(self, name, None))
+
def __repr__(self):
strtargets = str(self._targets)
@@ -95,11 +103,23 @@ def targets(self):
def targets(self, value):
self._targets = value
# raise AttributeError("Cannot set my targets argument after construction!")
-
+
+ def _defaults(self, *excludes):
+ """return dict of our default attributes, excluding names given."""
+ d = {}
+ for name in self._default_names:
+ if name not in excludes:
+ d[name] = getattr(self, name)
+ return d
+
@sync_results
def spin(self):
"""spin the client, and sync"""
self.client.spin()
+
+ @property
+ def _apply(self):
+ return getattr(self.client, self._apply_name)
@sync_results
@save_ids
@@ -113,7 +133,7 @@ def apply(self, f, *args, **kwargs):
else:
returns actual result of f(*args, **kwargs)
"""
- return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=self.bound)
+ return self._apply(f, args, kwargs, **self._defaults())
@save_ids
def apply_async(self, f, *args, **kwargs):
@@ -123,7 +143,8 @@ def apply_async(self, f, *args, **kwargs):
returns msg_id
"""
- return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False)
+ d = self._defaults('block', 'bound')
+ return self._apply(f,args,kwargs, block=False, bound=False, **d)
@spin_after
@save_ids
@@ -135,7 +156,8 @@ def apply_sync(self, f, *args, **kwargs):
returns: actual result of f(*args, **kwargs)
"""
- return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False)
+ d = self._defaults('block', 'bound')
+ return self._apply(f,args,kwargs, block=True, bound=False, **d)
@sync_results
@save_ids
@@ -150,7 +172,8 @@ def apply_bound(self, f, *args, **kwargs):
This method has access to the targets' globals
"""
- return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True)
+ d = self._defaults('bound')
+ return self._apply(f, args, kwargs, bound=True, **d)
@sync_results
@save_ids
@@ -163,7 +186,8 @@ def apply_async_bound(self, f, *args, **kwargs):
This method has access to the targets' globals
"""
- return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True)
+ d = self._defaults('block', 'bound')
+ return self._apply(f, args, kwargs, block=False, bound=True, **d)
@spin_after
@save_ids
@@ -175,7 +199,8 @@ def apply_sync_bound(self, f, *args, **kwargs):
This method has access to the targets' globals
"""
- return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True)
+ d = self._defaults('block', 'bound')
+ return self._apply(f, args, kwargs, block=True, bound=True, **d)
@spin_after
@save_ids
@@ -337,24 +362,22 @@ class LoadBalancedView(View):
Typically created via:
- >>> lbv = client[None]
- <LoadBalancedView tcp://127.0.0.1:12345>
+ >>> v = client[None]
+ <LoadBalancedView None>
but can also be created with:
- >>> lbc = LoadBalancedView(client)
+ >>> v = client.view([1,3],balanced=True)
+
+ which would restrict loadbalancing to between engines 1 and 3.
- TODO: allow subset of engines across which to balance.
"""
- def __repr__(self):
- return "<%s %s>"%(self.__class__.__name__, self.client._config['url'])
- @property
- def targets(self):
- return None
-
- @targets.setter
- def targets(self, value):
- raise AttributeError("Cannot set targets for LoadbalancedView!")
-
-
+ _apply_name = 'apply_balanced'
+ _default_names = ['targets', 'block', 'bound', 'follow', 'after', 'timeout']
+
+ def __init__(self, client, targets=None):
+ super(LoadBalancedView, self).__init__(client, targets)
+ self._ntargets = 1
+ self._apply_name = 'apply_balanced'
+

0 comments on commit a514d13

Please sign in to comment.