diff --git a/PyMPDATA/impl/traversals.py b/PyMPDATA/impl/traversals.py index 1b5f6da9..b565a4e9 100644 --- a/PyMPDATA/impl/traversals.py +++ b/PyMPDATA/impl/traversals.py @@ -18,7 +18,9 @@ class Traversals: """groups njit-ted traversals for a given grid, halo, jit_flags and threading settings""" - def __init__(self, *, grid, halo, jit_flags, n_threads, left_first, buffer_size): + def __init__( + self, *, grid, halo, jit_flags, n_threads, left_first: tuple, buffer_size + ): assert not (n_threads > 1 and len(grid) == 1) tmp = ( grid[OUTER] if len(grid) > 1 else 0, diff --git a/PyMPDATA/impl/traversals_halos_scalar.py b/PyMPDATA/impl/traversals_halos_scalar.py index 444ad5ed..c27a6ce0 100644 --- a/PyMPDATA/impl/traversals_halos_scalar.py +++ b/PyMPDATA/impl/traversals_halos_scalar.py @@ -16,17 +16,18 @@ from PyMPDATA.impl.traversals_common import make_common -def _make_fill_halos_scalar(*, jit_flags, halo, n_dims, chunker, spanner, left_first): +def _make_fill_halos_scalar( + *, jit_flags, halo, n_dims, chunker, spanner, left_first: tuple +): common = make_common(jit_flags, spanner, chunker) kwargs = { "jit_flags": jit_flags, "halo": halo, "n_dims": n_dims, - "left_first": left_first, } - mid3d = __make_mid3d(**kwargs) - outer = __make_outer(**kwargs) - inner = __make_inner(**kwargs) + mid3d = __make_mid3d(**kwargs, left_first=left_first[MID3D]) + outer = __make_outer(**kwargs, left_first=left_first[OUTER]) + inner = __make_inner(**kwargs, left_first=left_first[INNER]) @numba.njit(**jit_flags) # pylint: disable=too-many-arguments,too-many-branches diff --git a/PyMPDATA/impl/traversals_halos_vector.py b/PyMPDATA/impl/traversals_halos_vector.py index e0b75b97..61c6aa8f 100644 --- a/PyMPDATA/impl/traversals_halos_vector.py +++ b/PyMPDATA/impl/traversals_halos_vector.py @@ -18,7 +18,9 @@ from PyMPDATA.impl.traversals_common import make_common -def _make_fill_halos_vector(*, jit_flags, halo, n_dims, chunker, spanner, left_first): +def _make_fill_halos_vector( + *, jit_flags, halo, n_dims, chunker, spanner, left_first: tuple +): common = make_common(jit_flags, spanner, chunker) halos = ((halo - 1, halo, halo), (halo, halo - 1, halo), (halo, halo, halo - 1)) # pylint:disable=duplicate-code @@ -27,18 +29,17 @@ def _make_fill_halos_vector(*, jit_flags, halo, n_dims, chunker, spanner, left_f "halo": halo, "n_dims": n_dims, "halos": halos, - "left_first": left_first, } - outer_outer = __make_outer_outer(**kwargs) - outer_mid3d = __make_outer_mid3d(**kwargs) - outer_inner = __make_outer_inner(**kwargs) - mid3d_outer = __make_mid3d_outer(**kwargs) - mid3d_mid3d = __make_mid3d_mid3d(**kwargs) - mid3d_inner = __make_mid3d_inner(**kwargs) - inner_outer = __make_inner_outer(**kwargs) - inner_inner = __make_inner_inner(**kwargs) - inner_mid3d = __make_inner_mid3d(**kwargs) + outer_outer = __make_outer_outer(**kwargs, left_first=left_first[OUTER]) + outer_mid3d = __make_outer_mid3d(**kwargs, left_first=left_first[MID3D]) + outer_inner = __make_outer_inner(**kwargs, left_first=left_first[INNER]) + mid3d_outer = __make_mid3d_outer(**kwargs, left_first=left_first[OUTER]) + mid3d_mid3d = __make_mid3d_mid3d(**kwargs, left_first=left_first[MID3D]) + mid3d_inner = __make_mid3d_inner(**kwargs, left_first=left_first[INNER]) + inner_outer = __make_inner_outer(**kwargs, left_first=left_first[OUTER]) + inner_mid3d = __make_inner_mid3d(**kwargs, left_first=left_first[MID3D]) + inner_inner = __make_inner_inner(**kwargs, left_first=left_first[INNER]) @numba.njit(**jit_flags) # pylint: disable=too-many-arguments diff --git a/PyMPDATA/stepper.py b/PyMPDATA/stepper.py index 6ada5aa4..a906ece5 100644 --- a/PyMPDATA/stepper.py +++ b/PyMPDATA/stepper.py @@ -8,7 +8,7 @@ from numba.core.errors import NumbaExperimentalFeatureWarning from .impl.clock import clock -from .impl.enumerations import ARG_DATA, IMPL_BC, IMPL_META_AND_DATA +from .impl.enumerations import ARG_DATA, IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM from .impl.formulae_antidiff import make_antidiff from .impl.formulae_axpy import make_axpy from .impl.formulae_flux import make_flux_first_pass, make_flux_subsequent @@ -32,7 +32,7 @@ def __init__( non_unit_g_factor: bool = False, grid: (tuple, None) = None, n_threads: (int, None) = None, - left_first: bool = True, + left_first: (tuple, None) = None, buffer_size: int = 0 ): if n_dims is not None and grid is not None: @@ -47,6 +47,8 @@ def __init__( raise NotImplementedError() if n_threads is None: n_threads = numba.get_num_threads() + if left_first is None: + left_first = tuple([True] * MAX_DIM_NUM) self.__options = options self.__n_threads = 1 if n_dims == 1 else n_threads @@ -108,7 +110,7 @@ def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields): @lru_cache() # pylint: disable=too-many-locals,too-many-statements,too-many-arguments def make_step_impl( - options, non_unit_g_factor, grid, n_threads, left_first, buffer_size + options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size ): """returns (and caches) an njit-ted stepping function and a traversals pair""" traversals = Traversals( diff --git a/tests/unit_tests/test_boundary_condition_extrapolated_1d.py b/tests/unit_tests/test_boundary_condition_extrapolated_1d.py index 6520405d..bb01cc08 100644 --- a/tests/unit_tests/test_boundary_condition_extrapolated_1d.py +++ b/tests/unit_tests/test_boundary_condition_extrapolated_1d.py @@ -8,6 +8,7 @@ from PyMPDATA import Options, ScalarField, VectorField from PyMPDATA.boundary_conditions import Extrapolated +from PyMPDATA.impl.enumerations import MAX_DIM_NUM from PyMPDATA.impl.traversals import Traversals JIT_FLAGS = Options().jit_flags @@ -25,7 +26,7 @@ class TestBoundaryConditionExtrapolated: np.array([1, 2, 3, 4], dtype=complex), ), ) - def test_1d_scalar(data, halo, n_threads=1, left_first=True): + def test_1d_scalar(data, halo, n_threads=1): # arrange boundary_conditions = (Extrapolated(),) field = ScalarField(data, halo, boundary_conditions) @@ -35,7 +36,7 @@ def test_1d_scalar(data, halo, n_threads=1, left_first=True): halo=halo, jit_flags=JIT_FLAGS, n_threads=n_threads, - left_first=left_first, + left_first=tuple([True] * MAX_DIM_NUM), buffer_size=0, ) field.assemble(traversals) @@ -70,7 +71,7 @@ def test_1d_scalar(data, halo, n_threads=1, left_first=True): @staticmethod @pytest.mark.parametrize("data", (np.array([0, 2, 3, 0], dtype=float),)) @pytest.mark.parametrize("halo", (2, 3, 4)) - def test_1d_vector(data, halo, n_threads=1, left_first=True): + def test_1d_vector(data, halo, n_threads=1): # arrange boundary_condition = (Extrapolated(),) field = VectorField((data,), halo, boundary_condition) @@ -80,7 +81,7 @@ def test_1d_vector(data, halo, n_threads=1, left_first=True): halo=halo, jit_flags=JIT_FLAGS, n_threads=n_threads, - left_first=left_first, + left_first=tuple([True] * MAX_DIM_NUM), buffer_size=0, ) field.assemble(traversals) diff --git a/tests/unit_tests/test_boundary_condition_polar_2d.py b/tests/unit_tests/test_boundary_condition_polar_2d.py index 465c91f4..a805ce1f 100644 --- a/tests/unit_tests/test_boundary_condition_polar_2d.py +++ b/tests/unit_tests/test_boundary_condition_polar_2d.py @@ -8,7 +8,7 @@ from PyMPDATA import Options, ScalarField, VectorField from PyMPDATA.boundary_conditions import Periodic, Polar -from PyMPDATA.impl.enumerations import INNER, OUTER +from PyMPDATA.impl.enumerations import INNER, MAX_DIM_NUM, OUTER from PyMPDATA.impl.traversals import Traversals JIT_FLAGS = Options().jit_flags @@ -18,7 +18,7 @@ class TestPolarBoundaryCondition: @staticmethod @pytest.mark.parametrize("halo", (1,)) @pytest.mark.parametrize("n_threads", (1, 2, 3)) - def test_scalar_2d(halo, n_threads, left_first=True): + def test_scalar_2d(halo, n_threads): # arrange data = np.array([[1, 6], [2, 7], [3, 8], [4, 9]], dtype=float) boundary_condition = ( @@ -32,7 +32,7 @@ def test_scalar_2d(halo, n_threads, left_first=True): halo=halo, jit_flags=JIT_FLAGS, n_threads=n_threads, - left_first=left_first, + left_first=tuple([True] * MAX_DIM_NUM), buffer_size=0, ) field.assemble(traversals) @@ -58,7 +58,7 @@ def test_scalar_2d(halo, n_threads, left_first=True): @staticmethod @pytest.mark.parametrize("halo", (1,)) @pytest.mark.parametrize("n_threads", (1, 2, 3)) - def test_vector_2d(halo, n_threads, left_first=True): + def test_vector_2d(halo, n_threads): # arrange grid = (4, 2) data = ( @@ -93,7 +93,7 @@ def test_vector_2d(halo, n_threads, left_first=True): halo=halo, jit_flags=JIT_FLAGS, n_threads=n_threads, - left_first=left_first, + left_first=tuple([True] * MAX_DIM_NUM), buffer_size=0, ) field.assemble(traversals) diff --git a/tests/unit_tests/test_formulae_upwind.py b/tests/unit_tests/test_formulae_upwind.py index 57d4b8b7..da3fb718 100644 --- a/tests/unit_tests/test_formulae_upwind.py +++ b/tests/unit_tests/test_formulae_upwind.py @@ -6,7 +6,7 @@ from PyMPDATA import Options, ScalarField, VectorField from PyMPDATA.boundary_conditions import Periodic -from PyMPDATA.impl.enumerations import IMPL_BC, IMPL_META_AND_DATA +from PyMPDATA.impl.enumerations import IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM from PyMPDATA.impl.formulae_upwind import make_upwind from PyMPDATA.impl.meta import _Impl from PyMPDATA.impl.traversals import Traversals @@ -24,7 +24,7 @@ def test_formulae_upwind(): halo=halo, jit_flags=options.jit_flags, n_threads=1, - left_first=True, + left_first=tuple([True] * MAX_DIM_NUM), buffer_size=0, ) upwind = make_upwind( diff --git a/tests/unit_tests/test_traversals.py b/tests/unit_tests/test_traversals.py index d1006dd7..c2e41ac9 100644 --- a/tests/unit_tests/test_traversals.py +++ b/tests/unit_tests/test_traversals.py @@ -107,11 +107,10 @@ def test_apply_scalar( halo: int, grid: tuple, loop: bool, - left_first: bool = True, ): if len(grid) == 1 and n_threads > 1: return - cmn = make_commons(grid, halo, n_threads, left_first) + cmn = make_commons(grid, halo, n_threads, tuple([True] * MAX_DIM_NUM)) # arrange sut = cmn.traversals.apply_scalar(loop=loop) @@ -172,10 +171,10 @@ def test_apply_scalar( @pytest.mark.parametrize("halo", (1, 2, 3)) @pytest.mark.parametrize("grid", ((3, 4, 5), (5, 6), (11,))) # pylint: disable-next=too-many-locals,redefined-outer-name - def test_apply_vector(n_threads, halo: int, grid: tuple, left_first: bool = True): + def test_apply_vector(n_threads, halo: int, grid: tuple): if len(grid) == 1 and n_threads > 1: return - cmn = make_commons(grid, halo, n_threads, left_first) + cmn = make_commons(grid, halo, n_threads, tuple([True] * MAX_DIM_NUM)) # arrange sut = cmn.traversals.apply_vector() diff --git a/tests/unit_tests/test_traversals_with_bc_periodic.py b/tests/unit_tests/test_traversals_with_bc_periodic.py index 2126f7dd..4c755834 100644 --- a/tests/unit_tests/test_traversals_with_bc_periodic.py +++ b/tests/unit_tests/test_traversals_with_bc_periodic.py @@ -8,7 +8,7 @@ from PyMPDATA import Options, ScalarField, VectorField from PyMPDATA.boundary_conditions import Periodic -from PyMPDATA.impl.meta import INNER, MID3D, OUTER +from PyMPDATA.impl.enumerations import INNER, MAX_DIM_NUM, MID3D, OUTER from PyMPDATA.impl.traversals import Traversals from tests.unit_tests.fixtures.n_threads import n_threads @@ -52,13 +52,13 @@ def indices(arg1, arg2=None, arg3=None): @lru_cache() # pylint: disable-next=redefined-outer-name -def make_traversals(grid, halo, n_threads, left_first): +def make_traversals(grid, halo, n_threads): return Traversals( grid=grid, halo=halo, jit_flags=JIT_FLAGS, n_threads=n_threads, - left_first=left_first, + left_first=tuple([True] * MAX_DIM_NUM), buffer_size=0, ) @@ -77,20 +77,19 @@ class TestPeriodicBoundaryCondition: @pytest.mark.parametrize("side", (LEFT, RIGHT)) @pytest.mark.parametrize("dim", DIMENSIONS) # pylint: disable-next=redefined-outer-name,too-many-arguments - def test_scalar(data, halo, side, n_threads, dim, left_first=True): + def test_scalar(data, halo, side, n_threads, dim): n_dims = len(data.shape) + if n_dims == 1 and dim != INNER: - return + pytest.skip() if n_dims == 2 and dim == MID3D: - return + pytest.skip() if n_dims == 1 and n_threads > 1: - return + pytest.skip() # arrange field = ScalarField(data, halo, tuple(Periodic() for _ in range(n_dims))) - traversals = make_traversals( - grid=field.grid, halo=halo, n_threads=n_threads, left_first=left_first - ) + traversals = make_traversals(grid=field.grid, halo=halo, n_threads=n_threads) field.assemble(traversals) meta_and_data, fill_halos = field.impl sut = traversals._code["fill_halos_scalar"] # pylint:disable=protected-access @@ -141,20 +140,19 @@ def test_scalar(data, halo, side, n_threads, dim, left_first=True): @pytest.mark.parametrize("comp", DIMENSIONS) @pytest.mark.parametrize("dim_offset", (0, 1, 2)) # pylint: disable=redefined-outer-name,too-many-arguments,too-many-branches - def test_vector(data, halo, side, n_threads, comp, dim_offset, left_first=True): + def test_vector(data, halo, side, n_threads, comp, dim_offset): n_dims = len(data) + if n_dims == 1 and n_threads > 1: - return + pytest.skip() if n_dims == 1 and (comp != INNER or dim_offset != 0): - return + pytest.skip() if n_dims == 2 and (comp == MID3D or dim_offset == 2): - return + pytest.skip() # arrange field = VectorField(data, halo, tuple(Periodic() for _ in range(n_dims))) - traversals = make_traversals( - grid=field.grid, halo=halo, n_threads=n_threads, left_first=left_first - ) + traversals = make_traversals(grid=field.grid, halo=halo, n_threads=n_threads) field.assemble(traversals) meta_and_data, fill_halos = field.impl meta_and_data = (