Permalink
Browse files

pyzmq-2.1.3 related testing adjustments

  • Loading branch information...
minrk committed Mar 24, 2011
1 parent 35ad828 commit ee9089a75d324555babc345f3720d25ee53ee264
@@ -16,3 +16,8 @@
# from .remotefunction import *
# from .view import *
+import zmq
+
+if zmq.__version__ < '2.1.3':
+ raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
+
@@ -284,7 +284,7 @@ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipytho
):
super(Client, self).__init__(debug=debug, profile=profile)
if context is None:
- context = zmq.Context()
+ context = zmq.Context.instance()
self._context = context
@@ -976,6 +976,10 @@ def apply(self, f, args=None, kwargs=None, bound=False, block=None,
returns actual result(s) of f(*args, **kwargs)
if multiple targets:
list of results, matching `targets`
+ track : bool
+ whether to track non-copying sends.
+ [default False]
+
targets : int,list of ints, 'all', None
Specify the destination of the job.
if None:
@@ -986,34 +990,37 @@ def apply(self, f, args=None, kwargs=None, bound=False, block=None,
Run on each specified engine
if int:
Run on single engine
-
+ Note:
+ that if `balanced=True`, and `targets` is specified,
+ then the load-balancing will be limited to balancing
+ among `targets`.
+
balanced : bool, default None
whether to load-balance. This will default to True
if targets is unspecified, or False if targets is specified.
-
- The following arguments are only used when balanced is True:
+
+ If `balanced` and `targets` are both specified, the task will
+ be assigne to *one* of the targets by the scheduler.
+
+ The following arguments are only used when balanced is True:
+
after : Dependency or collection of msg_ids
Only for load-balanced execution (targets=None)
Specify a list of msg_ids as a time-based dependency.
This job will only be run *after* the dependencies
have been met.
-
+
follow : Dependency or collection of msg_ids
Only for load-balanced execution (targets=None)
Specify a list of msg_ids as a location-based dependency.
This job will only be run on an engine where this dependency
is met.
-
+
timeout : float/int or None
Only for load-balanced execution (targets=None)
Specify an amount of time (in seconds) for the scheduler to
wait for dependencies to be met before failing with a
DependencyTimeout.
- track : bool
- whether to track non-copying sends.
- [default False]
-
- after,follow,timeout only used if `balanced=True`.
Returns
-------
@@ -1022,7 +1029,7 @@ def apply(self, f, args=None, kwargs=None, bound=False, block=None,
return AsyncResult wrapping msg_ids
output of AsyncResult.get() is identical to that of `apply(...block=True)`
else:
- if single target:
+ if single target (or balanced):
return result of `f(*args, **kwargs)`
else:
return list of results, matching `targets`
@@ -69,8 +69,9 @@ def wait_on_engines(self, timeout=5):
def connect_client(self):
"""connect a client with my Context, and track its sockets for cleanup"""
c = Client(profile='iptest',context=self.context)
- for name in filter(lambda n:n.endswith('socket'), dir(c)):
- self.sockets.append(getattr(c, name))
+
+ # for name in filter(lambda n:n.endswith('socket'), dir(c)):
+ # self.sockets.append(getattr(c, name))
return c
def assertRaisesRemote(self, etype, f, *args, **kwargs):
@@ -100,6 +101,6 @@ def tearDown(self):
BaseZMQTestCase.tearDown(self)
# this will be superfluous when pyzmq merges PR #88
self.context.term()
- print tempfile.TemporaryFile().fileno(),
- sys.stdout.flush()
+ # print tempfile.TemporaryFile().fileno(),
+ # sys.stdout.flush()
@@ -1,7 +1,6 @@
import time
from tempfile import mktemp
-import nose.tools as nt
import zmq
from IPython.zmq.parallel import client as clientmod
@@ -65,7 +64,7 @@ def test_view_indexing(self):
v = self.client[-1]
self.assert_(isinstance(v, DirectView))
self.assertEquals(v.targets, targets[-1])
- nt.assert_raises(TypeError, lambda : self.client[None])
+ self.assertRaises(TypeError, lambda : self.client[None])
def test_view_cache(self):
"""test that multiple view requests return the same object"""
@@ -179,6 +178,7 @@ def test_get_result(self):
"""test getting results from the Hub."""
c = clientmod.Client(profile='iptest')
self.add_engines(1)
+ t = c.ids[-1]
ar = c.apply(wait, (1,), block=False, targets=t)
# give the monitor time to notice the message
time.sleep(.25)
@@ -2,8 +2,6 @@
from unittest import TestCase
-import nose.tools as nt
-
from IPython.testing.parametric import parametric
from IPython.utils import newserialized as ns
from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
@@ -14,12 +12,12 @@ class CanningTestCase(TestCase):
def test_canning(self):
d = dict(a=5,b=6)
cd = can(d)
- nt.assert_true(isinstance(cd, dict))
+ self.assertTrue(isinstance(cd, dict))
def test_canned_function(self):
f = lambda : 7
cf = can(f)
- nt.assert_true(isinstance(cf, CannedFunction))
+ self.assertTrue(isinstance(cf, CannedFunction))
@parametric
def test_can_roundtrip(cls):
@@ -32,17 +30,17 @@ def test_can_roundtrip(cls):
return map(cls.run_roundtrip, objs)
@classmethod
- def run_roundtrip(cls, obj):
+ def run_roundtrip(self, obj):
o = uncan(can(obj))
- nt.assert_equals(obj, o)
+ assert o == obj, "failed assertion: %r == %r"%(o,obj)
def test_serialized_interfaces(self):
us = {'a':10, 'b':range(10)}
s = ns.serialize(us)
uus = ns.unserialize(s)
- nt.assert_true(isinstance(s, ns.SerializeIt))
- nt.assert_equals(uus, us)
+ self.assertTrue(isinstance(s, ns.SerializeIt))
+ self.assertEquals(uus, us)
def test_pickle_serialized(self):
obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
@@ -51,16 +49,16 @@ def test_pickle_serialized(self):
firstData = originalSer.getData()
firstTD = originalSer.getTypeDescriptor()
firstMD = originalSer.getMetadata()
- nt.assert_equals(firstTD, 'pickle')
- nt.assert_equals(firstMD, {})
+ self.assertEquals(firstTD, 'pickle')
+ self.assertEquals(firstMD, {})
unSerialized = ns.UnSerializeIt(originalSer)
secondObj = unSerialized.getObject()
for k, v in secondObj.iteritems():
- nt.assert_equals(obj[k], v)
+ self.assertEquals(obj[k], v)
secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
- nt.assert_equals(firstData, secondSer.getData())
- nt.assert_equals(firstTD, secondSer.getTypeDescriptor() )
- nt.assert_equals(firstMD, secondSer.getMetadata())
+ self.assertEquals(firstData, secondSer.getData())
+ self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
+ self.assertEquals(firstMD, secondSer.getMetadata())
@skip_without('numpy')
def test_ndarray_serialized(self):
@@ -69,21 +67,21 @@ def test_ndarray_serialized(self):
unSer1 = ns.UnSerialized(a)
ser1 = ns.SerializeIt(unSer1)
td = ser1.getTypeDescriptor()
- nt.assert_equals(td, 'ndarray')
+ self.assertEquals(td, 'ndarray')
md = ser1.getMetadata()
- nt.assert_equals(md['shape'], a.shape)
- nt.assert_equals(md['dtype'], a.dtype.str)
+ self.assertEquals(md['shape'], a.shape)
+ self.assertEquals(md['dtype'], a.dtype.str)
buff = ser1.getData()
- nt.assert_equals(buff, numpy.getbuffer(a))
+ self.assertEquals(buff, numpy.getbuffer(a))
s = ns.Serialized(buff, td, md)
final = ns.unserialize(s)
- nt.assert_equals(numpy.getbuffer(a), numpy.getbuffer(final))
- nt.assert_true((a==final).all())
- nt.assert_equals(a.dtype.str, final.dtype.str)
- nt.assert_equals(a.shape, final.shape)
+ self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
+ self.assertTrue((a==final).all())
+ self.assertEquals(a.dtype.str, final.dtype.str)
+ self.assertEquals(a.shape, final.shape)
# test non-copying:
a[2] = 1e9
- nt.assert_true((a==final).all())
+ self.assertTrue((a==final).all())

0 comments on commit ee9089a

Please sign in to comment.