Skip to content

Commit

Permalink
pyzmq-2.1.3 related testing adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Apr 8, 2011
1 parent 35ad828 commit ee9089a
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 41 deletions.
5 changes: 5 additions & 0 deletions IPython/zmq/parallel/__init__.py
Expand Up @@ -16,3 +16,8 @@
# from .remotefunction import * # from .remotefunction import *
# from .view import * # from .view import *


import zmq

if zmq.__version__ < '2.1.3':
raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)

31 changes: 19 additions & 12 deletions IPython/zmq/parallel/client.py
Expand Up @@ -284,7 +284,7 @@ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipytho
): ):
super(Client, self).__init__(debug=debug, profile=profile) super(Client, self).__init__(debug=debug, profile=profile)
if context is None: if context is None:
context = zmq.Context() context = zmq.Context.instance()
self._context = context self._context = context




Expand Down Expand Up @@ -976,6 +976,10 @@ def apply(self, f, args=None, kwargs=None, bound=False, block=None,
returns actual result(s) of f(*args, **kwargs) returns actual result(s) of f(*args, **kwargs)
if multiple targets: if multiple targets:
list of results, matching `targets` list of results, matching `targets`
track : bool
whether to track non-copying sends.
[default False]
targets : int,list of ints, 'all', None targets : int,list of ints, 'all', None
Specify the destination of the job. Specify the destination of the job.
if None: if None:
Expand All @@ -986,34 +990,37 @@ def apply(self, f, args=None, kwargs=None, bound=False, block=None,
Run on each specified engine Run on each specified engine
if int: if int:
Run on single engine Run on single engine
Note:
that if `balanced=True`, and `targets` is specified,
then the load-balancing will be limited to balancing
among `targets`.
balanced : bool, default None balanced : bool, default None
whether to load-balance. This will default to True whether to load-balance. This will default to True
if targets is unspecified, or False if targets is specified. if targets is unspecified, or False if targets is specified.
The following arguments are only used when balanced is True: If `balanced` and `targets` are both specified, the task will
be assigne to *one* of the targets by the scheduler.
The following arguments are only used when balanced is True:
after : Dependency or collection of msg_ids after : Dependency or collection of msg_ids
Only for load-balanced execution (targets=None) Only for load-balanced execution (targets=None)
Specify a list of msg_ids as a time-based dependency. Specify a list of msg_ids as a time-based dependency.
This job will only be run *after* the dependencies This job will only be run *after* the dependencies
have been met. have been met.
follow : Dependency or collection of msg_ids follow : Dependency or collection of msg_ids
Only for load-balanced execution (targets=None) Only for load-balanced execution (targets=None)
Specify a list of msg_ids as a location-based dependency. Specify a list of msg_ids as a location-based dependency.
This job will only be run on an engine where this dependency This job will only be run on an engine where this dependency
is met. is met.
timeout : float/int or None timeout : float/int or None
Only for load-balanced execution (targets=None) Only for load-balanced execution (targets=None)
Specify an amount of time (in seconds) for the scheduler to Specify an amount of time (in seconds) for the scheduler to
wait for dependencies to be met before failing with a wait for dependencies to be met before failing with a
DependencyTimeout. DependencyTimeout.
track : bool
whether to track non-copying sends.
[default False]
after,follow,timeout only used if `balanced=True`.
Returns Returns
------- -------
Expand All @@ -1022,7 +1029,7 @@ def apply(self, f, args=None, kwargs=None, bound=False, block=None,
return AsyncResult wrapping msg_ids return AsyncResult wrapping msg_ids
output of AsyncResult.get() is identical to that of `apply(...block=True)` output of AsyncResult.get() is identical to that of `apply(...block=True)`
else: else:
if single target: if single target (or balanced):
return result of `f(*args, **kwargs)` return result of `f(*args, **kwargs)`
else: else:
return list of results, matching `targets` return list of results, matching `targets`
Expand Down
9 changes: 5 additions & 4 deletions IPython/zmq/parallel/tests/clienttest.py
Expand Up @@ -69,8 +69,9 @@ def wait_on_engines(self, timeout=5):
def connect_client(self): def connect_client(self):
"""connect a client with my Context, and track its sockets for cleanup""" """connect a client with my Context, and track its sockets for cleanup"""
c = Client(profile='iptest',context=self.context) c = Client(profile='iptest',context=self.context)
for name in filter(lambda n:n.endswith('socket'), dir(c)):
self.sockets.append(getattr(c, name)) # for name in filter(lambda n:n.endswith('socket'), dir(c)):
# self.sockets.append(getattr(c, name))
return c return c


def assertRaisesRemote(self, etype, f, *args, **kwargs): def assertRaisesRemote(self, etype, f, *args, **kwargs):
Expand Down Expand Up @@ -100,6 +101,6 @@ def tearDown(self):
BaseZMQTestCase.tearDown(self) BaseZMQTestCase.tearDown(self)
# this will be superfluous when pyzmq merges PR #88 # this will be superfluous when pyzmq merges PR #88
self.context.term() self.context.term()
print tempfile.TemporaryFile().fileno(), # print tempfile.TemporaryFile().fileno(),
sys.stdout.flush() # sys.stdout.flush()


4 changes: 2 additions & 2 deletions IPython/zmq/parallel/tests/test_client.py
@@ -1,7 +1,6 @@
import time import time
from tempfile import mktemp from tempfile import mktemp


import nose.tools as nt
import zmq import zmq


from IPython.zmq.parallel import client as clientmod from IPython.zmq.parallel import client as clientmod
Expand Down Expand Up @@ -65,7 +64,7 @@ def test_view_indexing(self):
v = self.client[-1] v = self.client[-1]
self.assert_(isinstance(v, DirectView)) self.assert_(isinstance(v, DirectView))
self.assertEquals(v.targets, targets[-1]) self.assertEquals(v.targets, targets[-1])
nt.assert_raises(TypeError, lambda : self.client[None]) self.assertRaises(TypeError, lambda : self.client[None])


def test_view_cache(self): def test_view_cache(self):
"""test that multiple view requests return the same object""" """test that multiple view requests return the same object"""
Expand Down Expand Up @@ -179,6 +178,7 @@ def test_get_result(self):
"""test getting results from the Hub.""" """test getting results from the Hub."""
c = clientmod.Client(profile='iptest') c = clientmod.Client(profile='iptest')
self.add_engines(1) self.add_engines(1)
t = c.ids[-1]
ar = c.apply(wait, (1,), block=False, targets=t) ar = c.apply(wait, (1,), block=False, targets=t)
# give the monitor time to notice the message # give the monitor time to notice the message
time.sleep(.25) time.sleep(.25)
Expand Down
44 changes: 21 additions & 23 deletions IPython/zmq/parallel/tests/test_newserialized.py
Expand Up @@ -2,8 +2,6 @@


from unittest import TestCase from unittest import TestCase


import nose.tools as nt

from IPython.testing.parametric import parametric from IPython.testing.parametric import parametric
from IPython.utils import newserialized as ns from IPython.utils import newserialized as ns
from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
Expand All @@ -14,12 +12,12 @@ class CanningTestCase(TestCase):
def test_canning(self): def test_canning(self):
d = dict(a=5,b=6) d = dict(a=5,b=6)
cd = can(d) cd = can(d)
nt.assert_true(isinstance(cd, dict)) self.assertTrue(isinstance(cd, dict))


def test_canned_function(self): def test_canned_function(self):
f = lambda : 7 f = lambda : 7
cf = can(f) cf = can(f)
nt.assert_true(isinstance(cf, CannedFunction)) self.assertTrue(isinstance(cf, CannedFunction))


@parametric @parametric
def test_can_roundtrip(cls): def test_can_roundtrip(cls):
Expand All @@ -32,17 +30,17 @@ def test_can_roundtrip(cls):
return map(cls.run_roundtrip, objs) return map(cls.run_roundtrip, objs)


@classmethod @classmethod
def run_roundtrip(cls, obj): def run_roundtrip(self, obj):
o = uncan(can(obj)) o = uncan(can(obj))
nt.assert_equals(obj, o) assert o == obj, "failed assertion: %r == %r"%(o,obj)


def test_serialized_interfaces(self): def test_serialized_interfaces(self):


us = {'a':10, 'b':range(10)} us = {'a':10, 'b':range(10)}
s = ns.serialize(us) s = ns.serialize(us)
uus = ns.unserialize(s) uus = ns.unserialize(s)
nt.assert_true(isinstance(s, ns.SerializeIt)) self.assertTrue(isinstance(s, ns.SerializeIt))
nt.assert_equals(uus, us) self.assertEquals(uus, us)


def test_pickle_serialized(self): def test_pickle_serialized(self):
obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L} obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
Expand All @@ -51,16 +49,16 @@ def test_pickle_serialized(self):
firstData = originalSer.getData() firstData = originalSer.getData()
firstTD = originalSer.getTypeDescriptor() firstTD = originalSer.getTypeDescriptor()
firstMD = originalSer.getMetadata() firstMD = originalSer.getMetadata()
nt.assert_equals(firstTD, 'pickle') self.assertEquals(firstTD, 'pickle')
nt.assert_equals(firstMD, {}) self.assertEquals(firstMD, {})
unSerialized = ns.UnSerializeIt(originalSer) unSerialized = ns.UnSerializeIt(originalSer)
secondObj = unSerialized.getObject() secondObj = unSerialized.getObject()
for k, v in secondObj.iteritems(): for k, v in secondObj.iteritems():
nt.assert_equals(obj[k], v) self.assertEquals(obj[k], v)
secondSer = ns.SerializeIt(ns.UnSerialized(secondObj)) secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
nt.assert_equals(firstData, secondSer.getData()) self.assertEquals(firstData, secondSer.getData())
nt.assert_equals(firstTD, secondSer.getTypeDescriptor() ) self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
nt.assert_equals(firstMD, secondSer.getMetadata()) self.assertEquals(firstMD, secondSer.getMetadata())


@skip_without('numpy') @skip_without('numpy')
def test_ndarray_serialized(self): def test_ndarray_serialized(self):
Expand All @@ -69,21 +67,21 @@ def test_ndarray_serialized(self):
unSer1 = ns.UnSerialized(a) unSer1 = ns.UnSerialized(a)
ser1 = ns.SerializeIt(unSer1) ser1 = ns.SerializeIt(unSer1)
td = ser1.getTypeDescriptor() td = ser1.getTypeDescriptor()
nt.assert_equals(td, 'ndarray') self.assertEquals(td, 'ndarray')
md = ser1.getMetadata() md = ser1.getMetadata()
nt.assert_equals(md['shape'], a.shape) self.assertEquals(md['shape'], a.shape)
nt.assert_equals(md['dtype'], a.dtype.str) self.assertEquals(md['dtype'], a.dtype.str)
buff = ser1.getData() buff = ser1.getData()
nt.assert_equals(buff, numpy.getbuffer(a)) self.assertEquals(buff, numpy.getbuffer(a))
s = ns.Serialized(buff, td, md) s = ns.Serialized(buff, td, md)
final = ns.unserialize(s) final = ns.unserialize(s)
nt.assert_equals(numpy.getbuffer(a), numpy.getbuffer(final)) self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
nt.assert_true((a==final).all()) self.assertTrue((a==final).all())
nt.assert_equals(a.dtype.str, final.dtype.str) self.assertEquals(a.dtype.str, final.dtype.str)
nt.assert_equals(a.shape, final.shape) self.assertEquals(a.shape, final.shape)
# test non-copying: # test non-copying:
a[2] = 1e9 a[2] = 1e9
nt.assert_true((a==final).all()) self.assertTrue((a==final).all())






0 comments on commit ee9089a

Please sign in to comment.