Skip to content

Commit

Permalink
Make testing.py classmethods into attributes.
Browse files Browse the repository at this point in the history
`get_ipcluster_size()` -> `ipcluster_size`
`get_comm_size()` -> `comm_size`
  • Loading branch information
bgrant committed Mar 11, 2014
1 parent ba15e04 commit 19a7170
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 38 deletions.
Expand Up @@ -155,9 +155,7 @@ def setUp(self):

class TestDapThreeBlockDims(DapTestMixin, MpiTestCase):

@classmethod
def get_comm_size(cls):
return 12
comm_size = 12

def setUp(self):
self.larr = distarray.local.LocalArray((53, 77, 99),
Expand Down Expand Up @@ -195,9 +193,7 @@ def setUp(self):

class TestDapLopsided(DapTestMixin, MpiTestCase):

@classmethod
def get_comm_size(cls):
return 2
comm_size = 2

def setUp(self):
if self.comm.Get_rank() == 0:
Expand Down
4 changes: 2 additions & 2 deletions distarray/local/tests/paralleltest_functions.py
Expand Up @@ -45,14 +45,14 @@ def f(*global_inds):
class TestCreationFuncs(MpiTestCase):

def test_zeros(self):
size = self.get_comm_size()
size = self.comm_size
nrows = size * 3
a = dla.zeros((nrows, 20), comm=self.comm)
expected = np.zeros((nrows // size, 20))
assert_array_equal(a.local_array, expected)

def test_ones(self):
size = self.get_comm_size()
size = self.comm_size
nrows = size * 3
a = dla.ones((nrows, 20), comm=self.comm)
expected = np.ones((nrows // size, 20))
Expand Down
8 changes: 2 additions & 6 deletions distarray/local/tests/paralleltest_io.py
Expand Up @@ -127,9 +127,7 @@ def test_flat_file_save_load_with_file_object(self):

class TestNpyFileLoad(MpiTestCase):

@classmethod
def get_comm_size(self):
return 2
comm_size = 2

def setUp(self):
self.rank = self.comm.Get_rank()
Expand Down Expand Up @@ -220,9 +218,7 @@ def tearDown(self):

class TestHdf5FileLoad(MpiTestCase):

@classmethod
def get_comm_size(cls):
return 2
comm_size = 2

def setUp(self):
self.rank = self.comm.Get_rank()
Expand Down
8 changes: 2 additions & 6 deletions distarray/local/tests/paralleltest_localarray.py
Expand Up @@ -182,9 +182,7 @@ def test_block_cyclic(self):

class TestGridShape(MpiTestCase):

@classmethod
def get_comm_size(cls):
return 12
comm_size = 12

def test_grid_shape(self):
"""Test various ways of setting the grid_shape."""
Expand All @@ -204,9 +202,7 @@ class TestDistMatrix(MpiTestCase):

"""Test the dist_matrix."""

@classmethod
def get_comm_size(cls):
return 12
comm_size = 12

@unittest.skip("Plot test.")
def test_plot_dist_matrix(self):
Expand Down
21 changes: 9 additions & 12 deletions distarray/testing.py
Expand Up @@ -95,20 +95,19 @@ class MpiTestCase(unittest.TestCase):

"""Base test class for MPI test cases.
Overload `get_comm_size` to change the default comm size (default is 4).
Overload the `comm_size` class attribute to change the default
(default is 4).
"""

@classmethod
def get_comm_size(cls):
return 4
comm_size = 4

@classmethod
def setUpClass(cls):
try:
cls.comm = create_comm_of_size(cls.get_comm_size())
cls.comm = create_comm_of_size(cls.comm_size)
except InvalidCommSizeError:
msg = "Must run with comm size >= {}."
raise unittest.SkipTest(msg.format(cls.get_comm_size()))
raise unittest.SkipTest(msg.format(cls.comm_size))

@classmethod
def tearDownClass(cls):
Expand All @@ -120,19 +119,17 @@ class IpclusterTestCase(unittest.TestCase):

"""Base test class for test cases needing an ipcluster.
Overload `get_ipcluster_size` to change the default (default is 4).
Overload the `ipcluster_size` class attribute to change the default (default is 4).
"""

@classmethod
def get_ipcluster_size(cls):
return 4
ipcluster_size = 4

@classmethod
def setUpClass(cls):
cls.client = Client()
if len(cls.client) < cls.get_ipcluster_size():
if len(cls.client) < cls.ipcluster_size:
errmsg = 'Tests need an ipcluster with at least {} engines running.'
raise unittest.SkipTest(errmsg.format(cls.get_ipcluster_size()))
raise unittest.SkipTest(errmsg.format(cls.ipcluster_size))

def tearDown(self):
self.client.clear(block=True)
Expand Down
8 changes: 2 additions & 6 deletions distarray/tests/test_distributed_io.py
Expand Up @@ -128,9 +128,7 @@ def test_save_load_with_prefix(self):

class TestNpyFileLoad(IpclusterTestCase):

@classmethod
def get_ipcluster_size(cls):
return 2
ipcluster_size = 2

def setUp(self):
self.dac = Context(self.client, targets=[0, 1])
Expand Down Expand Up @@ -230,9 +228,7 @@ def test_save_two_datasets(self):

class TestHdf5FileLoad(IpclusterTestCase):

@classmethod
def get_ipcluster_size(cls):
return 2
ipcluster_size = 2

def setUp(self):
self.h5py = import_or_skip('h5py')
Expand Down

0 comments on commit 19a7170

Please sign in to comment.