Permalink
Browse files

Client -> HasTraits, update examples with API tweaks

  • Loading branch information...
1 parent 498a93f commit 154798bf001af5a3e6106564b5b74ae4d0338b4e @minrk minrk committed Feb 26, 2011
@@ -24,6 +24,8 @@
# from zmq.eventloop import ioloop, zmqstream
from IPython.utils.path import get_ipython_dir
+from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
+ Dict, List, Bool, Str, Set)
from IPython.external.decorator import decorator
from IPython.external.ssh import tunnel
@@ -147,7 +149,7 @@ def __setitem__(self, key, value):
raise KeyError(key)
-class Client(object):
+class Client(HasTraits):
"""A semi-synchronous client to the IPython ZMQ controller
Parameters
@@ -247,31 +249,41 @@ class Client(object):
"""
- _connected=False
- _ssh=False
- _engines=None
- _registration_socket=None
- _query_socket=None
- _control_socket=None
- _iopub_socket=None
- _notification_socket=None
- _mux_socket=None
- _task_socket=None
- _task_scheme=None
- block = False
- outstanding=None
- results = None
- history = None
- debug = False
- targets = None
+ block = Bool(False)
+ outstanding=Set()
+ results = Dict()
+ metadata = Dict()
+ history = List()
+ debug = Bool(False)
+ profile=CUnicode('default')
+
+ _ids = List()
+ _connected=Bool(False)
+ _ssh=Bool(False)
+ _context = Instance('zmq.Context')
+ _config = Dict()
+ _engines=Instance(ReverseDict, (), {})
+ _registration_socket=Instance('zmq.Socket')
+ _query_socket=Instance('zmq.Socket')
+ _control_socket=Instance('zmq.Socket')
+ _iopub_socket=Instance('zmq.Socket')
+ _notification_socket=Instance('zmq.Socket')
+ _mux_socket=Instance('zmq.Socket')
+ _task_socket=Instance('zmq.Socket')
+ _task_scheme=Str()
+ _balanced_views=Dict()
+ _direct_views=Dict()
+ _closed = False
def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
context=None, username=None, debug=False, exec_key=None,
sshserver=None, sshkey=None, password=None, paramiko=None,
):
+ super(Client, self).__init__(debug=debug, profile=profile)
if context is None:
context = zmq.Context()
- self.context = context
+ self._context = context
+
self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
if self._cd is not None:
@@ -325,20 +337,14 @@ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipytho
self.session = ss.StreamSession(**key_arg)
else:
self.session = ss.StreamSession(username, **key_arg)
- self._registration_socket = self.context.socket(zmq.XREQ)
+ self._registration_socket = self._context.socket(zmq.XREQ)
self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
if self._ssh:
tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
else:
self._registration_socket.connect(url)
- self._engines = ReverseDict()
- self._ids = []
- self.outstanding=set()
- self.results = {}
- self.metadata = {}
- self.history = []
- self.debug = debug
- self.session.debug = debug
+
+ self.session.debug = self.debug
self._notification_handlers = {'registration_notification' : self._register_engine,
'unregistration_notification' : self._unregister_engine,
@@ -370,6 +376,14 @@ def ids(self):
"""Always up-to-date ids property."""
self._flush_notifications()
return self._ids
+
+ def close(self):
+ if self._closed:
+ return
+ snames = filter(lambda n: n.endswith('socket'), dir(self))
+ for socket in map(lambda name: getattr(self, name), snames):
+ socket.close()
+ self._closed = True
def _update_engines(self, engines):
"""Update our engines dict and _ids from a dict of the form: {id:uuid}."""
@@ -436,28 +450,28 @@ def connect_socket(s, url):
self._config['registration'] = dict(content)
if content.status == 'ok':
if content.mux:
- self._mux_socket = self.context.socket(zmq.PAIR)
+ self._mux_socket = self._context.socket(zmq.PAIR)
self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
connect_socket(self._mux_socket, content.mux)
if content.task:
self._task_scheme, task_addr = content.task
- self._task_socket = self.context.socket(zmq.PAIR)
+ self._task_socket = self._context.socket(zmq.PAIR)
self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
connect_socket(self._task_socket, task_addr)
if content.notification:
- self._notification_socket = self.context.socket(zmq.SUB)
+ self._notification_socket = self._context.socket(zmq.SUB)
connect_socket(self._notification_socket, content.notification)
self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
if content.query:
- self._query_socket = self.context.socket(zmq.PAIR)
+ self._query_socket = self._context.socket(zmq.PAIR)
self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
connect_socket(self._query_socket, content.query)
if content.control:
- self._control_socket = self.context.socket(zmq.PAIR)
+ self._control_socket = self._context.socket(zmq.PAIR)
self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
connect_socket(self._control_socket, content.control)
if content.iopub:
- self._iopub_socket = self.context.socket(zmq.SUB)
+ self._iopub_socket = self._context.socket(zmq.SUB)
self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
connect_socket(self._iopub_socket, content.iopub)
@@ -636,9 +650,13 @@ def _flush_iopub(self, sock):
msg = self.session.recv(sock, mode=zmq.NOBLOCK)
#--------------------------------------------------------------------------
- # getitem
+ # len, getitem
#--------------------------------------------------------------------------
+ def __len__(self):
+ """len(client) returns # of engines."""
+ return len(self.ids)
+
def __getitem__(self, key):
"""index access returns DirectView multiplexer objects
@@ -929,8 +947,9 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
else:
return list of results, matching `targets`
"""
-
+ assert not self._closed, "cannot use me anymore, I'm closed!"
# defaults:
+ block = block if block is not None else self.block
args = args if args is not None else []
kwargs = kwargs if kwargs is not None else {}
@@ -955,7 +974,7 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
raise TypeError("kwargs must be dict, not %s"%type(kwargs))
options = dict(bound=bound, block=block, targets=targets)
-
+
if balanced:
return self._apply_balanced(f, args, kwargs, timeout=timeout,
after=after, follow=follow, **options)
@@ -966,16 +985,17 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
else:
return self._apply_direct(f, args, kwargs, **options)
- def _apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None,
+ def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
after=None, follow=None, timeout=None):
"""call f(*args, **kwargs) remotely in a load-balanced manner.
This is a private method, see `apply` for details.
Not to be called directly!
"""
- for kwarg in (bound, block, targets):
- assert kwarg is not None, "kwarg %r must be specified!"%kwarg
+ loc = locals()
+ for name in ('bound', 'block'):
+ assert loc[name] is not None, "kwarg %r must be specified!"%name
if self._task_socket is None:
msg = "Task farming is disabled"
@@ -1030,9 +1050,9 @@ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
This is a private method, see `apply` for details.
Not to be called directly!
"""
-
- for kwarg in (bound, block, targets):
- assert kwarg is not None, "kwarg %r must be specified!"%kwarg
+ loc = locals()
+ for name in ('bound', 'block', 'targets'):
+ assert loc[name] is not None, "kwarg %r must be specified!"%name
idents,targets = self._build_targets(targets)
@@ -1058,35 +1078,65 @@ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
return ar
#--------------------------------------------------------------------------
- # decorators
+ # construct a View object
#--------------------------------------------------------------------------
@defaultblock
- def parallel(self, bound=True, targets='all', block=None):
- """Decorator for making a ParallelFunction."""
- return parallel(self, bound=bound, targets=targets, block=block)
+ def remote(self, bound=True, block=None, targets=None, balanced=None):
+ """Decorator for making a RemoteFunction"""
+ return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
@defaultblock
- def remote(self, bound=True, targets='all', block=None):
- """Decorator for making a RemoteFunction."""
- return remote(self, bound=bound, targets=targets, block=block)
+ def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
+ """Decorator for making a ParallelFunction"""
+ return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
- def view(self, targets=None, balanced=False):
- """Method for constructing View objects"""
+ def _cache_view(self, targets, balanced):
+ """save views, so subsequent requests don't create new objects."""
+ if balanced:
+ view_class = LoadBalancedView
+ view_cache = self._balanced_views
+ else:
+ view_class = DirectView
+ view_cache = self._direct_views
+
+ # use str, since often targets will be a list
+ key = str(targets)
+ if key not in view_cache:
+ view_cache[key] = view_class(client=self, targets=targets)
+
+ return view_cache[key]
+
+ def view(self, targets=None, balanced=None):
+ """Method for constructing View objects.
+
+ If no arguments are specified, create a LoadBalancedView
+ using all engines. If only `targets` specified, it will
+ be a DirectView. This method is the underlying implementation
+ of ``client.__getitem__``.
+
+ Parameters
+ ----------
+
+ targets: list,slice,int,etc. [default: use all engines]
+ The engines to use for the View
+ balanced : bool [default: False if targets specified, True else]
+ whether to build a LoadBalancedView or a DirectView
+
+ """
+
+ balanced = (targets is None) if balanced is None else balanced
+
if targets is None:
if balanced:
- return LoadBalancedView(client=self)
+ return self._cache_view(None,True)
else:
targets = slice(None)
- if balanced:
- view_class = LoadBalancedView
- else:
- view_class = DirectView
if isinstance(targets, int):
if targets not in self.ids:
raise IndexError("No such engine: %i"%targets)
- return view_class(client=self, targets=targets)
+ return self._cache_view(targets, balanced)
if isinstance(targets, slice):
indices = range(len(self.ids))[targets]
@@ -1095,7 +1145,7 @@ def view(self, targets=None, balanced=False):
if isinstance(targets, (tuple, list, xrange)):
_,targets = self._build_targets(list(targets))
- return view_class(client=self, targets=targets)
+ return self._cache_view(targets, balanced)
else:
raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
@@ -10,14 +10,16 @@
# Imports
#-----------------------------------------------------------------------------
+import warnings
+
import map as Map
from asyncresult import AsyncMapResult
#-----------------------------------------------------------------------------
# Decorators
#-----------------------------------------------------------------------------
-def remote(client, bound=False, block=None, targets=None, balanced=None):
+def remote(client, bound=True, block=None, targets=None, balanced=None):
"""Turn a function into a remote function.
This method can be used for map:
@@ -29,7 +31,7 @@ def remote_function(f):
return RemoteFunction(client, f, bound, block, targets, balanced)
return remote_function
-def parallel(client, dist='b', bound=False, block=None, targets='all', balanced=None):
+def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None):
"""Turn a function into a parallel remote function.
This method can be used for map:
@@ -93,8 +95,10 @@ def __call__(self, *args, **kwargs):
class ParallelFunction(RemoteFunction):
"""Class for mapping a function to sequences."""
- def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None):
+ def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None):
super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
+ self.chunk_size = chunk_size
+
mapClass = Map.dists[dist]
self.mapObject = mapClass()
@@ -106,12 +110,18 @@ def __call__(self, *sequences):
raise ValueError(msg)
if self.balanced:
- targets = [self.targets]*len_0
+ if self.chunk_size:
+ nparts = len_0/self.chunk_size + int(len_0%self.chunk_size > 0)
+ else:
+ nparts = len_0
+ targets = [self.targets]*nparts
else:
+ if self.chunk_size:
+ warnings.warn("`chunk_size` is ignored when `balanced=False", UserWarning)
# multiplexed:
targets = self.client._build_targets(self.targets)[-1]
+ nparts = len(targets)
- nparts = len(targets)
msg_ids = []
# my_f = lambda *a: map(self.func, *a)
for index, t in enumerate(targets):
@@ -132,7 +142,7 @@ def __call__(self, *sequences):
else:
f=self.func
ar = self.client.apply(f, args=args, block=False, bound=self.bound,
- targets=targets, balanced=self.balanced)
+ targets=t, balanced=self.balanced)
msg_ids.append(ar.msg_ids[0])
Oops, something went wrong.

0 comments on commit 154798b

Please sign in to comment.