Skip to content

Commit

Permalink
add dimension dependent halo fill order (#407)
Browse files Browse the repository at this point in the history
Co-authored-by: Sylwester Arabas <sylwester.arabas@uj.edu.pl>
  • Loading branch information
Delcior and slayoo committed Aug 26, 2023
1 parent 14f0a56 commit 7a4848f
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 52 deletions.
4 changes: 3 additions & 1 deletion PyMPDATA/impl/traversals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
class Traversals:
"""groups njit-ted traversals for a given grid, halo, jit_flags and threading settings"""

def __init__(self, *, grid, halo, jit_flags, n_threads, left_first, buffer_size):
def __init__(
self, *, grid, halo, jit_flags, n_threads, left_first: tuple, buffer_size
):
assert not (n_threads > 1 and len(grid) == 1)
tmp = (
grid[OUTER] if len(grid) > 1 else 0,
Expand Down
11 changes: 6 additions & 5 deletions PyMPDATA/impl/traversals_halos_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
from PyMPDATA.impl.traversals_common import make_common


def _make_fill_halos_scalar(*, jit_flags, halo, n_dims, chunker, spanner, left_first):
def _make_fill_halos_scalar(
*, jit_flags, halo, n_dims, chunker, spanner, left_first: tuple
):
common = make_common(jit_flags, spanner, chunker)
kwargs = {
"jit_flags": jit_flags,
"halo": halo,
"n_dims": n_dims,
"left_first": left_first,
}
mid3d = __make_mid3d(**kwargs)
outer = __make_outer(**kwargs)
inner = __make_inner(**kwargs)
mid3d = __make_mid3d(**kwargs, left_first=left_first[MID3D])
outer = __make_outer(**kwargs, left_first=left_first[OUTER])
inner = __make_inner(**kwargs, left_first=left_first[INNER])

@numba.njit(**jit_flags)
# pylint: disable=too-many-arguments,too-many-branches
Expand Down
23 changes: 12 additions & 11 deletions PyMPDATA/impl/traversals_halos_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from PyMPDATA.impl.traversals_common import make_common


def _make_fill_halos_vector(*, jit_flags, halo, n_dims, chunker, spanner, left_first):
def _make_fill_halos_vector(
*, jit_flags, halo, n_dims, chunker, spanner, left_first: tuple
):
common = make_common(jit_flags, spanner, chunker)
halos = ((halo - 1, halo, halo), (halo, halo - 1, halo), (halo, halo, halo - 1))
# pylint:disable=duplicate-code
Expand All @@ -27,18 +29,17 @@ def _make_fill_halos_vector(*, jit_flags, halo, n_dims, chunker, spanner, left_f
"halo": halo,
"n_dims": n_dims,
"halos": halos,
"left_first": left_first,
}

outer_outer = __make_outer_outer(**kwargs)
outer_mid3d = __make_outer_mid3d(**kwargs)
outer_inner = __make_outer_inner(**kwargs)
mid3d_outer = __make_mid3d_outer(**kwargs)
mid3d_mid3d = __make_mid3d_mid3d(**kwargs)
mid3d_inner = __make_mid3d_inner(**kwargs)
inner_outer = __make_inner_outer(**kwargs)
inner_inner = __make_inner_inner(**kwargs)
inner_mid3d = __make_inner_mid3d(**kwargs)
outer_outer = __make_outer_outer(**kwargs, left_first=left_first[OUTER])
outer_mid3d = __make_outer_mid3d(**kwargs, left_first=left_first[MID3D])
outer_inner = __make_outer_inner(**kwargs, left_first=left_first[INNER])
mid3d_outer = __make_mid3d_outer(**kwargs, left_first=left_first[OUTER])
mid3d_mid3d = __make_mid3d_mid3d(**kwargs, left_first=left_first[MID3D])
mid3d_inner = __make_mid3d_inner(**kwargs, left_first=left_first[INNER])
inner_outer = __make_inner_outer(**kwargs, left_first=left_first[OUTER])
inner_mid3d = __make_inner_mid3d(**kwargs, left_first=left_first[MID3D])
inner_inner = __make_inner_inner(**kwargs, left_first=left_first[INNER])

@numba.njit(**jit_flags)
# pylint: disable=too-many-arguments
Expand Down
8 changes: 5 additions & 3 deletions PyMPDATA/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numba.core.errors import NumbaExperimentalFeatureWarning

from .impl.clock import clock
from .impl.enumerations import ARG_DATA, IMPL_BC, IMPL_META_AND_DATA
from .impl.enumerations import ARG_DATA, IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM
from .impl.formulae_antidiff import make_antidiff
from .impl.formulae_axpy import make_axpy
from .impl.formulae_flux import make_flux_first_pass, make_flux_subsequent
Expand All @@ -32,7 +32,7 @@ def __init__(
non_unit_g_factor: bool = False,
grid: (tuple, None) = None,
n_threads: (int, None) = None,
left_first: bool = True,
left_first: (tuple, None) = None,
buffer_size: int = 0
):
if n_dims is not None and grid is not None:
Expand All @@ -47,6 +47,8 @@ def __init__(
raise NotImplementedError()
if n_threads is None:
n_threads = numba.get_num_threads()
if left_first is None:
left_first = tuple([True] * MAX_DIM_NUM)

self.__options = options
self.__n_threads = 1 if n_dims == 1 else n_threads
Expand Down Expand Up @@ -108,7 +110,7 @@ def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields):
@lru_cache()
# pylint: disable=too-many-locals,too-many-statements,too-many-arguments
def make_step_impl(
options, non_unit_g_factor, grid, n_threads, left_first, buffer_size
options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size
):
"""returns (and caches) an njit-ted stepping function and a traversals pair"""
traversals = Traversals(
Expand Down
9 changes: 5 additions & 4 deletions tests/unit_tests/test_boundary_condition_extrapolated_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from PyMPDATA import Options, ScalarField, VectorField
from PyMPDATA.boundary_conditions import Extrapolated
from PyMPDATA.impl.enumerations import MAX_DIM_NUM
from PyMPDATA.impl.traversals import Traversals

JIT_FLAGS = Options().jit_flags
Expand All @@ -25,7 +26,7 @@ class TestBoundaryConditionExtrapolated:
np.array([1, 2, 3, 4], dtype=complex),
),
)
def test_1d_scalar(data, halo, n_threads=1, left_first=True):
def test_1d_scalar(data, halo, n_threads=1):
# arrange
boundary_conditions = (Extrapolated(),)
field = ScalarField(data, halo, boundary_conditions)
Expand All @@ -35,7 +36,7 @@ def test_1d_scalar(data, halo, n_threads=1, left_first=True):
halo=halo,
jit_flags=JIT_FLAGS,
n_threads=n_threads,
left_first=left_first,
left_first=tuple([True] * MAX_DIM_NUM),
buffer_size=0,
)
field.assemble(traversals)
Expand Down Expand Up @@ -70,7 +71,7 @@ def test_1d_scalar(data, halo, n_threads=1, left_first=True):
@staticmethod
@pytest.mark.parametrize("data", (np.array([0, 2, 3, 0], dtype=float),))
@pytest.mark.parametrize("halo", (2, 3, 4))
def test_1d_vector(data, halo, n_threads=1, left_first=True):
def test_1d_vector(data, halo, n_threads=1):
# arrange
boundary_condition = (Extrapolated(),)
field = VectorField((data,), halo, boundary_condition)
Expand All @@ -80,7 +81,7 @@ def test_1d_vector(data, halo, n_threads=1, left_first=True):
halo=halo,
jit_flags=JIT_FLAGS,
n_threads=n_threads,
left_first=left_first,
left_first=tuple([True] * MAX_DIM_NUM),
buffer_size=0,
)
field.assemble(traversals)
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/test_boundary_condition_polar_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from PyMPDATA import Options, ScalarField, VectorField
from PyMPDATA.boundary_conditions import Periodic, Polar
from PyMPDATA.impl.enumerations import INNER, OUTER
from PyMPDATA.impl.enumerations import INNER, MAX_DIM_NUM, OUTER
from PyMPDATA.impl.traversals import Traversals

JIT_FLAGS = Options().jit_flags
Expand All @@ -18,7 +18,7 @@ class TestPolarBoundaryCondition:
@staticmethod
@pytest.mark.parametrize("halo", (1,))
@pytest.mark.parametrize("n_threads", (1, 2, 3))
def test_scalar_2d(halo, n_threads, left_first=True):
def test_scalar_2d(halo, n_threads):
# arrange
data = np.array([[1, 6], [2, 7], [3, 8], [4, 9]], dtype=float)
boundary_condition = (
Expand All @@ -32,7 +32,7 @@ def test_scalar_2d(halo, n_threads, left_first=True):
halo=halo,
jit_flags=JIT_FLAGS,
n_threads=n_threads,
left_first=left_first,
left_first=tuple([True] * MAX_DIM_NUM),
buffer_size=0,
)
field.assemble(traversals)
Expand All @@ -58,7 +58,7 @@ def test_scalar_2d(halo, n_threads, left_first=True):
@staticmethod
@pytest.mark.parametrize("halo", (1,))
@pytest.mark.parametrize("n_threads", (1, 2, 3))
def test_vector_2d(halo, n_threads, left_first=True):
def test_vector_2d(halo, n_threads):
# arrange
grid = (4, 2)
data = (
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_vector_2d(halo, n_threads, left_first=True):
halo=halo,
jit_flags=JIT_FLAGS,
n_threads=n_threads,
left_first=left_first,
left_first=tuple([True] * MAX_DIM_NUM),
buffer_size=0,
)
field.assemble(traversals)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_formulae_upwind.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from PyMPDATA import Options, ScalarField, VectorField
from PyMPDATA.boundary_conditions import Periodic
from PyMPDATA.impl.enumerations import IMPL_BC, IMPL_META_AND_DATA
from PyMPDATA.impl.enumerations import IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM
from PyMPDATA.impl.formulae_upwind import make_upwind
from PyMPDATA.impl.meta import _Impl
from PyMPDATA.impl.traversals import Traversals
Expand All @@ -24,7 +24,7 @@ def test_formulae_upwind():
halo=halo,
jit_flags=options.jit_flags,
n_threads=1,
left_first=True,
left_first=tuple([True] * MAX_DIM_NUM),
buffer_size=0,
)
upwind = make_upwind(
Expand Down
7 changes: 3 additions & 4 deletions tests/unit_tests/test_traversals.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ def test_apply_scalar(
halo: int,
grid: tuple,
loop: bool,
left_first: bool = True,
):
if len(grid) == 1 and n_threads > 1:
return
cmn = make_commons(grid, halo, n_threads, left_first)
cmn = make_commons(grid, halo, n_threads, tuple([True] * MAX_DIM_NUM))

# arrange
sut = cmn.traversals.apply_scalar(loop=loop)
Expand Down Expand Up @@ -172,10 +171,10 @@ def test_apply_scalar(
@pytest.mark.parametrize("halo", (1, 2, 3))
@pytest.mark.parametrize("grid", ((3, 4, 5), (5, 6), (11,)))
# pylint: disable-next=too-many-locals,redefined-outer-name
def test_apply_vector(n_threads, halo: int, grid: tuple, left_first: bool = True):
def test_apply_vector(n_threads, halo: int, grid: tuple):
if len(grid) == 1 and n_threads > 1:
return
cmn = make_commons(grid, halo, n_threads, left_first)
cmn = make_commons(grid, halo, n_threads, tuple([True] * MAX_DIM_NUM))

# arrange
sut = cmn.traversals.apply_vector()
Expand Down
32 changes: 15 additions & 17 deletions tests/unit_tests/test_traversals_with_bc_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from PyMPDATA import Options, ScalarField, VectorField
from PyMPDATA.boundary_conditions import Periodic
from PyMPDATA.impl.meta import INNER, MID3D, OUTER
from PyMPDATA.impl.enumerations import INNER, MAX_DIM_NUM, MID3D, OUTER
from PyMPDATA.impl.traversals import Traversals
from tests.unit_tests.fixtures.n_threads import n_threads

Expand Down Expand Up @@ -52,13 +52,13 @@ def indices(arg1, arg2=None, arg3=None):

@lru_cache()
# pylint: disable-next=redefined-outer-name
def make_traversals(grid, halo, n_threads, left_first):
def make_traversals(grid, halo, n_threads):
return Traversals(
grid=grid,
halo=halo,
jit_flags=JIT_FLAGS,
n_threads=n_threads,
left_first=left_first,
left_first=tuple([True] * MAX_DIM_NUM),
buffer_size=0,
)

Expand All @@ -77,20 +77,19 @@ class TestPeriodicBoundaryCondition:
@pytest.mark.parametrize("side", (LEFT, RIGHT))
@pytest.mark.parametrize("dim", DIMENSIONS)
# pylint: disable-next=redefined-outer-name,too-many-arguments
def test_scalar(data, halo, side, n_threads, dim, left_first=True):
def test_scalar(data, halo, side, n_threads, dim):
n_dims = len(data.shape)

if n_dims == 1 and dim != INNER:
return
pytest.skip()
if n_dims == 2 and dim == MID3D:
return
pytest.skip()
if n_dims == 1 and n_threads > 1:
return
pytest.skip()

# arrange
field = ScalarField(data, halo, tuple(Periodic() for _ in range(n_dims)))
traversals = make_traversals(
grid=field.grid, halo=halo, n_threads=n_threads, left_first=left_first
)
traversals = make_traversals(grid=field.grid, halo=halo, n_threads=n_threads)
field.assemble(traversals)
meta_and_data, fill_halos = field.impl
sut = traversals._code["fill_halos_scalar"] # pylint:disable=protected-access
Expand Down Expand Up @@ -141,20 +140,19 @@ def test_scalar(data, halo, side, n_threads, dim, left_first=True):
@pytest.mark.parametrize("comp", DIMENSIONS)
@pytest.mark.parametrize("dim_offset", (0, 1, 2))
# pylint: disable=redefined-outer-name,too-many-arguments,too-many-branches
def test_vector(data, halo, side, n_threads, comp, dim_offset, left_first=True):
def test_vector(data, halo, side, n_threads, comp, dim_offset):
n_dims = len(data)

if n_dims == 1 and n_threads > 1:
return
pytest.skip()
if n_dims == 1 and (comp != INNER or dim_offset != 0):
return
pytest.skip()
if n_dims == 2 and (comp == MID3D or dim_offset == 2):
return
pytest.skip()

# arrange
field = VectorField(data, halo, tuple(Periodic() for _ in range(n_dims)))
traversals = make_traversals(
grid=field.grid, halo=halo, n_threads=n_threads, left_first=left_first
)
traversals = make_traversals(grid=field.grid, halo=halo, n_threads=n_threads)
field.assemble(traversals)
meta_and_data, fill_halos = field.impl
meta_and_data = (
Expand Down

0 comments on commit 7a4848f

Please sign in to comment.