Skip to content

Commit

Permalink
add map/scatter/gather/ParallelFunction from kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Apr 8, 2011
1 parent a4b0811 commit 15b7567
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 11 deletions.
193 changes: 185 additions & 8 deletions IPython/zmq/parallel/client.py
Expand Up @@ -28,6 +28,7 @@
from view import DirectView, LoadBalancedView
from dependency import Dependency, depend, require
import error
import map as Map

#--------------------------------------------------------------------------
# helpers for implementing old MEC API via client.apply
Expand Down Expand Up @@ -92,6 +93,18 @@ def remote_function(f):
return RemoteFunction(client, f, bound, block, targets)
return remote_function

def parallel(client, dist='b', bound=False, block=None, targets='all'):
"""Turn a function into a parallel remote function.
This method can be used for map:
>>> @parallel(client,block=True)
def func(a)
"""
def parallel_function(f):
return ParallelFunction(client, f, dist, bound, block, targets)
return parallel_function

#--------------------------------------------------------------------------
# Classes
#--------------------------------------------------------------------------
Expand Down Expand Up @@ -133,6 +146,103 @@ def __call__(self, *args, **kwargs):
block=self.block, targets=self.targets, bound=self.bound)


class ParallelFunction(RemoteFunction):
"""Class for mapping a function to sequences."""
def __init__(self, client, f, dist='b', bound=False, block=None, targets='all'):
super(ParallelFunction, self).__init__(client,f,bound,block,targets)
mapClass = Map.dists[dist]
self.mapObject = mapClass()

def __call__(self, *sequences):
len_0 = len(sequences[0])
for s in sequences:
if len(s)!=len_0:
raise ValueError('all sequences must have equal length')

if self.targets is None:
# load-balanced:
engines = [None]*len_0
else:
# multiplexed:
engines = self.client._build_targets(self.targets)[-1]

nparts = len(engines)
msg_ids = []
for index, engineid in enumerate(engines):
args = []
for seq in sequences:
args.append(self.mapObject.getPartition(seq, index, nparts))
mid = self.client.apply(self.func, args=args, block=False,
bound=self.bound,
targets=engineid)
msg_ids.append(mid)

if self.block:
dg = PendingMapResult(self.client, msg_ids, self.mapObject)
dg.wait()
return dg.result
else:
return dg


class PendingResult(object):
"""Class for representing results of non-blocking calls."""
def __init__(self, client, msg_ids):
self.client = client
self.msg_ids = msg_ids
self._result = None
self.done = False

def __repr__(self):
if self.done:
return "<%s: finished>"%(self.__class__.__name__)
else:
return "<%s: %r>"%(self.__class__.__name__,self.msg_ids)

@property
def result(self):
if self._result is not None:
return self._result
if not self.done:
self.wait(0)
if self.done:
results = map(self.client.results.get, self.msg_ids)
results = error.collect_exceptions(results, 'get_result')
self._result = self.reconstruct_result(results)
return self._result
else:
raise error.ResultNotCompleted

def reconstruct_result(self, res):
"""
Override me in subclasses for turning a list of results
into the expected form.
"""
if len(res) == 1:
return res[0]
else:
return res

def wait(self, timout=-1):
self.done = self.client.barrier(self.msg_ids)
return self.done

class PendingMapResult(PendingResult):
"""Class for representing results of non-blocking gathers.
This will properly reconstruct the gather.
"""

def __init__(self, client, msg_ids, mapObject):
self.mapObject = mapObject
PendingResult.__init__(self, client, msg_ids)

def reconstruct_result(self, res):
"""Perform the gather on the actual results."""
return self.mapObject.joinPartitions(res)



class AbortedTask(object):
"""A basic wrapper object describing an aborted task."""
def __init__(self, msg_id):
Expand Down Expand Up @@ -498,6 +608,17 @@ def __getitem__(self, key):
# Begin public methods
#--------------------------------------------------------------------------

@property
def remote(self):
"""property for convenient RemoteFunction generation.
>>> @client.remote
... def f():
import os
print (os.getpid())
"""
return remote(self, block=self.block)

def spin(self):
"""Flush any registration notifications and execution results
waiting in the ZMQ queue.
Expand Down Expand Up @@ -784,7 +905,7 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
self.barrier(msg_id)
return self._maybe_raise(self.results[msg_id])
else:
return msg_id
return PendingResult(self, [msg_id])

def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
after=None, follow=None):
Expand Down Expand Up @@ -814,10 +935,7 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
if block:
self.barrier(msg_ids)
else:
if len(msg_ids) == 1:
return msg_ids[0]
else:
return msg_ids
return PendingResult(self, msg_ids)
if len(msg_ids) == 1:
return self._maybe_raise(self.results[msg_ids[0]])
else:
Expand All @@ -826,20 +944,25 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
result[target] = self.results[mid]
return error.collect_exceptions(result, f.__name__)

@defaultblock
def map(self, f, sequences, targets=None, block=None, bound=False):
pf = ParallelFunction(self,f,block=block,bound=bound,targets=targets)
return pf(*sequences)

#--------------------------------------------------------------------------
# Data movement
#--------------------------------------------------------------------------

@defaultblock
def push(self, ns, targets=None, block=None):
def push(self, ns, targets='all', block=None):
"""Push the contents of `ns` into the namespace on `target`"""
if not isinstance(ns, dict):
raise TypeError("Must be a dict, not %s"%type(ns))
result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
return result

@defaultblock
def pull(self, keys, targets=None, block=True):
def pull(self, keys, targets='all', block=True):
"""Pull objects from `target`'s namespace by `keys`"""
if isinstance(keys, str):
pass
Expand All @@ -850,6 +973,48 @@ def pull(self, keys, targets=None, block=True):
result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
return result

@defaultblock
def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
"""
Partition a Python sequence and send the partitions to a set of engines.
"""
targets = self._build_targets(targets)[-1]
mapObject = Map.dists[dist]()
nparts = len(targets)
msg_ids = []
for index, engineid in enumerate(targets):
partition = mapObject.getPartition(seq, index, nparts)
if flatten and len(partition) == 1:
mid = self.push({key: partition[0]}, targets=engineid, block=False)
else:
mid = self.push({key: partition}, targets=engineid, block=False)
msg_ids.append(mid)
r = PendingResult(self, msg_ids)
if block:
r.wait()
return
else:
return r

@defaultblock
def gather(self, key, dist='b', targets='all', block=True):
"""
Gather a partitioned sequence on a set of engines as a single local seq.
"""

targets = self._build_targets(targets)[-1]
mapObject = Map.dists[dist]()
msg_ids = []
for index, engineid in enumerate(targets):
msg_ids.append(self.pull(key, targets=engineid,block=False))

r = PendingMapResult(self, msg_ids, mapObject)
if block:
r.wait()
return r.result
else:
return r

#--------------------------------------------------------------------------
# Query methods
#--------------------------------------------------------------------------
Expand Down Expand Up @@ -985,4 +1150,16 @@ def spin(self):
for stream in (self.queue_stream, self.notifier_stream,
self.task_stream, self.control_stream):
stream.flush()


__all__ = [ 'Client',
'depend',
'require',
'remote',
'parallel',
'RemoteFunction',
'ParallelFunction',
'DirectView',
'LoadBalancedView',
'PendingResult',
'PendingMapResult'
]
10 changes: 7 additions & 3 deletions IPython/zmq/parallel/error.py
Expand Up @@ -247,11 +247,15 @@ def raise_exception(self, excid=0):
et,ev,tb = sys.exc_info()


def collect_exceptions(rdict, method):
def collect_exceptions(rdict_or_list, method):
"""check a result dict for errors, and raise CompositeError if any exist.
Passthrough otherwise."""
elist = []
for r in rdict.values():
if isinstance(rdict_or_list, dict):
rlist = rdict_or_list.values()
else:
rlist = rdict_or_list
for r in rlist:
if isinstance(r, RemoteError):
en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
# Sometimes we could have CompositeError in our list. Just take
Expand All @@ -264,7 +268,7 @@ def collect_exceptions(rdict, method):
else:
elist.append((en, ev, etb, ei))
if len(elist)==0:
return rdict
return rdict_or_list
else:
msg = "one or more exceptions from call to method: %s" % (method)
# This silliness is needed so the debugger has access to the exception
Expand Down

0 comments on commit 15b7567

Please sign in to comment.