Skip to content

Commit

Permalink
tests: improve conftest, dle
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jul 26, 2021
1 parent 4d25407 commit c38526b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
19 changes: 10 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,44 +213,45 @@ def _R(expr):
def assert_structure(operator, exp_trees=None, exp_iters=None):
"""
Utility function that helps to check loop structure of IETs. Retrieves trees from an
Operator and check that the blocking structure is as expected. Trees and Iterations
are returned for further use in tests.
Operator and check that the blocking structure is as expected.
Examples
--------
To check that an Iteration tree has the following structure:
.. code-block:: python
for t
for time
for x
for y
for f
for y
we call(Note: `time` mapped to `t`):
we call:
.. code-block:: python
trees, iters = assert_structure(op, ['t,x,y', 't,f,y'], 't,x,y,f,y')`
assert_structure(op, ['t,x,y', 't,f,y'], 't,x,y,f,y')`
Notes
-----
`time` is mapped to `t`
"""
mapper = {'time': 't'}
trees = retrieve_iteration_tree(operator)
iters = FindNodes(Iteration).visit(operator)

if exp_trees is not None:
trees = retrieve_iteration_tree(operator)
exp_trees = [i.replace(',', '') for i in exp_trees] # 't,x,y' -> 'txy'
tree_struc = (["".join(mapper.get(i.dim.name, i.dim.name) for i in j)
for j in trees]) # Flatten every tree's dims as a string
assert tree_struc == exp_trees

if exp_iters is not None:
iters = FindNodes(Iteration).visit(operator)
exp_iters = exp_iters.replace(',', '') # 't,x,y' -> 'txy'
iter_struc = "".join(mapper.get(i.dim.name, i.dim.name) for i in iters)
assert iter_struc == exp_iters

return trees, iters


def assert_blocking(operator, exp_nests):
"""
Expand Down
30 changes: 16 additions & 14 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,24 @@ def test_composite_transformation(shape):
assert np.equal(wo_blocking.data, w_blocking.data).all()


@pytest.mark.parametrize("blockinner, openmp", [
(False, True), (False, True)
@pytest.mark.parametrize("blockinner, openmp, expected", [
(False, True, 't,x0_blk0,y0_blk0,x,y,z'), (False, False, 't,x0_blk0,y0_blk0,x,y,z'),
(True, True, 't,x0_blk0,y0_blk0,z0_blk0,x,y,z'),
(True, False, 't,x0_blk0,y0_blk0,z0_blk0,x,y,z')
])
def test_cache_blocking_structure(blockinner, openmp):
def test_cache_blocking_structure(blockinner, openmp, expected):
# Check code structure
_, op = _new_operator2((10, 31, 45), time_order=2,
opt=('blocking', {'openmp': True, 'blockinner': blockinner,
opt=('blocking', {'openmp': openmp, 'blockinner': blockinner,
'par-collapse-ncores': 1}))
if blockinner:
assert_structure(op, ['t,x0_blk0,y0_blk0,z0_blk0,x,y,z'])
else:
assert_structure(op, ['t,x0_blk0,y0_blk0,x,y,z'])

assert_structure(op, [expected])

# Check presence of openmp pragmas at the right place
trees = retrieve_iteration_tree(op)
assert len(trees[0][1].pragmas) == 1
assert 'omp for' in trees[0][1].pragmas[0].value
if openmp:
trees = retrieve_iteration_tree(op)
assert len(trees[0][1].pragmas) == 1
assert 'omp for' in trees[0][1].pragmas[0].value


def test_cache_blocking_structure_subdims():
Expand Down Expand Up @@ -473,12 +474,12 @@ def test_collapsing(self, eqns, expected, blocking):
op = Operator(eqns, opt=('blocking', 'simd', 'openmp',
{'blockinner': True, 'par-collapse-ncores': 1,
'par-collapse-work': 0}))
_, iterations = assert_structure(op, ['t,x0_blk0,y0_blk0,z0_blk0,x,y,z'])
assert_structure(op, ['t,x0_blk0,y0_blk0,z0_blk0,x,y,z'])
else:
op = Operator(eqns, opt=('simd', 'openmp', {'par-collapse-ncores': 1,
'par-collapse-work': 0}))
iterations = FindNodes(Iteration).visit(op)

iterations = FindNodes(Iteration).visit(op)
assert len(iterations) == len(expected)

# Check for presence of pragma omp + collapse clause
Expand Down Expand Up @@ -516,7 +517,8 @@ def test_collapsing_v2(self):

op = Operator(eq, opt=('advanced', {'openmp': True}))

_, iterations = assert_structure(op, ['co,ci,x,y'])
assert_structure(op, ['co,ci,x,y'])
iterations = FindNodes(Iteration).visit(op)
assert iterations[0].ncollapsed == 1
assert iterations[1].is_Vectorized
assert iterations[2].is_Sequential
Expand Down

0 comments on commit c38526b

Please sign in to comment.