Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

add message tracking to client, add/improve tests

  • Loading branch information...
commit 62f8971bbbc8c0fd8dab94599715f45dc2e744a2 1 parent 864a845
@minrk minrk authored
View
19 IPython/zmq/parallel/asyncresult.py
@@ -34,13 +34,17 @@ class AsyncResult(object):
"""
msg_ids = None
+ _targets = None
+ _tracker = None
- def __init__(self, client, msg_ids, fname='unknown'):
+ def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
self._client = client
if isinstance(msg_ids, basestring):
msg_ids = [msg_ids]
self.msg_ids = msg_ids
self._fname=fname
+ self._targets = targets
+ self._tracker = tracker
self._ready = False
self._success = None
self._single_result = len(msg_ids) == 1
@@ -169,6 +173,19 @@ def result_dict(self):
def __dict__(self):
return self.get_dict(0)
+
+ def abort(self):
+ """abort my tasks."""
+ assert not self.ready(), "Can't abort, I am already done!"
+ return self.client.abort(self.msg_ids, targets=self._targets, block=True)
+
+ @property
+ def sent(self):
+ """check whether my messages have been sent"""
+ if self._tracker is None:
+ return True
+ else:
+ return self._tracker.done
#-------------------------------------
# dict-access
View
68 IPython/zmq/parallel/client.py
@@ -356,6 +356,9 @@ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipytho
'apply_reply' : self._handle_apply_reply}
self._connect(sshserver, ssh_kwargs)
+ def __del__(self):
+ """cleanup sockets, but _not_ context."""
+ self.close()
def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
if ipython_dir is None:
@@ -387,7 +390,8 @@ def close(self):
return
snames = filter(lambda n: n.endswith('socket'), dir(self))
for socket in map(lambda name: getattr(self, name), snames):
- socket.close()
+ if isinstance(socket, zmq.Socket) and not socket.closed:
+ socket.close()
self._closed = True
def _update_engines(self, engines):
@@ -550,7 +554,6 @@ def _handle_stranded_msgs(self, eid, uuid):
outstanding = self._outstanding_dict[uuid]
for msg_id in list(outstanding):
- print msg_id
if msg_id in self.results:
# we already
continue
@@ -796,7 +799,7 @@ def clear(self, targets=None, block=None):
if msg['content']['status'] != 'ok':
error = self._unwrap_exception(msg['content'])
if error:
- return error
+ raise error
@spinfirst
@@ -840,7 +843,7 @@ def abort(self, jobs=None, targets=None, block=None):
if msg['content']['status'] != 'ok':
error = self._unwrap_exception(msg['content'])
if error:
- return error
+ raise error
@spinfirst
@defaultblock
@@ -945,7 +948,8 @@ def _build_dependency(self, dep):
@defaultblock
def apply(self, f, args=None, kwargs=None, bound=True, block=None,
targets=None, balanced=None,
- after=None, follow=None, timeout=None):
+ after=None, follow=None, timeout=None,
+ track=False):
"""Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
This is the central execution command for the client.
@@ -1003,6 +1007,9 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
Specify an amount of time (in seconds) for the scheduler to
wait for dependencies to be met before failing with a
DependencyTimeout.
+ track : bool
+ whether to track non-copying sends.
+ [default False]
after,follow,timeout only used if `balanced=True`.
@@ -1044,7 +1051,7 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be dict, not %s"%type(kwargs))
- options = dict(bound=bound, block=block, targets=targets)
+ options = dict(bound=bound, block=block, targets=targets, track=track)
if balanced:
return self._apply_balanced(f, args, kwargs, timeout=timeout,
@@ -1057,7 +1064,7 @@ def apply(self, f, args=None, kwargs=None, bound=True, block=None,
return self._apply_direct(f, args, kwargs, **options)
def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
- after=None, follow=None, timeout=None):
+ after=None, follow=None, timeout=None, track=None):
"""call f(*args, **kwargs) remotely in a load-balanced manner.
This is a private method, see `apply` for details.
@@ -1065,7 +1072,7 @@ def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
"""
loc = locals()
- for name in ('bound', 'block'):
+ for name in ('bound', 'block', 'track'):
assert loc[name] is not None, "kwarg %r must be specified!"%name
if self._task_socket is None:
@@ -1101,13 +1108,13 @@ def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
content = dict(bound=bound)
msg = self.session.send(self._task_socket, "apply_request",
- content=content, buffers=bufs, subheader=subheader)
+ content=content, buffers=bufs, subheader=subheader, track=track)
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
self.history.append(msg_id)
self.metadata[msg_id]['submitted'] = datetime.now()
-
- ar = AsyncResult(self, [msg_id], fname=f.__name__)
+ tracker = None if track is False else msg['tracker']
+ ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
if block:
try:
return ar.get()
@@ -1116,7 +1123,8 @@ def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
else:
return ar
- def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
+ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
+ track=None):
"""Then underlying method for applying functions to specific engines
via the MUX queue.
@@ -1124,7 +1132,7 @@ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
Not to be called directly!
"""
loc = locals()
- for name in ('bound', 'block', 'targets'):
+ for name in ('bound', 'block', 'targets', 'track'):
assert loc[name] is not None, "kwarg %r must be specified!"%name
idents,targets = self._build_targets(targets)
@@ -1134,15 +1142,22 @@ def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
bufs = util.pack_apply_message(f,args,kwargs)
msg_ids = []
+ trackers = []
for ident in idents:
msg = self.session.send(self._mux_socket, "apply_request",
- content=content, buffers=bufs, ident=ident, subheader=subheader)
+ content=content, buffers=bufs, ident=ident, subheader=subheader,
+ track=track)
+ if track:
+ trackers.append(msg['tracker'])
msg_id = msg['msg_id']
self.outstanding.add(msg_id)
self._outstanding_dict[ident].add(msg_id)
self.history.append(msg_id)
msg_ids.append(msg_id)
- ar = AsyncResult(self, msg_ids, fname=f.__name__)
+
+ tracker = None if track is False else zmq.MessageTracker(*trackers)
+ ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
+
if block:
try:
return ar.get()
@@ -1230,11 +1245,11 @@ def view(self, targets=None, balanced=None):
#--------------------------------------------------------------------------
@defaultblock
- def push(self, ns, targets='all', block=None):
+ def push(self, ns, targets='all', block=None, track=False):
"""Push the contents of `ns` into the namespace on `target`"""
if not isinstance(ns, dict):
raise TypeError("Must be a dict, not %s"%type(ns))
- result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
+ result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track)
if not block:
return result
@@ -1251,7 +1266,7 @@ def pull(self, keys, targets='all', block=None):
return result
@defaultblock
- def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
+ def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
"""
Partition a Python sequence and send the partitions to a set of engines.
"""
@@ -1259,16 +1274,25 @@ def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
mapObject = Map.dists[dist]()
nparts = len(targets)
msg_ids = []
+ trackers = []
for index, engineid in enumerate(targets):
partition = mapObject.getPartition(seq, index, nparts)
if flatten and len(partition) == 1:
- r = self.push({key: partition[0]}, targets=engineid, block=False)
+ r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
else:
- r = self.push({key: partition}, targets=engineid, block=False)
+ r = self.push({key: partition}, targets=engineid, block=False, track=track)
msg_ids.extend(r.msg_ids)
- r = AsyncResult(self, msg_ids, fname='scatter')
+ if track:
+ trackers.append(r._tracker)
+
+ if track:
+ tracker = zmq.MessageTracker(*trackers)
+ else:
+ tracker = None
+
+ r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
if block:
- r.get()
+ r.wait()
else:
return r
View
51 IPython/zmq/parallel/streamsession.py
@@ -179,7 +179,7 @@ def check_key(self, msg_or_header):
return header.get('key', None) == self.key
- def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
+ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
"""Build and send a message via stream or socket.
Parameters
@@ -191,13 +191,34 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
Normally, msg_or_type will be a msg_type unless a message is being sent more
than once.
+ content : dict or None
+ the content of the message (ignored if msg_or_type is a message)
+ buffers : list or None
+ the already-serialized buffers to be appended to the message
+ parent : Message or dict or None
+ the parent or parent header describing the parent of this message
+ subheader : dict or None
+ extra header keys for this message's header
+ ident : bytes or list of bytes
+ the zmq.IDENTITY routing path
+ track : bool
+ whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
+
Returns
-------
- (msg,sent) : tuple
- msg : Message
- the nice wrapped dict-like object containing the headers
+ msg : message dict
+ the constructed message
+ (msg,tracker) : (message dict, MessageTracker)
+ if track=True, then a 2-tuple will be returned, the first element being the constructed
+ message, and the second being the MessageTracker
"""
+
+ if not isinstance(stream, (zmq.Socket, ZMQStream)):
+ raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
+ elif track and isinstance(stream, ZMQStream):
+ raise TypeError("ZMQStream cannot track messages")
+
if isinstance(msg_or_type, (Message, dict)):
# we got a Message, not a msg_type
# don't build a new Message
@@ -205,6 +226,7 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
content = msg['content']
else:
msg = self.msg(msg_or_type, content, parent, subheader)
+
buffers = [] if buffers is None else buffers
to_send = []
if isinstance(ident, list):
@@ -222,7 +244,7 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
content = self.none
elif isinstance(content, dict):
content = self.pack(content)
- elif isinstance(content, str):
+ elif isinstance(content, bytes):
# content is already packed, as in a relayed message
pass
else:
@@ -231,16 +253,29 @@ def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, sub
flag = 0
if buffers:
flag = zmq.SNDMORE
- stream.send_multipart(to_send, flag, copy=False)
+ _track = False
+ else:
+ _track=track
+ if track:
+ tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
+ else:
+ tracker = stream.send_multipart(to_send, flag, copy=False)
for b in buffers[:-1]:
stream.send(b, flag, copy=False)
if buffers:
- stream.send(buffers[-1], copy=False)
+ if track:
+ tracker = stream.send(buffers[-1], copy=False, track=track)
+ else:
+ tracker = stream.send(buffers[-1], copy=False)
+
# omsg = Message(msg)
if self.debug:
pprint.pprint(msg)
pprint.pprint(to_send)
pprint.pprint(buffers)
+
+ msg['tracker'] = tracker
+
return msg
def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
@@ -250,7 +285,7 @@ def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
----------
msg : list of sendable buffers"""
to_send = []
- if isinstance(ident, str):
+ if isinstance(ident, bytes):
ident = [ident]
if ident is not None:
to_send.extend(ident)
View
10 IPython/zmq/parallel/tests/__init__.py
@@ -1,24 +1,26 @@
"""toplevel setup/teardown for parallel tests."""
+import tempfile
import time
-from subprocess import Popen, PIPE
+from subprocess import Popen, PIPE, STDOUT
from IPython.zmq.parallel.ipcluster import launch_process
from IPython.zmq.parallel.entry_point import select_random_ports
processes = []
+blackhole = tempfile.TemporaryFile()
# nose setup/teardown
def setup():
- cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
+ cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
processes.append(cp)
time.sleep(.5)
add_engine()
- time.sleep(3)
+ time.sleep(2)
def add_engine(profile='iptest'):
- ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
+ ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
# ep.start()
processes.append(ep)
return ep
View
4 IPython/zmq/parallel/tests/clienttest.py
@@ -88,7 +88,9 @@ def setUp(self):
self.base_engine_count=len(self.client.ids)
self.engines=[]
- # def tearDown(self):
+ def tearDown(self):
+ self.client.close()
+ BaseZMQTestCase.tearDown(self)
# [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
# [ e.wait() for e in self.engines ]
# while len(self.client.ids) > self.base_engine_count:
View
91 IPython/zmq/parallel/tests/test_client.py
@@ -2,6 +2,7 @@
from tempfile import mktemp
import nose.tools as nt
+import zmq
from IPython.zmq.parallel import client as clientmod
from IPython.zmq.parallel import error
@@ -18,10 +19,9 @@ def test_ids(self):
self.assertEquals(len(self.client.ids), n+3)
self.assertTrue
- def test_segfault(self):
- """test graceful handling of engine death"""
+ def test_segfault_task(self):
+ """test graceful handling of engine death (balanced)"""
self.add_engines(1)
- eid = self.client.ids[-1]
ar = self.client.apply(segfault, block=False)
self.assertRaisesRemote(error.EngineError, ar.get)
eid = ar.engine_id
@@ -29,6 +29,17 @@ def test_segfault(self):
time.sleep(.01)
self.client.spin()
+ def test_segfault_mux(self):
+ """test graceful handling of engine death (direct)"""
+ self.add_engines(1)
+ eid = self.client.ids[-1]
+ ar = self.client[eid].apply_async(segfault)
+ self.assertRaisesRemote(error.EngineError, ar.get)
+ eid = ar.engine_id
+ while eid in self.client.ids:
+ time.sleep(.01)
+ self.client.spin()
+
def test_view_indexing(self):
"""test index access for views"""
self.add_engines(2)
@@ -91,13 +102,14 @@ def test_clear(self):
def test_push_pull(self):
"""test pushing and pulling"""
data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
+ t = self.client.ids[-1]
self.add_engines(2)
push = self.client.push
pull = self.client.pull
self.client.block=True
nengines = len(self.client)
- push({'data':data}, targets=0)
- d = pull('data', targets=0)
+ push({'data':data}, targets=t)
+ d = pull('data', targets=t)
self.assertEquals(d, data)
push({'data':data})
d = pull('data')
@@ -119,15 +131,16 @@ def testf(x):
return 2.0*x
self.add_engines(4)
+ t = self.client.ids[-1]
self.client.block=True
push = self.client.push
pull = self.client.pull
execute = self.client.execute
- push({'testf':testf}, targets=0)
- r = pull('testf', targets=0)
+ push({'testf':testf}, targets=t)
+ r = pull('testf', targets=t)
self.assertEqual(r(1.0), testf(1.0))
- execute('r = testf(10)', targets=0)
- r = pull('r', targets=0)
+ execute('r = testf(10)', targets=t)
+ r = pull('r', targets=t)
self.assertEquals(r, testf(10))
ar = push({'testf':testf}, block=False)
ar.get()
@@ -135,8 +148,8 @@ def testf(x):
rlist = ar.get()
for r in rlist:
self.assertEqual(r(1.0), testf(1.0))
- execute("def g(x): return x*x", targets=0)
- r = pull(('testf','g'),targets=0)
+ execute("def g(x): return x*x", targets=t)
+ r = pull(('testf','g'),targets=t)
self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
def test_push_function_globals(self):
@@ -173,7 +186,7 @@ def test_ids_list(self):
ids.remove(ids[-1])
self.assertNotEquals(ids, self.client._ids)
- def test_arun_newline(self):
+ def test_run_newline(self):
"""test that run appends newline to files"""
tmpfile = mktemp()
with open(tmpfile, 'w') as f:
@@ -184,4 +197,56 @@ def test_arun_newline(self):
v.run(tmpfile, block=True)
self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
-
+ def test_apply_tracked(self):
+ """test tracking for apply"""
+ # self.add_engines(1)
+ t = self.client.ids[-1]
+ self.client.block=False
+ def echo(n=1024*1024, **kwargs):
+ return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
+ ar = echo(1)
+ self.assertTrue(ar._tracker is None)
+ self.assertTrue(ar.sent)
+ ar = echo(track=True)
+ self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
+ self.assertEquals(ar.sent, ar._tracker.done)
+ ar._tracker.wait()
+ self.assertTrue(ar.sent)
+
+ def test_push_tracked(self):
+ t = self.client.ids[-1]
+ ns = dict(x='x'*1024*1024)
+ ar = self.client.push(ns, targets=t, block=False)
+ self.assertTrue(ar._tracker is None)
+ self.assertTrue(ar.sent)
+
+ ar = self.client.push(ns, targets=t, block=False, track=True)
+ self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
+ self.assertEquals(ar.sent, ar._tracker.done)
+ ar._tracker.wait()
+ self.assertTrue(ar.sent)
+ ar.get()
+
+ def test_scatter_tracked(self):
+ t = self.client.ids
+ x='x'*1024*1024
+ ar = self.client.scatter('x', x, targets=t, block=False)
+ self.assertTrue(ar._tracker is None)
+ self.assertTrue(ar.sent)
+
+ ar = self.client.scatter('x', x, targets=t, block=False, track=True)
+ self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
+ self.assertEquals(ar.sent, ar._tracker.done)
+ ar._tracker.wait()
+ self.assertTrue(ar.sent)
+ ar.get()
+
+ def test_remote_reference(self):
+ v = self.client[-1]
+ v['a'] = 123
+ ra = clientmod.Reference('a')
+ b = v.apply_sync_bound(lambda x: x, ra)
+ self.assertEquals(b, 123)
+ self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra)
+
+
View
23 IPython/zmq/parallel/tests/test_streamsession.py
@@ -4,7 +4,7 @@
import zmq
from zmq.tests import BaseZMQTestCase
-
+from zmq.eventloop.zmqstream import ZMQStream
# from IPython.zmq.tests import SessionTestCase
from IPython.zmq.parallel import streamsession as ss
@@ -31,7 +31,7 @@ def test_msg(self):
def test_args(self):
"""initialization arguments for StreamSession"""
- s = ss.StreamSession()
+ s = self.session
self.assertTrue(s.pack is ss.default_packer)
self.assertTrue(s.unpack is ss.default_unpacker)
self.assertEquals(s.username, os.environ.get('USER', 'username'))
@@ -46,7 +46,24 @@ def test_args(self):
self.assertEquals(s.session, u)
self.assertEquals(s.username, 'carrot')
-
+ def test_tracking(self):
+ """test tracking messages"""
+ a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
+ s = self.session
+ stream = ZMQStream(a)
+ msg = s.send(a, 'hello', track=False)
+ self.assertTrue(msg['tracker'] is None)
+ msg = s.send(a, 'hello', track=True)
+ self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
+ M = zmq.Message(b'hi there', track=True)
+ msg = s.send(a, 'hello', buffers=[M], track=True)
+ t = msg['tracker']
+ self.assertTrue(isinstance(t, zmq.MessageTracker))
+ self.assertRaises(zmq.NotDone, t.wait, .1)
+ del M
+ t.wait(1) # this will raise
+
+
# def test_rekey(self):
# """rekeying dict around json str keys"""
# d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
Please sign in to comment.
Something went wrong with that request. Please try again.