Permalink
Browse files

multitarget returns list instead of dict

  • Loading branch information...
1 parent 39eab52 commit 643109253cbc45b7089024151e0a803919d0870f @minrk minrk committed Jan 28, 2011
Showing with 14 additions and 23 deletions.
  1. +6 −8 IPython/zmq/parallel/asyncresult.py
  2. +8 −15 IPython/zmq/parallel/client.py
@@ -21,18 +21,18 @@ class AsyncResult(object):
Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
"""
- def __init__(self, client, msg_ids, targets=None):
+ def __init__(self, client, msg_ids, fname=''):
self._client = client
self.msg_ids = msg_ids
- self._targets=targets
+ self._fname=fname
self._ready = False
self._success = None
def __repr__(self):
if self._ready:
return "<%s: finished>"%(self.__class__.__name__)
else:
- return "<%s: %r>"%(self.__class__.__name__,self.msg_ids)
+ return "<%s: %s>"%(self.__class__.__name__,self._fname)
def _reconstruct_result(self, res):
@@ -42,8 +42,6 @@ def _reconstruct_result(self, res):
"""
if len(res) == 1:
return res[0]
- elif self.targets is not None:
- return dict(zip(self._targets, res))
else:
return res
@@ -81,7 +79,7 @@ def wait(self, timeout=-1):
if self._ready:
try:
results = map(self._client.results.get, self.msg_ids)
- results = error.collect_exceptions(results, 'get')
+ results = error.collect_exceptions(results, self._fname)
self._result = self._reconstruct_result(results)
except Exception, e:
self._exception = e
@@ -104,9 +102,9 @@ class AsyncMapResult(AsyncResult):
This will properly reconstruct the gather.
"""
- def __init__(self, client, msg_ids, mapObject):
+ def __init__(self, client, msg_ids, mapObject, fname=''):
self._mapObject = mapObject
- AsyncResult.__init__(self, client, msg_ids)
+ AsyncResult.__init__(self, client, msg_ids, fname=fname)
def _reconstruct_result(self, res):
"""Perform the gather on the actual results."""
@@ -83,7 +83,6 @@ def defaultblock(f, self, *args, **kwargs):
self.block = saveblock
return ret
-
class AbortedTask(object):
"""A basic wrapper object describing an aborted task."""
def __init__(self, msg_id):
@@ -754,11 +753,11 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
self.history.append(msg_id)
+ ar = AsyncResult(self, [msg_id], fname=f.__name__)
if block:
- self.barrier(msg_id)
- return self._maybe_raise(self.results[msg_id])
+ return ar.get()
else:
- return AsyncResult(self, [msg_id])
+ return ar
def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
after=None, follow=None):
@@ -779,17 +778,11 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
self.outstanding.add(msg_id)
self.history.append(msg_id)
msg_ids.append(msg_id)
+ ar = AsyncResult(self, msg_ids, fname=f.__name__)
if block:
- self.barrier(msg_ids)
- else:
- return AsyncResult(self, msg_ids, targets=targets)
- if len(msg_ids) == 1:
- return self._maybe_raise(self.results[msg_ids[0]])
+ return ar.get()
else:
- result = {}
- for target,mid in zip(targets, msg_ids):
- result[target] = self.results[mid]
- return error.collect_exceptions(result, f.__name__)
+ return ar
#--------------------------------------------------------------------------
# Map and decorators
@@ -849,7 +842,7 @@ def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
else:
r = self.push({key: partition}, targets=engineid, block=False)
msg_ids.extend(r.msg_ids)
- r = AsyncResult(self, msg_ids,targets)
+ r = AsyncResult(self, msg_ids, fname='scatter')
if block:
return r.get()
else:
@@ -867,7 +860,7 @@ def gather(self, key, dist='b', targets='all', block=None):
for index, engineid in enumerate(targets):
msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
- r = AsyncMapResult(self, msg_ids, mapObject)
+ r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
if block:
return r.get()
else:

0 comments on commit 6431092

Please sign in to comment.