Permalink
Browse files

add message tracking to client, add/improve tests

  • Loading branch information...
1 parent 864a845 commit 62f8971bbbc8c0fd8dab94599715f45dc2e744a2 @minrk minrk committed Mar 16, 2011
@@ -34,13 +34,17 @@ class AsyncResult(object):
"""
msg_ids = None
+ _targets = None
+ _tracker = None
- def __init__(self, client, msg_ids, fname='unknown'):
+ def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
self._client = client
if isinstance(msg_ids, basestring):
msg_ids = [msg_ids]
self.msg_ids = msg_ids
self._fname=fname
+ self._targets = targets
+ self._tracker = tracker
self._ready = False
self._success = None
self._single_result = len(msg_ids) == 1
@@ -169,6 +173,19 @@ def result_dict(self):
def __dict__(self):
return self.get_dict(0)
+
+ def abort(self):
+ """abort my tasks."""
+ assert not self.ready(), "Can't abort, I am already done!"
+ return self.client.abort(self.msg_ids, targets=self._targets, block=True)
+
+ @property
+ def sent(self):
+ """check whether my messages have been sent"""
+ if self._tracker is None:
+ return True
+ else:
+ return self._tracker.done
#-------------------------------------
# dict-access
@@ -356,6 +356,9 @@ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipytho
'apply_reply' : self._handle_apply_reply}
self._connect(sshserver, ssh_kwargs)
+ def __del__(self):
+ """cleanup sockets, but _not_ context."""
+ self.close()
def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
if ipython_dir is None:
@@ -387,7 +390,8 @@ def close(self):
return
snames = filter(lambda n: n.endswith('socket'), dir(self))
for socket in map(lambda name: getattr(self, name), snames):
- socket.close()
+ if isinstance(socket, zmq.Socket) and not socket.closed:
+ socket.close()
self._closed = True
def _update_engines(self, engines):
@@ -550,7 +554,6 @@ def _handle_stranded_msgs(self, eid, uuid):
outstanding = self._outstanding_dict[uuid]
for msg_id in list(outstanding):
- print msg_id
if msg_id in self.results:
# we already
continue
@@ -796,7 +799,7 @@ def clear(self, targets=None, block=None):
if msg['content']['status'] != 'ok':
error = self._unwrap_exception(msg['content'])
if error:
- return error
+ raise error
@spinfirst
@@ -840,7 +843,7 @@ def abort(self, jobs=None, targets=None, block=None):
if msg['content']['status'] != 'ok':
error = self._unwrap_exception(msg['content'])
if error:
- return error
+ raise error
@spinfirst
@defaultblock
@@ -945,7 +948,8 @@ def _build_dependency(self, dep):
@defaultblock
def apply(self, f, args=None, kwargs=None, bound=True, block=None,
targets=None, balanced=None,
- after=None, follow=None, timeout=None):
+ after=None, follow=None, timeout=None,
+ track=False):
"""Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
This is the central execution command for the client.
@@ -1003,6 +1007,9 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
Specify an amount of time (in seconds) for the scheduler to
wait for dependencies to be met before failing with a
DependencyTimeout.
+ track : bool
+ whether to track non-copying sends.
+ [default False]
after,follow,timeout only used if `balanced=True`.
@@ -1044,7 +1051,7 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be dict, not %s"%type(kwargs))
- options = dict(bound=bound, block=block, targets=targets)
+ options = dict(bound=bound, block=block, targets=targets, track=track)
if balanced:
return self._apply_balanced(f, args, kwargs, timeout=timeout,
@@ -1057,15 +1064,15 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
return self._apply_direct(f, args, kwargs, **options)
def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
- after=None, follow=None, timeout=None):
+ after=None, follow=None, timeout=None, track=None):
"""call f(*args, **kwargs) remotely in a load-balanced manner.
This is a private method, see `apply` for details.
Not to be called directly!
"""
loc = locals()
- for name in ('bound', 'block'):
+ for name in ('bound', 'block', 'track'):
assert loc[name] is not None, "kwarg %r must be specified!"%name
if self._task_socket is None:
@@ -1101,13 +1108,13 @@ def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
content = dict(bound=bound)
msg = self.session.send(self._task_socket, "apply_request",
- content=content, buffers=bufs, subheader=subheader)
+ content=content, buffers=bufs, subheader=subheader, track=track)
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
self.history.append(msg_id)
self.metadata[msg_id]['submitted'] = datetime.now()
-
- ar = AsyncResult(self, [msg_id], fname=f.__name__)
+ tracker = None if track is False else msg['tracker']
+ ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
if block:
try:
return ar.get()
@@ -1116,15 +1123,16 @@ def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
else:
return ar
- def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
+ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
+ track=None):
"""Then underlying method for applying functions to specific engines
via the MUX queue.
This is a private method, see `apply` for details.
Not to be called directly!
"""
loc = locals()
- for name in ('bound', 'block', 'targets'):
+ for name in ('bound', 'block', 'targets', 'track'):
assert loc[name] is not None, "kwarg %r must be specified!"%name
idents,targets = self._build_targets(targets)
@@ -1134,15 +1142,22 @@ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
bufs = util.pack_apply_message(f,args,kwargs)
msg_ids = []
+ trackers = []
for ident in idents:
msg = self.session.send(self._mux_socket, "apply_request",
- content=content, buffers=bufs, ident=ident, subheader=subheader)
+ content=content, buffers=bufs, ident=ident, subheader=subheader,
+ track=track)
+ if track:
+ trackers.append(msg['tracker'])
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
self._outstanding_dict[ident].add(msg_id)
self.history.append(msg_id)
msg_ids.append(msg_id)
- ar = AsyncResult(self, msg_ids, fname=f.__name__)
+
+ tracker = None if track is False else zmq.MessageTracker(*trackers)
+ ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
+
if block:
try:
return ar.get()
@@ -1230,11 +1245,11 @@ def view(self, targets=None, balanced=None):
#--------------------------------------------------------------------------
@defaultblock
- def push(self, ns, targets='all', block=None):
+ def push(self, ns, targets='all', block=None, track=False):
"""Push the contents of `ns` into the namespace on `target`"""
if not isinstance(ns, dict):
raise TypeError("Must be a dict, not %s"%type(ns))
- result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
+ result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track)
if not block:
return result
@@ -1251,24 +1266,33 @@ def pull(self, keys, targets='all', block=None):
return result
@defaultblock
- def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
+ def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
"""
Partition a Python sequence and send the partitions to a set of engines.
"""
targets = self._build_targets(targets)[-1]
mapObject = Map.dists[dist]()
nparts = len(targets)
msg_ids = []
+ trackers = []
for index, engineid in enumerate(targets):
partition = mapObject.getPartition(seq, index, nparts)
if flatten and len(partition) == 1:
- r = self.push({key: partition[0]}, targets=engineid, block=False)
+ r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
else:
- r = self.push({key: partition}, targets=engineid, block=False)
+ r = self.push({key: partition}, targets=engineid, block=False, track=track)
msg_ids.extend(r.msg_ids)
- r = AsyncResult(self, msg_ids, fname='scatter')
+ if track:
+ trackers.append(r._tracker)
+
+ if track:
+ tracker = zmq.MessageTracker(*trackers)
+ else:
+ tracker = None
+
+ r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
if block:
- r.get()
+ r.wait()
else:
return r
@@ -179,7 +179,7 @@ def check_key(self, msg_or_header):
return header.get('key', None) == self.key
- def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
+ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
"""Build and send a message via stream or socket.
Parameters
@@ -191,20 +191,42 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
Normally, msg_or_type will be a msg_type unless a message is being sent more
than once.
+ content : dict or None
+ the content of the message (ignored if msg_or_type is a message)
+ buffers : list or None
+ the already-serialized buffers to be appended to the message
+ parent : Message or dict or None
+ the parent or parent header describing the parent of this message
+ subheader : dict or None
+ extra header keys for this message's header
+ ident : bytes or list of bytes
+ the zmq.IDENTITY routing path
+ track : bool
+ whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
+
Returns
-------
- (msg,sent) : tuple
- msg : Message
- the nice wrapped dict-like object containing the headers
+ msg : message dict
+ the constructed message
+ (msg,tracker) : (message dict, MessageTracker)
+ if track=True, then a 2-tuple will be returned, the first element being the constructed
+ message, and the second being the MessageTracker
"""
+
+ if not isinstance(stream, (zmq.Socket, ZMQStream)):
+ raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
+ elif track and isinstance(stream, ZMQStream):
+ raise TypeError("ZMQStream cannot track messages")
+
if isinstance(msg_or_type, (Message, dict)):
# we got a Message, not a msg_type
# don't build a new Message
msg = msg_or_type
content = msg['content']
else:
msg = self.msg(msg_or_type, content, parent, subheader)
+
buffers = [] if buffers is None else buffers
to_send = []
if isinstance(ident, list):
@@ -222,7 +244,7 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
content = self.none
elif isinstance(content, dict):
content = self.pack(content)
- elif isinstance(content, str):
+ elif isinstance(content, bytes):
# content is already packed, as in a relayed message
pass
else:
@@ -231,16 +253,29 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
flag = 0
if buffers:
flag = zmq.SNDMORE
- stream.send_multipart(to_send, flag, copy=False)
+ _track = False
+ else:
+ _track=track
+ if track:
+ tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
+ else:
+ tracker = stream.send_multipart(to_send, flag, copy=False)
for b in buffers[:-1]:
stream.send(b, flag, copy=False)
if buffers:
- stream.send(buffers[-1], copy=False)
+ if track:
+ tracker = stream.send(buffers[-1], copy=False, track=track)
+ else:
+ tracker = stream.send(buffers[-1], copy=False)
+
# omsg = Message(msg)
if self.debug:
pprint.pprint(msg)
pprint.pprint(to_send)
pprint.pprint(buffers)
+
+ msg['tracker'] = tracker
+
return msg
def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
@@ -250,7 +285,7 @@ def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
----------
msg : list of sendable buffers"""
to_send = []
- if isinstance(ident, str):
+ if isinstance(ident, bytes):
ident = [ident]
if ident is not None:
to_send.extend(ident)
@@ -1,24 +1,26 @@
"""toplevel setup/teardown for parallel tests."""
+import tempfile
import time
-from subprocess import Popen, PIPE
+from subprocess import Popen, PIPE, STDOUT
from IPython.zmq.parallel.ipcluster import launch_process
from IPython.zmq.parallel.entry_point import select_random_ports
processes = []
+blackhole = tempfile.TemporaryFile()
# nose setup/teardown
def setup():
- cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
+ cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
processes.append(cp)
time.sleep(.5)
add_engine()
- time.sleep(3)
+ time.sleep(2)
def add_engine(profile='iptest'):
- ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
+ ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
# ep.start()
processes.append(ep)
return ep
@@ -88,7 +88,9 @@ def setUp(self):
self.base_engine_count=len(self.client.ids)
self.engines=[]
- # def tearDown(self):
+ def tearDown(self):
+ self.client.close()
+ BaseZMQTestCase.tearDown(self)
# [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
# [ e.wait() for e in self.engines ]
# while len(self.client.ids) > self.base_engine_count:
Oops, something went wrong.

0 comments on commit 62f8971

Please sign in to comment.