Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

match return shape in AsyncResult to sync results

  • Loading branch information...
commit e9e0d814dce8531935890b0fe59f3450ec55f491 1 parent 7d08ffd
@minrk minrk authored
View
5 IPython/zmq/parallel/asyncresult.py
@@ -21,9 +21,10 @@ class AsyncResult(object):
Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
"""
- def __init__(self, client, msg_ids):
+ def __init__(self, client, msg_ids, targets=None):
self._client = client
self.msg_ids = msg_ids
+ self._targets=targets
self._ready = False
self._success = None
@@ -41,6 +42,8 @@ def _reconstruct_result(self, res):
"""
if len(res) == 1:
return res[0]
+ elif self.targets is not None:
+ return dict(zip(self._targets, res))
else:
return res
View
41 IPython/zmq/parallel/client.py
@@ -632,7 +632,7 @@ def run(self, code, block=None):
whether or not to wait until done
"""
- result = self.apply(execute, (code,), targets=None, block=block, bound=False)
+ result = self.apply(_execute, (code,), targets=None, block=block, bound=False)
return result
def _maybe_raise(self, result):
@@ -721,6 +721,18 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be dict, not %s"%type(kwargs))
+ if isinstance(after, Dependency):
+ after = after.as_dict()
+ elif isinstance(after, AsyncResult):
+ after=after.msg_ids
+ elif after is None:
+ after = []
+ if isinstance(follow, Dependency):
+ follow = follow.as_dict()
+ elif isinstance(follow, AsyncResult):
+ follow=follow.msg_ids
+ elif follow is None:
+ follow = []
options = dict(bound=bound, block=block, after=after, follow=follow)
if targets is None:
@@ -732,18 +744,11 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
after=None, follow=None):
"""The underlying method for applying functions in a load balanced
manner, via the task queue."""
- if isinstance(after, Dependency):
- after = after.as_dict()
- elif after is None:
- after = []
- if isinstance(follow, Dependency):
- follow = follow.as_dict()
- elif follow is None:
- follow = []
- subheader = dict(after=after, follow=follow)
+ subheader = dict(after=after, follow=follow)
bufs = ss.pack_apply_message(f,args,kwargs)
content = dict(bound=bound)
+
msg = self.session.send(self._task_socket, "apply_request",
content=content, buffers=bufs, subheader=subheader)
msg_id = msg['msg_id']
@@ -761,17 +766,11 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
via the MUX queue."""
queues,targets = self._build_targets(targets)
- bufs = ss.pack_apply_message(f,args,kwargs)
- if isinstance(after, Dependency):
- after = after.as_dict()
- elif after is None:
- after = []
- if isinstance(follow, Dependency):
- follow = follow.as_dict()
- elif follow is None:
- follow = []
+
subheader = dict(after=after, follow=follow)
content = dict(bound=bound)
+ bufs = ss.pack_apply_message(f,args,kwargs)
+
msg_ids = []
for queue in queues:
msg = self.session.send(self._mux_socket, "apply_request",
@@ -783,7 +782,7 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
if block:
self.barrier(msg_ids)
else:
- return AsyncResult(self, msg_ids)
+ return AsyncResult(self, msg_ids, targets=targets)
if len(msg_ids) == 1:
return self._maybe_raise(self.results[msg_ids[0]])
else:
@@ -850,7 +849,7 @@ def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
else:
r = self.push({key: partition}, targets=engineid, block=False)
msg_ids.extend(r.msg_ids)
- r = AsyncResult(self, msg_ids)
+ r = AsyncResult(self, msg_ids,targets)
if block:
return r.get()
else:
View
8 IPython/zmq/parallel/view.py
@@ -263,21 +263,19 @@ def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None):
Partition a Python sequence and send the partitions to a set of engines.
"""
block = block if block is not None else self.block
- if targets is None:
- targets = self.targets
+ targets = targets if targets is not None else self.targets
return self.client.scatter(key, seq, dist=dist, flatten=flatten,
targets=targets, block=block)
@sync_results
@save_ids
- def gather(self, key, dist='b', targets=None, block=True):
+ def gather(self, key, dist='b', targets=None, block=None):
"""
Gather a partitioned sequence on a set of engines as a single local seq.
"""
block = block if block is not None else self.block
- if targets is None:
- targets = self.targets
+ targets = targets if targets is not None else self.targets
return self.client.gather(key, dist=dist, targets=targets, block=block)
Please sign in to comment.
Something went wrong with that request. Please try again.