Permalink
Browse files

testing fixes

  • Loading branch information...
1 parent 60c800a commit 4f574e163c57914da5f32b3b026a1feb54bb538c @minrk minrk committed Feb 28, 2011
@@ -35,7 +35,7 @@ class AsyncResult(object):
msg_ids = None
- def __init__(self, client, msg_ids, fname=''):
+ def __init__(self, client, msg_ids, fname='unknown'):
self._client = client
if isinstance(msg_ids, basestring):
msg_ids = [msg_ids]
@@ -265,7 +265,7 @@ def wait(self, timeout=-1):
else:
rdict = self._client.result_status(remote_ids, status_only=False)
pending = rdict['pending']
- while pending and time.time() < start+timeout:
+ while pending and (timeout < 0 or time.time() < start+timeout):
rdict = self._client.result_status(remote_ids, status_only=False)
pending = rdict['pending']
if pending:
@@ -360,16 +360,17 @@ def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
if cluster_dir is not None:
try:
self._cd = ClusterDir.find_cluster_dir(cluster_dir)
+ return
except ClusterDirError:
pass
elif profile is not None:
try:
self._cd = ClusterDir.find_cluster_dir_by_profile(
ipython_dir, profile)
+ return
except ClusterDirError:
pass
- else:
- self._cd = None
+ self._cd = None
@property
def ids(self):
@@ -489,9 +490,9 @@ def _unwrap_exception(self, content):
"""unwrap exception, and remap engineid to int."""
e = ss.unwrap_exception(content)
if e.engine_info:
- e_uuid = e.engine_info['engineid']
+ e_uuid = e.engine_info['engine_uuid']
eid = self._engines[e_uuid]
- e.engine_info['engineid'] = eid
+ e.engine_info['engine_id'] = eid
return e
def _register_engine(self, msg):
@@ -1338,11 +1339,11 @@ def result_status(self, msg_ids, status_only=True):
be lists of msg_ids that are incomplete or complete. If `status_only`
is False, then completed results will be keyed by their `msg_id`.
"""
- if not isinstance(indices_or_msg_ids, (list,tuple)):
- indices_or_msg_ids = [indices_or_msg_ids]
+ if not isinstance(msg_ids, (list,tuple)):
+ indices_or_msg_ids = [msg_ids]
theids = []
- for msg_id in indices_or_msg_ids:
+ for msg_id in msg_ids:
if isinstance(msg_id, int):
msg_id = self.history[msg_id]
if not isinstance(msg_id, basestring):
@@ -175,7 +175,7 @@ def __init__(self, ename, evalue, traceback, engine_info=None):
self.args=(ename, evalue)
def __repr__(self):
- engineid = self.engine_info.get('engineid', ' ')
+ engineid = self.engine_info.get('engine_id', ' ')
return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
def __str__(self):
@@ -702,7 +702,7 @@ def save_task_destination(self, idents, msg):
self.log.error("task::invalid task tracking message", exc_info=True)
return
content = msg['content']
- print (content)
+ # print (content)
msg_id = content['msg_id']
engine_uuid = content['engine_id']
eid = self.by_ident[engine_uuid]
@@ -728,7 +728,7 @@ def mia_task_request(self, idents, msg):
def save_iopub_message(self, topics, msg):
"""save an iopub message into the db"""
- print (topics)
+ # print (topics)
try:
msg = self.session.unpack_message(msg, content=True)
except:
@@ -12,33 +12,41 @@
import warnings
+from IPython.testing import decorators as testdec
+
import map as Map
from asyncresult import AsyncMapResult
#-----------------------------------------------------------------------------
# Decorators
#-----------------------------------------------------------------------------
+@testdec.skip_doctest
def remote(client, bound=True, block=None, targets=None, balanced=None):
"""Turn a function into a remote function.
This method can be used for map:
- >>> @remote(client,block=True)
- def func(a)
+ In [1]: @remote(client,block=True)
+ ...: def func(a):
+ ...: pass
"""
+
def remote_function(f):
return RemoteFunction(client, f, bound, block, targets, balanced)
return remote_function
+@testdec.skip_doctest
def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None):
"""Turn a function into a parallel remote function.
This method can be used for map:
- >>> @parallel(client,block=True)
- def func(a)
+ In [1]: @parallel(client,block=True)
+ ...: def func(a):
+ ...: pass
"""
+
def parallel_function(f):
return ParallelFunction(client, f, dist, bound, block, targets, balanced)
return parallel_function
@@ -104,7 +104,7 @@ def __init__(self, **kwargs):
self._initial_exec_lines()
def _wrap_exception(self, method=None):
- e_info = dict(engineid=self.ident, method=method)
+ e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
content=wrap_exception(e_info)
return content
@@ -29,7 +29,6 @@ def teardown():
p = processes.pop()
if p.poll() is None:
try:
- print 'terminating'
p.terminate()
except Exception, e:
print e
@@ -17,7 +17,7 @@
# simple tasks for use in apply tests
def segfault():
- """"""
+ """this will segfault"""
import ctypes
ctypes.memset(-1,0,1)
@@ -73,9 +73,10 @@ def connect_client(self):
def assertRaisesRemote(self, etype, f, *args, **kwargs):
try:
- f(*args, **kwargs)
- except error.CompositeError as e:
- e.raise_exception()
+ try:
+ f(*args, **kwargs)
+ except error.CompositeError as e:
+ e.raise_exception()
except error.RemoteError as e:
self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
else:
@@ -87,10 +88,11 @@ def setUp(self):
self.base_engine_count=len(self.client.ids)
self.engines=[]
- def tearDown(self):
- [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
- # while len(self.client.ids) > self.base_engine_count:
- # time.sleep(.1)
- del self.engines
- BaseZMQTestCase.tearDown(self)
+ # def tearDown(self):
+ # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
+ # [ e.wait() for e in self.engines ]
+ # while len(self.client.ids) > self.base_engine_count:
+ # time.sleep(.1)
+ # del self.engines
+ # BaseZMQTestCase.tearDown(self)
@@ -2,28 +2,35 @@
import nose.tools as nt
-from IPython.zmq.parallel.asyncresult import AsyncResult
+from IPython.zmq.parallel import client as clientmod
+from IPython.zmq.parallel import error
+from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
from IPython.zmq.parallel.view import LoadBalancedView, DirectView
-from clienttest import ClusterTestCase, segfault
+from clienttest import ClusterTestCase, segfault, wait
class TestClient(ClusterTestCase):
def test_ids(self):
- self.assertEquals(len(self.client.ids), 1)
+ n = len(self.client.ids)
self.add_engines(3)
- self.assertEquals(len(self.client.ids), 4)
+ self.assertEquals(len(self.client.ids), n+3)
+ self.assertTrue
def test_segfault(self):
+ """test graceful handling of engine death"""
self.add_engines(1)
eid = self.client.ids[-1]
- self.client[eid].apply(segfault)
+ ar = self.client.apply(segfault, block=False)
+ self.assertRaisesRemote(error.EngineError, ar.get)
+ eid = ar.engine_id
while eid in self.client.ids:
time.sleep(.01)
self.client.spin()
def test_view_indexing(self):
- self.add_engines(4)
+ """test index access for views"""
+ self.add_engines(2)
targets = self.client._build_targets('all')[-1]
v = self.client[:]
self.assertEquals(v.targets, targets)
@@ -60,17 +67,30 @@ def test_view_cache(self):
def test_targets(self):
"""test various valid targets arguments"""
- pass
+ build = self.client._build_targets
+ ids = self.client.ids
+ idents,targets = build(None)
+ self.assertEquals(ids, targets)
def test_clear(self):
"""test clear behavior"""
- # self.add_engines(4)
- # self.client.push()
+ self.add_engines(2)
+ self.client.block=True
+ self.client.push(dict(a=5))
+ self.client.pull('a')
+ id0 = self.client.ids[-1]
+ self.client.clear(targets=id0)
+ self.client.pull('a', targets=self.client.ids[:-1])
+ self.assertRaisesRemote(NameError, self.client.pull, 'a')
+ self.client.clear()
+ for i in self.client.ids:
+ self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
+
def test_push_pull(self):
"""test pushing and pulling"""
data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
- self.add_engines(4)
+ self.add_engines(2)
push = self.client.push
pull = self.client.pull
self.client.block=True
@@ -131,4 +151,15 @@ def geta():
v.execute('b=f()')
self.assertEquals(v['b'], 5)
+ def test_get_result(self):
+ """test getting results from the Hub."""
+ c = clientmod.Client(profile='iptest')
+ t = self.client.ids[-1]
+ ar = c.apply(wait, (1,), block=False, targets=t)
+ time.sleep(.25)
+ ahr = self.client.get_result(ar.msg_ids)
+ self.assertTrue(isinstance(ahr, AsyncHubResult))
+ self.assertEquals(ahr.get(), ar.get())
+ ar2 = self.client.get_result(ar.msg_ids)
+ self.assertFalse(isinstance(ar2, AsyncHubResult))
@@ -10,6 +10,7 @@
# Imports
#-----------------------------------------------------------------------------
+from IPython.testing import decorators as testdec
from IPython.utils.traitlets import HasTraits, Bool, List, Dict, Set, Int, Instance
from IPython.external.decorator import decorator
@@ -330,7 +331,7 @@ def parallel(self, dist='b', bound=True, block=None):
block = self.block if block is None else block
return parallel(self.client, bound=bound, targets=self._targets, block=block, balanced=self._balanced)
-
+@testdec.skip_doctest
class DirectView(View):
"""Direct Multiplexer View of one or more engines.
@@ -413,7 +414,7 @@ def update(self, ns):
return self.client.push(ns, targets=self._targets, block=self.block)
push = update
-
+
def get(self, key_s):
"""get object(s) by `key_s` from remote namespace
will return one object if it is a key.
@@ -430,26 +431,24 @@ def pull(self, key_s, block=True):
block = block if block is not None else self.block
return self.client.pull(key_s, block=block, targets=self._targets)
- def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None):
+ def scatter(self, key, seq, dist='b', flatten=False, block=None):
"""
Partition a Python sequence and send the partitions to a set of engines.
"""
block = block if block is not None else self.block
- targets = targets if targets is not None else self._targets
return self.client.scatter(key, seq, dist=dist, flatten=flatten,
- targets=targets, block=block)
+ targets=self._targets, block=block)
@sync_results
@save_ids
- def gather(self, key, dist='b', targets=None, block=None):
+ def gather(self, key, dist='b', block=None):
"""
Gather a partitioned sequence on a set of engines as a single local seq.
"""
block = block if block is not None else self.block
- targets = targets if targets is not None else self._targets
- return self.client.gather(key, dist=dist, targets=targets, block=block)
+ return self.client.gather(key, dist=dist, targets=self._targets, block=block)
def __getitem__(self, key):
return self.get(key)
@@ -496,7 +495,8 @@ def activate(self):
print "You must first load the parallelmagic extension " \
"by doing '%load_ext parallelmagic'"
-
+
+@testdec.skip_doctest
class LoadBalancedView(View):
"""An load-balancing View that only executes via the Task scheduler.

0 comments on commit 4f574e1

Please sign in to comment.