Skip to content
Browse files

major cleanup of client code + purge_request implemented

  • Loading branch information...
1 parent 9715461 commit 077c1af7d76b072fbdf39213493df7e8c85c6609 @minrk minrk committed
Showing with 362 additions and 200 deletions.
  1. +362 −200 IPython/zmq/parallel/client.py
View
562 IPython/zmq/parallel/client.py
@@ -30,7 +30,7 @@ def _push(ns):
def _pull(keys):
g = globals()
- if isinstance(keys, (list,tuple)):
+ if isinstance(keys, (list,tuple, set)):
return map(g.get, keys)
else:
return g.get(keys)
@@ -41,14 +41,19 @@ def _clear():
def execute(code):
exec code in globals()
-# decorators for methods:
+#--------------------------------------------------------------------------
+# Decorators for Client methods
+#--------------------------------------------------------------------------
+
@decorator
-def spinfirst(f,self,*args,**kwargs):
+def spinfirst(f, self, *args, **kwargs):
+ """Call spin() to sync state prior to calling the method."""
self.spin()
return f(self, *args, **kwargs)
@decorator
def defaultblock(f, self, *args, **kwargs):
+ """Default to self.block; preserve self.block."""
block = kwargs.get('block',None)
block = self.block if block is None else block
saveblock = self.block
@@ -57,45 +62,46 @@ def defaultblock(f, self, *args, **kwargs):
self.block = saveblock
return ret
+#--------------------------------------------------------------------------
+# Classes
+#--------------------------------------------------------------------------
+
+
class AbortedTask(object):
+ """A basic wrapper object describing an aborted task."""
def __init__(self, msg_id):
self.msg_id = msg_id
-# @decorator
-# def checktargets(f):
-# @wraps(f)
-# def checked_method(self, *args, **kwargs):
-# self._build_targets(kwargs['targets'])
-# return f(self, *args, **kwargs)
-# return checked_method
-
-# class _ZMQEventLoopThread(threading.Thread):
-#
-# def __init__(self, loop):
-# self.loop = loop
-# threading.Thread.__init__(self)
-#
-# def run(self):
-# self.loop.start()
-#
+class ControllerError(Exception):
+ def __init__(self, etype, evalue, tb):
+ self.etype = etype
+ self.evalue = evalue
+ self.traceback=tb
+
class Client(object):
"""A semi-synchronous client to the IPython ZMQ controller
+ Parameters
+ ----------
+
+ addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101
+ The address of the controller's registration socket.
+
+
Attributes
----------
- ids : set
- a set of engine IDs
+ ids : set of int engine IDs
requesting the ids attribute always synchronizes
the registration state. To request ids without synchronization,
- use _ids
+ use semi-private _ids.
history : list of msg_ids
a list of msg_ids, keeping track of all the execution
- messages you have submitted
+ messages you have submitted in order.
outstanding : set of msg_ids
a set of msg_ids that have been submitted, but whose
- results have not been received
+ results have not yet been received.
results : dict
a dict of all our results, keyed by msg_id
@@ -111,44 +117,43 @@ class Client(object):
barrier : wait on one or more msg_ids
- execution methods: apply/apply_bound/apply_to
+ execution methods: apply/apply_bound/apply_to/applu_bount
legacy: execute, run
- query methods: queue_status, get_result
+ query methods: queue_status, get_result, purge
control methods: abort, kill
-
-
"""
_connected=False
_engines=None
- registration_socket=None
- query_socket=None
- control_socket=None
- notification_socket=None
- queue_socket=None
- task_socket=None
+ _addr='tcp://127.0.0.1:10101'
+ _registration_socket=None
+ _query_socket=None
+ _control_socket=None
+ _notification_socket=None
+ _mux_socket=None
+ _task_socket=None
block = False
outstanding=None
results = None
history = None
debug = False
- def __init__(self, addr, context=None, username=None, debug=False):
+ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False):
if context is None:
context = zmq.Context()
self.context = context
- self.addr = addr
+ self._addr = addr
if username is None:
self.session = ss.StreamSession()
else:
self.session = ss.StreamSession(username)
- self.registration_socket = self.context.socket(zmq.PAIR)
- self.registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
- self.registration_socket.connect(addr)
+ self._registration_socket = self.context.socket(zmq.PAIR)
+ self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
+ self._registration_socket.connect(addr)
self._engines = {}
self._ids = set()
self.outstanding=set()
@@ -167,16 +172,21 @@ def __init__(self, addr, context=None, username=None, debug=False):
@property
def ids(self):
+ """Always up to date ids property."""
self._flush_notifications()
return self._ids
def _update_engines(self, engines):
+ """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
for k,v in engines.iteritems():
eid = int(k)
self._engines[eid] = bytes(v) # force not unicode
self._ids.add(eid)
def _build_targets(self, targets):
+ """Turn valid target IDs or 'all' into two lists:
+ (int_ids, uuids).
+ """
if targets is None:
targets = self._ids
elif isinstance(targets, str):
@@ -189,45 +199,50 @@ def _build_targets(self, targets):
return [self._engines[t] for t in targets], list(targets)
def _connect(self):
- """setup all our socket connections to the controller"""
+ """setup all our socket connections to the controller. This is called from
+ __init__."""
if self._connected:
return
self._connected=True
- self.session.send(self.registration_socket, 'connection_request')
- idents,msg = self.session.recv(self.registration_socket,mode=0)
+ self.session.send(self._registration_socket, 'connection_request')
+ idents,msg = self.session.recv(self._registration_socket,mode=0)
if self.debug:
pprint(msg)
msg = ss.Message(msg)
content = msg.content
if content.status == 'ok':
if content.queue:
- self.queue_socket = self.context.socket(zmq.PAIR)
- self.queue_socket.setsockopt(zmq.IDENTITY, self.session.session)
- self.queue_socket.connect(content.queue)
+ self._mux_socket = self.context.socket(zmq.PAIR)
+ self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
+ self._mux_socket.connect(content.queue)
if content.task:
- self.task_socket = self.context.socket(zmq.PAIR)
- self.task_socket.setsockopt(zmq.IDENTITY, self.session.session)
- self.task_socket.connect(content.task)
+ self._task_socket = self.context.socket(zmq.PAIR)
+ self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
+ self._task_socket.connect(content.task)
if content.notification:
- self.notification_socket = self.context.socket(zmq.SUB)
- self.notification_socket.connect(content.notification)
- self.notification_socket.setsockopt(zmq.SUBSCRIBE, "")
+ self._notification_socket = self.context.socket(zmq.SUB)
+ self._notification_socket.connect(content.notification)
+ self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
if content.query:
- self.query_socket = self.context.socket(zmq.PAIR)
- self.query_socket.setsockopt(zmq.IDENTITY, self.session.session)
- self.query_socket.connect(content.query)
+ self._query_socket = self.context.socket(zmq.PAIR)
+ self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
+ self._query_socket.connect(content.query)
if content.control:
- self.control_socket = self.context.socket(zmq.PAIR)
- self.control_socket.setsockopt(zmq.IDENTITY, self.session.session)
- self.control_socket.connect(content.control)
+ self._control_socket = self.context.socket(zmq.PAIR)
+ self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
+ self._control_socket.connect(content.control)
self._update_engines(dict(content.engines))
else:
self._connected = False
raise Exception("Failed to connect!")
- #### handlers and callbacks for incoming messages #######
+ #--------------------------------------------------------------------------
+ # handlers and callbacks for incoming messages
+ #--------------------------------------------------------------------------
+
def _register_engine(self, msg):
+ """Register a new engine, and update our connection info."""
content = msg['content']
eid = content['id']
d = {eid : content['queue']}
@@ -235,7 +250,7 @@ def _register_engine(self, msg):
self._ids.add(int(eid))
def _unregister_engine(self, msg):
- # print 'unregister',msg
+ """Unregister an engine that has died."""
content = msg['content']
eid = int(content['id'])
if eid in self._ids:
@@ -243,7 +258,7 @@ def _unregister_engine(self, msg):
self._engines.pop(eid)
def _handle_execute_reply(self, msg):
- # msg_id = msg['msg_id']
+ """Save the reply to an execute_request into our results."""
parent = msg['parent_header']
msg_id = parent['msg_id']
if msg_id not in self.outstanding:
@@ -253,8 +268,7 @@ def _handle_execute_reply(self, msg):
self.results[msg_id] = ss.unwrap_exception(msg['content'])
def _handle_apply_reply(self, msg):
- # pprint(msg)
- # msg_id = msg['msg_id']
+ """Save the reply to an apply_request into our results."""
parent = msg['parent_header']
msg_id = parent['msg_id']
if msg_id not in self.outstanding:
@@ -272,8 +286,9 @@ def _handle_apply_reply(self, msg):
self.results[msg_id] = ss.unwrap_exception(content)
def _flush_notifications(self):
- "flush incoming notifications of engine registrations"
- msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
+ """Flush notifications of engine registrations waiting
+ in ZMQ queue."""
+ msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
while msg is not None:
if self.debug:
pprint(msg)
@@ -284,10 +299,10 @@ def _flush_notifications(self):
raise Exception("Unhandled message type: %s"%msg.msg_type)
else:
handler(msg)
- msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
+ msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
def _flush_results(self, sock):
- "flush incoming task or queue results"
+ """Flush task or queue results waiting in ZMQ queue."""
msg = self.session.recv(sock, mode=zmq.NOBLOCK)
while msg is not None:
if self.debug:
@@ -302,16 +317,20 @@ def _flush_results(self, sock):
msg = self.session.recv(sock, mode=zmq.NOBLOCK)
def _flush_control(self, sock):
- "flush incoming control replies"
+ """Flush replies from the control channel waiting
+ in the ZMQ queue."""
msg = self.session.recv(sock, mode=zmq.NOBLOCK)
while msg is not None:
if self.debug:
pprint(msg)
msg = self.session.recv(sock, mode=zmq.NOBLOCK)
- ###### get/setitem ########
+ #--------------------------------------------------------------------------
+ # getitem
+ #--------------------------------------------------------------------------
def __getitem__(self, key):
+ """Dict access returns DirectView multiplexer objects."""
if isinstance(key, int):
if key not in self.ids:
raise IndexError("No such engine: %i"%key)
@@ -329,51 +348,75 @@ def __getitem__(self, key):
else:
raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
- ############ begin real methods #############
+ #--------------------------------------------------------------------------
+ # Begin public methods
+ #--------------------------------------------------------------------------
def spin(self):
- """flush incoming notifications and execution results."""
- if self.notification_socket:
+ """Flush any registration notifications and execution results
+ waiting in the ZMQ queue.
+ """
+ if self._notification_socket:
self._flush_notifications()
- if self.queue_socket:
- self._flush_results(self.queue_socket)
- if self.task_socket:
- self._flush_results(self.task_socket)
- if self.control_socket:
- self._flush_control(self.control_socket)
+ if self._mux_socket:
+ self._flush_results(self._mux_socket)
+ if self._task_socket:
+ self._flush_results(self._task_socket)
+ if self._control_socket:
+ self._flush_control(self._control_socket)
- @spinfirst
- def queue_status(self, targets=None, verbose=False):
- """fetch the status of engine queues
+ def barrier(self, msg_ids=None, timeout=-1):
+ """waits on one or more `msg_ids`, for up to `timeout` seconds.
Parameters
----------
- targets : int/str/list of ints/strs
- the engines on which to execute
- default : all
- verbose : bool
- whether to return lengths only, or lists of ids for each element
-
- """
- targets = self._build_targets(targets)[1]
- content = dict(targets=targets)
- self.session.send(self.query_socket, "queue_request", content=content)
- idents,msg = self.session.recv(self.query_socket, 0)
- if self.debug:
- pprint(msg)
- return msg['content']
+ msg_ids : int, str, or list of ints and/or strs
+ ints are indices to self.history
+ strs are msg_ids
+ default: wait on all outstanding messages
+ timeout : float
+ a time in seconds, after which to give up.
+ default is -1, which means no timeout
+ Returns
+ -------
+ True : when all msg_ids are done
+ False : timeout reached, some msg_ids still outstanding
+ """
+ tic = time.time()
+ if msg_ids is None:
+ theids = self.outstanding
+ else:
+ if isinstance(msg_ids, (int, str)):
+ msg_ids = [msg_ids]
+ theids = set()
+ for msg_id in msg_ids:
+ if isinstance(msg_id, int):
+ msg_id = self.history[msg_id]
+ theids.add(msg_id)
+ self.spin()
+ while theids.intersection(self.outstanding):
+ if timeout >= 0 and ( time.time()-tic ) > timeout:
+ break
+ time.sleep(1e-3)
+ self.spin()
+ return len(theids.intersection(self.outstanding)) == 0
+
+ #--------------------------------------------------------------------------
+ # Control methods
+ #--------------------------------------------------------------------------
+
@spinfirst
@defaultblock
def clear(self, targets=None, block=None):
- """clear the namespace in target(s)"""
+ """Clear the namespace in target(s)."""
targets = self._build_targets(targets)[0]
for t in targets:
- self.session.send(self.control_socket, 'clear_request', content={},ident=t)
+ self.session.send(self._control_socket, 'clear_request', content={},ident=t)
error = False
if self.block:
for i in range(len(targets)):
- idents,msg = self.session.recv(self.control_socket,0)
+ idents,msg = self.session.recv(self._control_socket,0)
if self.debug:
pprint(msg)
if msg['content']['status'] != 'ok':
@@ -385,18 +428,18 @@ def clear(self, targets=None, block=None):
@spinfirst
@defaultblock
def abort(self, msg_ids = None, targets=None, block=None):
- """abort the Queues of target(s)"""
+ """Abort the execution queues of target(s)."""
targets = self._build_targets(targets)[0]
if isinstance(msg_ids, basestring):
msg_ids = [msg_ids]
content = dict(msg_ids=msg_ids)
for t in targets:
- self.session.send(self.control_socket, 'abort_request',
+ self.session.send(self._control_socket, 'abort_request',
content=content, ident=t)
error = False
if self.block:
for i in range(len(targets)):
- idents,msg = self.session.recv(self.control_socket,0)
+ idents,msg = self.session.recv(self._control_socket,0)
if self.debug:
pprint(msg)
if msg['content']['status'] != 'ok':
@@ -410,21 +453,25 @@ def kill(self, targets=None, block=None):
"""Terminates one or more engine processes."""
targets = self._build_targets(targets)[0]
for t in targets:
- self.session.send(self.control_socket, 'kill_request', content={},ident=t)
+ self.session.send(self._control_socket, 'kill_request', content={},ident=t)
error = False
if self.block:
for i in range(len(targets)):
- idents,msg = self.session.recv(self.control_socket,0)
+ idents,msg = self.session.recv(self._control_socket,0)
if self.debug:
pprint(msg)
if msg['content']['status'] != 'ok':
error = msg['content']
if error:
return error
-
+
+ #--------------------------------------------------------------------------
+ # Execution methods
+ #--------------------------------------------------------------------------
+
@defaultblock
def execute(self, code, targets='all', block=None):
- """executes `code` on `targets` in blocking or nonblocking manner.
+ """Executes `code` on `targets` in blocking or nonblocking manner.
Parameters
----------
@@ -434,7 +481,8 @@ def execute(self, code, targets='all', block=None):
the engines on which to execute
default : all
block : bool
- whether or not to wait until done
+ whether or not to wait until done to return
+ default: self.block
"""
# block = self.block if block is None else block
# saveblock = self.block
@@ -444,7 +492,7 @@ def execute(self, code, targets='all', block=None):
return result
def run(self, code, block=None):
- """runs `code` on an engine.
+ """Runs `code` on an engine.
Calls to this are load-balanced.
@@ -459,11 +507,96 @@ def run(self, code, block=None):
result = self.apply(execute, (code,), targets=None, block=block, bound=False)
return result
+ def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
+ after=None, follow=None):
+ """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
+
+ This is the central execution command for the client.
+
+ Parameters
+ ----------
+
+ f : function
+ The fuction to be called remotely
+ args : tuple/list
+ The positional arguments passed to `f`
+ kwargs : dict
+ The keyword arguments passed to `f`
+ bound : bool (default: True)
+ Whether to execute in the Engine(s) namespace, or in a clean
+ namespace not affecting the engine.
+ block : bool (default: self.block)
+ Whether to wait for the result, or return immediately.
+ False:
+ returns msg_id(s)
+ if multiple targets:
+ list of ids
+ True:
+ returns actual result(s) of f(*args, **kwargs)
+ if multiple targets:
+ dict of results, by engine ID
+ targets : int,list of ints, 'all', None
+ Specify the destination of the job.
+ if None:
+ Submit via Task queue for load-balancing.
+ if 'all':
+ Run on all active engines
+ if list:
+ Run on each specified engine
+ if int:
+ Run on single engine
+
+ after : Dependency or collection of msg_ids
+ Only for load-balanced execution (targets=None)
+ Specify a list of msg_ids as a time-based dependency.
+ This job will only be run *after* the dependencies
+ have been met.
+
+ follow : Dependency or collection of msg_ids
+ Only for load-balanced execution (targets=None)
+ Specify a list of msg_ids as a location-based dependency.
+ This job will only be run on an engine where this dependency
+ is met.
+
+ Returns
+ -------
+ if block is False:
+ if single target:
+ return msg_id
+ else:
+ return list of msg_ids
+ ? (should this be dict like block=True) ?
+ else:
+ if single target:
+ return result of f(*args, **kwargs)
+ else:
+ return dict of results, keyed by engine
+ """
+
+ # defaults:
+ block = block if block is not None else self.block
+ args = args if args is not None else []
+ kwargs = kwargs if kwargs is not None else {}
+
+ # enforce types of f,args,kwrags
+ if not callable(f):
+ raise TypeError("f must be callable, not %s"%type(f))
+ if not isinstance(args, (tuple, list)):
+ raise TypeError("args must be tuple or list, not %s"%type(args))
+ if not isinstance(kwargs, dict):
+ raise TypeError("kwargs must be dict, not %s"%type(kwargs))
+
+ options = dict(bound=bound, block=block, after=after, follow=follow)
+
+ if targets is None:
+ return self._apply_balanced(f, args, kwargs, **options)
+ else:
+ return self._apply_direct(f, args, kwargs, targets=targets, **options)
+
def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
after=None, follow=None):
- """the underlying method for applying functions in a load balanced
- manner."""
- block = block if block is not None else self.block
+ """The underlying method for applying functions in a load balanced
+ manner, via the task queue."""
if isinstance(after, Dependency):
after = after.as_dict()
elif after is None:
@@ -476,7 +609,7 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
bufs = ss.pack_apply_message(f,args,kwargs)
content = dict(bound=bound)
- msg = self.session.send(self.task_socket, "apply_request",
+ msg = self.session.send(self._task_socket, "apply_request",
content=content, buffers=bufs, subheader=subheader)
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
@@ -489,9 +622,8 @@ def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
after=None, follow=None):
- """Then underlying method for applying functions to specific engines."""
-
- block = block if block is not None else self.block
+ """Then underlying method for applying functions to specific engines
+ via the MUX queue."""
queues,targets = self._build_targets(targets)
bufs = ss.pack_apply_message(f,args,kwargs)
@@ -507,7 +639,7 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
content = dict(bound=bound)
msg_ids = []
for queue in queues:
- msg = self.session.send(self.queue_socket, "apply_request",
+ msg = self.session.send(self._mux_socket, "apply_request",
content=content, buffers=bufs,ident=queue, subheader=subheader)
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
@@ -528,115 +660,145 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
result[target] = self.results[mid]
return result
- def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
- after=None, follow=None):
- """calls f(*args, **kwargs) on a remote engine(s), returning the result.
-
- if self.block is False:
- returns msg_id or list of msg_ids
- else:
- returns actual result of f(*args, **kwargs)
- """
- # enforce types of f,args,kwrags
- args = args if args is not None else []
- kwargs = kwargs if kwargs is not None else {}
- if not callable(f):
- raise TypeError("f must be callable, not %s"%type(f))
- if not isinstance(args, (tuple, list)):
- raise TypeError("args must be tuple or list, not %s"%type(args))
- if not isinstance(kwargs, dict):
- raise TypeError("kwargs must be dict, not %s"%type(kwargs))
-
- options = dict(bound=bound, block=block, after=after, follow=follow)
-
- if targets is None:
- return self._apply_balanced(f, args, kwargs, **options)
- else:
- return self._apply_direct(f, args, kwargs, targets=targets, **options)
+ #--------------------------------------------------------------------------
+ # Data movement
+ #--------------------------------------------------------------------------
+ @defaultblock
def push(self, ns, targets=None, block=None):
- """push the contents of `ns` into the namespace on `target`"""
+ """Push the contents of `ns` into the namespace on `target`"""
if not isinstance(ns, dict):
raise TypeError("Must be a dict, not %s"%type(ns))
- result = self.apply(_push, (ns,), targets=targets, block=block,bound=True)
+ result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
return result
- @spinfirst
+ @defaultblock
def pull(self, keys, targets=None, block=True):
- """pull objects from `target`'s namespace by `keys`"""
-
+ """Pull objects from `target`'s namespace by `keys`"""
+ if isinstance(keys, str):
+ pass
+ elif isistance(keys, (list,tuple,set)):
+ for key in keys:
+ if not isinstance(key, str):
+ raise TypeError
result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
return result
- def barrier(self, msg_ids=None, timeout=-1):
- """waits on one or more `msg_ids`, for up to `timeout` seconds.
+ #--------------------------------------------------------------------------
+ # Query methods
+ #--------------------------------------------------------------------------
+
+ @spinfirst
+ def get_results(self, msg_ids, status_only=False):
+ """Returns the result of the execute or task request with `msg_ids`.
Parameters
----------
- msg_ids : int, str, or list of ints and/or strs
- ints are indices to self.history
- strs are msg_ids
- default: wait on all outstanding messages
- timeout : float
- a time in seconds, after which to give up.
- default is -1, which means no timeout
-
- Returns
- -------
- True : when all msg_ids are done
- False : timeout reached, msg_ids still outstanding
+ msg_ids : list of ints or msg_ids
+ if int:
+ Passed as index to self.history for convenience.
+ status_only : bool (default: False)
+ if False:
+ return the actual results
"""
- tic = time.time()
- if msg_ids is None:
- theids = self.outstanding
- else:
- if isinstance(msg_ids, (int, str)):
- msg_ids = [msg_ids]
- theids = set()
- for msg_id in msg_ids:
- if isinstance(msg_id, int):
- msg_id = self.history[msg_id]
- theids.add(msg_id)
- self.spin()
- while theids.intersection(self.outstanding):
- if timeout >= 0 and ( time.time()-tic ) > timeout:
- break
- time.sleep(1e-3)
- self.spin()
- return len(theids.intersection(self.outstanding)) == 0
-
- @spinfirst
- def get_results(self, msg_ids,status_only=False):
- """returns the result of the execute or task request with `msg_id`"""
if not isinstance(msg_ids, (list,tuple)):
msg_ids = [msg_ids]
theids = []
for msg_id in msg_ids:
if isinstance(msg_id, int):
msg_id = self.history[msg_id]
+ if not isinstance(msg_id, str):
+ raise TypeError("msg_ids must be str, not %r"%msg_id)
theids.append(msg_id)
- content = dict(msg_ids=theids, status_only=status_only)
- msg = self.session.send(self.query_socket, "result_request", content=content)
- zmq.select([self.query_socket], [], [])
- idents,msg = self.session.recv(self.query_socket, zmq.NOBLOCK)
+ completed = []
+ local_results = {}
+ for msg_id in list(theids):
+ if msg_id in self.results:
+ completed.append(msg_id)
+ local_results[msg_id] = self.results[msg_id]
+ theids.remove(msg_id)
+
+ if msg_ids: # some not locally cached
+ content = dict(msg_ids=theids, status_only=status_only)
+ msg = self.session.send(self._query_socket, "result_request", content=content)
+ zmq.select([self._query_socket], [], [])
+ idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
+ if self.debug:
+ pprint(msg)
+ content = msg['content']
+ if content['status'] != 'ok':
+ raise ss.unwrap_exception(content)
+ else:
+ content = dict(completed=[],pending=[])
+ if not status_only:
+ # load cached results into result:
+ content['completed'].extend(completed)
+ content.update(local_results)
+ # update cache with results:
+ for msg_id in msg_ids:
+ if msg_id in content['completed']:
+ self.results[msg_id] = content[msg_id]
+ return content
+
+ @spinfirst
+ def queue_status(self, targets=None, verbose=False):
+ """Fetch the status of engine queues.
+
+ Parameters
+ ----------
+ targets : int/str/list of ints/strs
+ the engines on which to execute
+ default : all
+ verbose : bool
+ whether to return lengths only, or lists of ids for each element
+ """
+ targets = self._build_targets(targets)[1]
+ content = dict(targets=targets, verbose=verbose)
+ self.session.send(self._query_socket, "queue_request", content=content)
+ idents,msg = self.session.recv(self._query_socket, 0)
if self.debug:
pprint(msg)
+ content = msg['content']
+ status = content.pop('status')
+ if status != 'ok':
+ raise ss.unwrap_exception(content)
+ return content
+
+ @spinfirst
+ def purge_results(self, msg_ids=[], targets=[]):
+ """Tell the controller to forget results.
- # while True:
- # try:
- # except zmq.ZMQError:
- # time.sleep(1e-3)
- # continue
- # else:
- # break
- return msg['content']
+ Individual results can be purged by msg_id, or the entire
+ history of specific targets can
+
+ Parameters
+ ----------
+ targets : int/str/list of ints/strs
+ the targets
+ default : None
+ """
+ if not targets and not msg_ids:
+ raise ValueError
+ if targets:
+ targets = self._build_targets(targets)[1]
+ content = dict(targets=targets, msg_ids=msg_ids)
+ self.session.send(self._query_socket, "purge_request", content=content)
+ idents, msg = self.session.recv(self._query_socket, 0)
+ if self.debug:
+ pprint(msg)
+ content = msg['content']
+ if content['status'] != 'ok':
+ raise ss.unwrap_exception(content)
class AsynClient(Client):
- """An Asynchronous client, using the Tornado Event Loop"""
+ """An Asynchronous client, using the Tornado Event Loop.
+ !!!unfinished!!!"""
io_loop = None
- queue_stream = None
- notifier_stream = None
+ _queue_stream = None
+ _notifier_stream = None
+ _task_stream = None
+ _control_stream = None
def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
Client.__init__(self, addr, context, username, debug)
@@ -644,10 +806,10 @@ def __init__(self, addr, context=None, username=None, debug=False, io_loop=None)
io_loop = ioloop.IOLoop.instance()
self.io_loop = io_loop
- self.queue_stream = zmqstream.ZMQStream(self.queue_socket, io_loop)
- self.control_stream = zmqstream.ZMQStream(self.control_socket, io_loop)
- self.task_stream = zmqstream.ZMQStream(self.task_socket, io_loop)
- self.notification_stream = zmqstream.ZMQStream(self.notification_socket, io_loop)
+ self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
+ self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
+ self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
+ self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
def spin(self):
for stream in (self.queue_stream, self.notifier_stream,

0 comments on commit 077c1af

Please sign in to comment.
Something went wrong with that request. Please try again.