From d48c9bc2eca14763ea54e8f14b5f796b3d2c0753 Mon Sep 17 00:00:00 2001 From: "Thomas H. Gibson" Date: Mon, 6 Dec 2021 17:02:37 -0600 Subject: [PATCH] Array container support and batched (eager) communication (#154) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support array containers in core grudge DG ops * Convert wave example to use array containers * Update projection routine to support array containers * Convert mpi communication * Use DOFArray-based traversal routines * Containerize distributed (eager) MPI communication * Use array container in wave-op-mpi example * Clean up wave-op-mpi demo * Use non-recursive array traversal * Use map_array_container in trace pair exchange * Remove unused imports * Remove sketchy assert; shape differs when evaluating fluxes vs storing state * Refactor containerization routines * Fix suppressed logging error * Fix forward flattening map * Fix type check in projection routine * dofdesc-itize descriptors first in projection * Make flattening methods private and provide short docs * Clean up flatten/unflatten functions and document * Carefully track ordering of data in flattened array * Use container flatten/unflatten functions from arraycontext * Containerize elementwise reductions and add unit tests * Containerize nodal reductions and add unit tests * Update documentation to reflect container support * Catch scalars in op.norm * Remove flat_norm and use op.norm instead * Separate initialization and completion of send/recv * Fix documentation in operator module Co-authored-by: Andreas Klöckner * Rename boundary communicator data/attributes * Comment on nonblocking MPI communication process in communicator class * Clean up comments related to rank-communication for trace pairs * Clean up/clarify documentation related to array containers * Update reduction unit tests * Documentation updates Co-authored-by: Andreas Klöckner * Fix comment formatting * Containerize grad/div functions * Add checks for empty arrays in grad/div operators * Fix gradient op for vector-valued components * Correct div helper for handling empty and nested obj arrays * Clear up documentation and add clarifying notes * More shared code in array container support (#191) * Refactor grad/weak grad to use the same helper, simplify div empty handling * Factor out some redundant code in elementwise reductions Co-authored-by: Andreas Klöckner --- examples/wave/wave-op-mpi.py | 87 +++++--- grudge/op.py | 368 +++++++++++++++++++-------------- grudge/projection.py | 30 +-- grudge/reductions.py | 182 +++++++++++----- grudge/trace_pair.py | 200 ++++++++++-------- test/test_mpi_communication.py | 2 +- test/test_reductions.py | 199 +++++++++++++++--- 7 files changed, 703 insertions(+), 365 deletions(-) diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index 9b93785a..f994d0b0 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -31,11 +31,18 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw, freeze +from arraycontext import ( + thaw, freeze, + with_container_arithmetic, + dataclass_array_container +) from grudge.array_context import PytatoPyOpenCLArrayContext, PyOpenCLArrayContext -from pytools.obj_array import flat_obj_array +from dataclasses import dataclass +from pytools.obj_array import flat_obj_array, make_obj_array + +from meshmode.dof_array import DOFArray from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from grudge.discretization import DiscretizationCollection @@ -51,42 +58,58 @@ # {{{ wave equation bits +@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@dataclass_array_container +@dataclass(frozen=True) +class WaveState: + u: DOFArray + v: np.ndarray # [object array] + + def __post_init__(self): + assert isinstance(self.v, np.ndarray) and self.v.dtype.char == "O" + + @property + def array_context(self): + return self.u.array_context + + def wave_flux(dcoll, c, w_tpair): - u = w_tpair[0] - v = w_tpair[1:] + u = w_tpair.u + v = w_tpair.v normal = thaw(dcoll.normal(w_tpair.dd), u.int.array_context) - flux_weak = flat_obj_array( - np.dot(v.avg, normal), - normal*u.avg, - ) + flux_weak = WaveState( + u=v.avg @ normal, + v=u.avg * normal + ) # upwind - v_jump = np.dot(normal, v.ext-v.int) - flux_weak += flat_obj_array( - 0.5*(u.ext-u.int), - 0.5*normal*v_jump, - ) + v_jump = v.diff @ normal + flux_weak += WaveState( + u=0.5 * u.diff, + v=0.5 * v_jump * normal, + ) return op.project(dcoll, w_tpair.dd, "all_faces", c*flux_weak) def wave_operator(dcoll, c, w): - u = w[0] - v = w[1:] + u = w.u + v = w.v - dir_u = op.project(dcoll, "vol", BTAG_ALL, u) - dir_v = op.project(dcoll, "vol", BTAG_ALL, v) - dir_bval = flat_obj_array(dir_u, dir_v) - dir_bc = flat_obj_array(-dir_u, dir_v) + dir_w = op.project(dcoll, "vol", BTAG_ALL, w) + dir_u = dir_w.u + dir_v = dir_w.v + dir_bval = WaveState(u=dir_u, v=dir_v) + dir_bc = WaveState(u=-dir_u, v=dir_v) return ( op.inverse_mass( dcoll, - flat_obj_array( - -c*op.weak_local_div(dcoll, v), - -c*op.weak_local_grad(dcoll, u) + WaveState( + u=-c*op.weak_local_div(dcoll, v), + v=-c*op.weak_local_grad(dcoll, u) ) + op.face_mass( dcoll, @@ -183,10 +206,10 @@ def main(ctx_factory, dim=2, order=3, visualize=False, lazy=False): dcoll = DiscretizationCollection(actx, local_mesh, order=order, mpi_communicator=comm) - fields = flat_obj_array( - bump(actx, dcoll), - [dcoll.zeros(actx) for i in range(dcoll.dim)] - ) + fields = WaveState( + u=bump(actx, dcoll), + v=make_obj_array([dcoll.zeros(actx) for i in range(dcoll.dim)]) + ) c = 1 dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c)) @@ -210,12 +233,12 @@ def rhs(t, w): fields = rk4_step(fields, t, dt, compiled_rhs) - l2norm = actx.to_numpy(op.norm(dcoll, fields[0], 2)) + l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2)) if istep % 10 == 0: - linfnorm = actx.to_numpy(op.norm(dcoll, fields[0], np.inf)) - nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields[0])) - nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields[0])) + linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf)) + nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u)) + nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u)) if comm.rank == 0: logger.info(f"step: {istep} t: {t} " f"L2: {l2norm} " @@ -227,8 +250,8 @@ def rhs(t, w): comm, f"fld-wave-eager-mpi-{{rank:03d}}-{istep:04d}.vtu", [ - ("u", fields[0]), - ("v", fields[1:]), + ("u", fields.u), + ("v", fields.v), ] ) diff --git a/grudge/op.py b/grudge/op.py index 473e099d..8b1b95ce 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -50,7 +50,12 @@ """ -from arraycontext import ArrayContext +from arraycontext import ArrayContext, map_array_container +from arraycontext.container import ArrayOrContainerT + +from functools import partial + +from meshmode.dof_array import DOFArray from meshmode.transform_metadata import FirstAxisIsElementsTag from grudge.discretization import DiscretizationCollection @@ -58,8 +63,6 @@ from pytools import keyed_memoize_in from pytools.obj_array import obj_array_vectorize, make_obj_array -from meshmode.dof_array import DOFArray - import numpy as np import grudge.dof_desc as dof_desc @@ -154,6 +157,88 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # }}} +# {{{ common derivative "helpers" + +def _div_helper(dcoll, diff_func, *args): + if len(args) == 1: + vecs, = args + dd = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) + elif len(args) == 2: + dd, vecs = args + else: + raise TypeError("invalid number of arguments") + + if not isinstance(vecs, np.ndarray): + # vecs is not an object array -> treat as array container + return map_array_container(partial(_div_helper, dcoll, diff_func), vecs) + + assert vecs.dtype == object + + if vecs.size: + sample_vec = vecs[(0,)*vecs.ndim] + + if isinstance(sample_vec, np.ndarray): + assert sample_vec.dtype == object + # vecs is an object array containing further object arrays + # -> treat as array container + return map_array_container(partial(_div_helper, dcoll, diff_func), vecs) + + if vecs.shape[-1] != dcoll.ambient_dim: + raise ValueError("last/innermost dimension of *vecs* argument doesn't match " + "ambient dimension") + + div_result_shape = vecs.shape[:-1] + + if len(div_result_shape) == 0: + return sum(diff_func(dd, i, vec_i) for i, vec_i in enumerate(vecs)) + else: + result = np.zeros(div_result_shape, dtype=object) + for idx in np.ndindex(div_result_shape): + result[idx] = sum( + diff_func(dd, i, vec_i) for i, vec_i in enumerate(vecs[idx])) + return result + + +def _grad_helper(dcoll, scalar_grad, *args, nested): + if len(args) == 1: + vec, = args + dd_in = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) + elif len(args) == 2: + dd_in, vec = args + else: + raise TypeError("invalid number of arguments") + + if isinstance(vec, np.ndarray): + # Occasionally, data structures coming from *mirgecom* will + # contain empty object arrays as placeholders for fields. + # For example, species mass fractions is an empty object array when + # running in a single-species configuration. + # This hack here ensures that these empty arrays, at the very least, + # have their shape updated when applying the gradient operator + if vec.size == 0: + return vec.reshape(vec.shape + (dcoll.ambient_dim,)) + + # For containers with ndarray data (such as momentum/velocity), + # the gradient is matrix-valued, so we compute the gradient for + # each component. If requested (via 'not nested'), return a matrix of + # derivatives by stacking the results. + grad = obj_array_vectorize( + lambda el: _grad_helper( + dcoll, scalar_grad, dd_in, el, nested=nested), vec) + if nested: + return grad + else: + return np.stack(grad, axis=0) + + if not isinstance(vec, DOFArray): + return map_array_container( + partial(_grad_helper, scalar_grad, dcoll, dd_in, nested=nested), vec) + + return scalar_grad(dcoll, dd_in, vec) + +# }}} + + # {{{ Derivative operators def _reference_derivative_matrices(actx: ArrayContext, @@ -177,8 +262,23 @@ def get_ref_derivative_mats(grp): return get_ref_derivative_mats(out_element_group) +def _strong_scalar_grad(dcoll, dd_in, vec): + assert dd_in == dof_desc.as_dofdesc(dof_desc.DD_VOLUME) + + from grudge.geometry import inverse_surface_metric_derivative_mat + + discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + actx = vec.array_context + + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, + _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + return _gradient_kernel(actx, discr, discr, + _reference_derivative_matrices, inverse_jac_mat, vec, + metric_in_matvec=False) + + def local_grad( - dcoll: DiscretizationCollection, vec, *, nested=False) -> np.ndarray: + dcoll: DiscretizationCollection, vec, *, nested=False) -> ArrayOrContainerT: r"""Return the element-local gradient of a function :math:`f` represented by *vec*: @@ -187,34 +287,20 @@ def local_grad( \nabla|_E f = \left( \partial_x|_E f, \partial_y|_E f, \partial_z|_E f \right) - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :arg nested: return nested object arrays instead of a single multidimensional array if *vec* is non-scalar. :returns: an object array (possibly nested) of - :class:`~meshmode.dof_array.DOFArray`\ s. + :class:`~meshmode.dof_array.DOFArray`\ s or + :class:`~arraycontext.container.ArrayContainer`\ of object arrays. """ - if isinstance(vec, np.ndarray): - grad = obj_array_vectorize( - lambda el: local_grad(dcoll, el, nested=nested), vec) - if nested: - return grad - else: - return np.stack(grad, axis=0) - from grudge.geometry import inverse_surface_metric_derivative_mat + return _grad_helper(dcoll, _strong_scalar_grad, vec, nested=nested) - discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - actx = vec.array_context - - inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, - _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) - return _gradient_kernel(actx, discr, discr, - _reference_derivative_matrices, inverse_jac_mat, vec, - metric_in_matvec=False) - -def local_d_dx(dcoll: DiscretizationCollection, xyz_axis, vec): +def local_d_dx( + dcoll: DiscretizationCollection, xyz_axis, vec) -> ArrayOrContainerT: r"""Return the element-local derivative along axis *xyz_axis* of a function :math:`f` represented by *vec*: @@ -224,9 +310,14 @@ def local_d_dx(dcoll: DiscretizationCollection, xyz_axis, vec): :arg xyz_axis: an integer indicating the axis along which the derivative is taken. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. - :returns: a :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. """ + if not isinstance(vec, DOFArray): + return map_array_container(partial(local_d_dx, dcoll, xyz_axis), vec) + discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) actx = vec.array_context @@ -240,30 +331,7 @@ def local_d_dx(dcoll: DiscretizationCollection, xyz_axis, vec): metric_in_matvec=False) -def _div_helper(dcoll: DiscretizationCollection, diff_func, vecs): - if not isinstance(vecs, np.ndarray): - raise TypeError("argument must be an object array") - assert vecs.dtype == object - - if isinstance(vecs[(0,)*vecs.ndim], np.ndarray): - div_shape = vecs.shape - else: - if vecs.shape[-1] != dcoll.ambient_dim: - raise ValueError("last dimension of *vecs* argument doesn't match " - "ambient dimension") - div_shape = vecs.shape[:-1] - - if len(div_shape) == 0: - return sum(diff_func(i, vec_i) for i, vec_i in enumerate(vecs)) - else: - result = np.zeros(div_shape, dtype=object) - for idx in np.ndindex(div_shape): - result[idx] = sum( - diff_func(i, vec_i) for i, vec_i in enumerate(vecs[idx])) - return result - - -def local_div(dcoll: DiscretizationCollection, vecs): +def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainerT: r"""Return the element-local divergence of the vector function :math:`\mathbf{f}` represented by *vecs*: @@ -271,16 +339,18 @@ def local_div(dcoll: DiscretizationCollection, vecs): \nabla|_E \cdot \mathbf{f} = \sum_{i=1}^d \partial_{x_i}|_E \mathbf{f}_i - :arg vec: an object array of - a :class:`~meshmode.dof_array.DOFArray`\ s, - where the last axis of the array must have length - matching the volume dimension. - :returns: a :class:`~meshmode.dof_array.DOFArray`. + :arg vecs: an object array of + :class:`~meshmode.dof_array.DOFArray`\s or an + :class:`~arraycontext.container.ArrayContainer` object + with object array entries. The last axis of the array + must have length matching the volume dimension. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. """ - - return _div_helper(dcoll, - lambda i, subvec: local_d_dx(dcoll, i, subvec), - vecs) + return _div_helper( + dcoll, + lambda dd, i, subvec: local_d_dx(dcoll, i, subvec), + vecs) # }}} @@ -332,11 +402,28 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): in_element_group) -def weak_local_grad(dcoll: DiscretizationCollection, *args, nested=False): +def _weak_scalar_grad(dcoll, dd_in, vec): + from grudge.geometry import inverse_surface_metric_derivative_mat + + in_discr = dcoll.discr_from_dd(dd_in) + out_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + + actx = vec.array_context + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, + times_area_element=True, + _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + + return _gradient_kernel(actx, out_discr, in_discr, + _reference_stiffness_transpose_matrix, inverse_jac_mat, vec, + metric_in_matvec=True) + + +def weak_local_grad( + dcoll: DiscretizationCollection, *args, nested=False) -> ArrayOrContainerT: r"""Return the element-local weak gradient of the volume function represented by *vec*. - May be called with ``(vecs)`` or ``(dd, vecs)``. + May be called with ``(vec)`` or ``(dd_in, vec)``. Specifically, the function returns an object array where the :math:`i`-th component is the weak derivative with respect to the :math:`i`-th coordinate @@ -344,51 +431,24 @@ def weak_local_grad(dcoll: DiscretizationCollection, *args, nested=False): information. For non-scalar :math:`f`, the function will return a nested object array containing the component-wise weak derivatives. - :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + :arg dd_in: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :arg nested: return nested object arrays instead of a single multidimensional array if *vec* is non-scalar :returns: an object array (possibly nested) of - :class:`~meshmode.dof_array.DOFArray`\ s. + :class:`~meshmode.dof_array.DOFArray`\ s or + :class:`~arraycontext.container.ArrayContainer`\ of object arrays. """ - if len(args) == 1: - vec, = args - dd_in = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) - elif len(args) == 2: - dd_in, vec = args - else: - raise TypeError("invalid number of arguments") + return _grad_helper(dcoll, _weak_scalar_grad, *args, nested=nested) - if isinstance(vec, np.ndarray): - grad = obj_array_vectorize( - lambda el: weak_local_grad(dcoll, dd_in, el, nested=nested), vec) - if nested: - return grad - else: - return np.stack(grad, axis=0) - - from grudge.geometry import inverse_surface_metric_derivative_mat - in_discr = dcoll.discr_from_dd(dd_in) - out_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - - actx = vec.array_context - inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, - times_area_element=True, - _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) - - return _gradient_kernel(actx, out_discr, in_discr, - _reference_stiffness_transpose_matrix, inverse_jac_mat, vec, - metric_in_matvec=True) - - -def weak_local_d_dx(dcoll: DiscretizationCollection, *args): +def weak_local_d_dx(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Return the element-local weak derivative along axis *xyz_axis* of the volume function represented by *vec*. - May be called with ``(xyz_axis, vecs)`` or ``(dd, xyz_axis, vecs)``. + May be called with ``(xyz_axis, vec)`` or ``(dd_in, xyz_axis, vec)``. Specifically, this function computes the volume contribution of the weak derivative in the :math:`i`-th component (specified by *xyz_axis*) @@ -405,10 +465,14 @@ def weak_local_d_dx(dcoll: DiscretizationCollection, *args): is the elemental mass matrix (see :func:`mass` for more information), and :math:`\mathbf{f}|_E` is a vector of coefficients for :math:`f` on :math:`E`. + :arg dd_in: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + Defaults to the base volume discretization if not provided. :arg xyz_axis: an integer indicating the axis along which the derivative - is taken - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. - :returns: a :class:`~meshmode.dof_array.DOFArray`\ s. + is taken. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. """ if len(args) == 2: xyz_axis, vec = args @@ -418,6 +482,12 @@ def weak_local_d_dx(dcoll: DiscretizationCollection, *args): else: raise TypeError("invalid number of arguments") + if not isinstance(vec, DOFArray): + return map_array_container( + partial(weak_local_d_dx, dcoll, dd_in, xyz_axis), + vec + ) + from grudge.geometry import inverse_surface_metric_derivative_mat in_discr = dcoll.discr_from_dd(dd_in) @@ -434,7 +504,7 @@ def weak_local_d_dx(dcoll: DiscretizationCollection, *args): metric_in_matvec=True) -def weak_local_div(dcoll: DiscretizationCollection, *args): +def weak_local_div(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Return the element-local weak divergence of the vector volume function represented by *vecs*. @@ -455,23 +525,19 @@ def weak_local_div(dcoll: DiscretizationCollection, *args): :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. - :arg vec: a object array of - a :class:`~meshmode.dof_array.DOFArray`\ s, - where the last axis of the array must have length - matching the volume dimension. - :returns: a :class:`~meshmode.dof_array.DOFArray`. + :arg vecs: an object array of + :class:`~meshmode.dof_array.DOFArray`\s or an + :class:`~arraycontext.container.ArrayContainer` object + with object array entries. The last axis of the array + must have length matching the volume dimension. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec*. """ - if len(args) == 1: - vecs, = args - dd = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) - elif len(args) == 2: - dd, vecs = args - else: - raise TypeError("invalid number of arguments") - - return _div_helper(dcoll, - lambda i, subvec: weak_local_d_dx(dcoll, dd, i, subvec), - vecs) + return _div_helper( + dcoll, + lambda dd, i, subvec: weak_local_d_dx(dcoll, dd, i, subvec), + *args + ) # }}} @@ -517,11 +583,9 @@ def get_ref_mass_mat(out_grp, in_grp): def _apply_mass_operator( dcoll: DiscretizationCollection, dd_out, dd_in, vec): - if isinstance(vec, np.ndarray): - return obj_array_vectorize( - lambda vi: _apply_mass_operator(dcoll, - dd_out, - dd_in, vi), vec + if not isinstance(vec, DOFArray): + return map_array_container( + partial(_apply_mass_operator, dcoll, dd_out, dd_in), vec ) from grudge.geometry import area_element @@ -552,11 +616,11 @@ def _apply_mass_operator( ) -def mass(dcoll: DiscretizationCollection, *args): +def mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Return the action of the DG mass matrix on a vector (or vectors) of :class:`~meshmode.dof_array.DOFArray`\ s, *vec*. In the case of - *vec* being an object array of :class:`~meshmode.dof_array.DOFArray`\ s, - the mass operator is applied in the Kronecker sense (component-wise). + *vec* being an :class:`~arraycontext.container.ArrayContainer`, + the mass operator is applied component-wise. May be called with ``(vec)`` or ``(dd, vec)``. @@ -572,11 +636,10 @@ def mass(dcoll: DiscretizationCollection, *args): :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. - :returns: a :class:`~meshmode.dof_array.DOFArray` denoting the - application of the mass matrix, or an object array of - :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec*. """ if len(args) == 1: @@ -616,11 +679,9 @@ def get_ref_inv_mass_mat(grp): def _apply_inverse_mass_operator( dcoll: DiscretizationCollection, dd_out, dd_in, vec): - if isinstance(vec, np.ndarray): - return obj_array_vectorize( - lambda vi: _apply_inverse_mass_operator(dcoll, - dd_out, - dd_in, vi), vec + if not isinstance(vec, DOFArray): + return map_array_container( + partial(_apply_inverse_mass_operator, dcoll, dd_out, dd_in), vec ) from grudge.geometry import area_element @@ -655,12 +716,11 @@ def _apply_inverse_mass_operator( return DOFArray(actx, data=tuple(group_data)) -def inverse_mass(dcoll: DiscretizationCollection, vec): +def inverse_mass(dcoll: DiscretizationCollection, vec) -> ArrayOrContainerT: r"""Return the action of the DG mass matrix inverse on a vector (or vectors) of :class:`~meshmode.dof_array.DOFArray`\ s, *vec*. - In the case of *vec* being an object array of - :class:`~meshmode.dof_array.DOFArray`\ s, the inverse mass operator is - applied in the Kronecker sense (component-wise). + In the case of *vec* being an :class:`~arraycontext.container.ArrayContainer`, + the inverse mass operator is applied component-wise. For affine elements :math:`E`, the element-wise mass inverse is computed directly as the inverse of the (physical) mass matrix: @@ -686,11 +746,10 @@ def inverse_mass(dcoll: DiscretizationCollection, vec): where :math:`\widehat{\mathbf{M}}` is the reference mass matrix on :math:`\widehat{E}`. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. - :returns: a :class:`~meshmode.dof_array.DOFArray` denoting the - application of the inverse mass matrix, or an object array of - :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec*. """ return _apply_inverse_mass_operator( @@ -789,9 +848,9 @@ def get_ref_face_mass_mat(face_grp, vol_grp): def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec): - if isinstance(vec, np.ndarray): - return obj_array_vectorize( - lambda vi: _apply_face_mass_operator(dcoll, dd, vi), vec + if not isinstance(vec, DOFArray): + return map_array_container( + partial(_apply_face_mass_operator, dcoll, dd), vec ) from grudge.geometry import area_element @@ -833,11 +892,11 @@ def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec): ) -def face_mass(dcoll: DiscretizationCollection, *args): +def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Return the action of the DG face mass matrix on a vector (or vectors) of :class:`~meshmode.dof_array.DOFArray`\ s, *vec*. In the case of - *vec* being an object array of :class:`~meshmode.dof_array.DOFArray`\ s, - the mass operator is applied in the Kronecker sense (component-wise). + *vec* being an arbitrary :class:`~arraycontext.container.ArrayContainer`, + the face mass operator is applied component-wise. May be called with ``(vec)`` or ``(dd, vec)``. @@ -862,11 +921,10 @@ def face_mass(dcoll: DiscretizationCollection, *args): :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base ``"all_faces"`` discretization if not provided. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. - :returns: a :class:`~meshmode.dof_array.DOFArray` denoting the - application of the face mass matrix, or an object array of - :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec*. """ if len(args) == 1: diff --git a/grudge/projection.py b/grudge/projection.py index 3129253c..fdf95155 100644 --- a/grudge/projection.py +++ b/grudge/projection.py @@ -32,37 +32,41 @@ """ -import numpy as np +from functools import partial + +from arraycontext import map_array_container +from arraycontext.container import ArrayOrContainerT from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc -from numbers import Number +from meshmode.dof_array import DOFArray -from pytools.obj_array import obj_array_vectorize +from numbers import Number -def project(dcoll: DiscretizationCollection, src, tgt, vec): +def project( + dcoll: DiscretizationCollection, src, tgt, vec) -> ArrayOrContainerT: """Project from one discretization to another, e.g. from the volume to the boundary, or from the base to the an overintegrated quadrature discretization. :arg src: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. :arg tgt: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or a - :class:`~arraycontext.ArrayContainer`. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec*. """ src = as_dofdesc(src) tgt = as_dofdesc(tgt) - if src == tgt: + if isinstance(vec, Number) or src == tgt: return vec - if isinstance(vec, np.ndarray): - return obj_array_vectorize( - lambda el: project(dcoll, src, tgt, el), vec) - - if isinstance(vec, Number): - return vec + if not isinstance(vec, DOFArray): + return map_array_container( + partial(project, dcoll, src, tgt), vec + ) return dcoll.connection_from_dds(src, tgt)(vec) diff --git a/grudge/reductions.py b/grudge/reductions.py index b4897c14..c719be71 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -57,14 +57,19 @@ """ -from functools import reduce +from functools import reduce, partial -from arraycontext import make_loopy_program, DeviceScalar +from arraycontext import ( + make_loopy_program, + map_array_container, + serialize_container, + DeviceScalar +) +from arraycontext.container import ArrayOrContainerT from grudge.discretization import DiscretizationCollection from pytools import memoize_in -from pytools.obj_array import obj_array_vectorize from meshmode.dof_array import DOFArray @@ -78,10 +83,8 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> "DeviceScalar": r"""Return the vector p-norm of a function represented by its vector of degrees of freedom *vec*. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an object array of - a :class:`~meshmode.dof_array.DOFArray`\ s, - where the last axis of the array must have length - matching the volume dimension. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :arg p: an integer denoting the order of the integral norm. Currently, only values of 2 or `numpy.inf` are supported. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. @@ -115,8 +118,9 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. - :returns: a scalar denoting the nodal sum. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer`. + :returns: a device scalar denoting the nodal sum. """ comm = dcoll.mpi_communicator if comm is None: @@ -135,12 +139,15 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :returns: a scalar denoting the rank-local nodal sum. """ - if isinstance(vec, np.ndarray): - return sum(nodal_sum_loc(dcoll, dd, vec[idx]) - for idx in np.ndindex(vec.shape)) + if not isinstance(vec, DOFArray): + return sum( + nodal_sum_loc(dcoll, dd, comp) + for _, comp in serialize_container(vec) + ) actx = vec.array_context @@ -152,8 +159,9 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. - :returns: a scalar denoting the nodal minimum. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a device scalar denoting the nodal minimum. """ comm = dcoll.mpi_communicator if comm is None: @@ -173,12 +181,15 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :returns: a scalar denoting the rank-local nodal minimum. """ - if isinstance(vec, np.ndarray): - return min(nodal_min_loc(dcoll, dd, vec[idx]) - for idx in np.ndindex(vec.shape)) + if not isinstance(vec, DOFArray): + return min( + nodal_min_loc(dcoll, dd, comp) + for _, comp in serialize_container(vec) + ) actx = vec.array_context @@ -192,8 +203,9 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. - :returns: a scalar denoting the nodal maximum. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a device scalar denoting the nodal maximum. """ comm = dcoll.mpi_communicator if comm is None: @@ -213,12 +225,15 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray`. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer`. :returns: a scalar denoting the rank-local nodal maximum. """ - if isinstance(vec, np.ndarray): - return max(nodal_max_loc(dcoll, dd, vec[idx]) - for idx in np.ndindex(vec.shape)) + if not isinstance(vec, DOFArray): + return max( + nodal_max_loc(dcoll, dd, comp) + for _, comp in serialize_container(vec) + ) actx = vec.array_context @@ -232,8 +247,9 @@ def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": :class:`~meshmode.dof_array.DOFArray` of degrees of freedom. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` - :returns: a scalar denoting the evaluated integral. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a device scalar denoting the evaluated integral. """ from grudge.op import _apply_mass_operator @@ -250,13 +266,19 @@ def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar": # {{{ Elementwise reductions def _apply_elementwise_reduction( - op_name: str, dcoll: DiscretizationCollection, *args) -> DOFArray: + op_name: str, dcoll: DiscretizationCollection, + *args) -> ArrayOrContainerT: r"""Returns a vector of DOFs with all entries on each element set to the reduction operation *op_name* over all degrees of freedom. - :arg \*args: Arguments for the reduction operator, such as *dd* and *vec*. - :returns: a :class:`~meshmode.dof_array.DOFArray` or object arrary of - :class:`~meshmode.dof_array.DOFArray`s. + May be called with ``(vec)`` or ``(dd, vec)``. + + :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + Defaults to the base volume discretization if not provided. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer`. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer`. """ if len(args) == 1: vec, = args @@ -268,9 +290,9 @@ def _apply_elementwise_reduction( dd = dof_desc.as_dofdesc(dd) - if isinstance(vec, np.ndarray): - return obj_array_vectorize( - lambda vi: _apply_elementwise_reduction(op_name, dcoll, dd, vi), vec + if not isinstance(vec, DOFArray): + return map_array_container( + partial(_apply_elementwise_reduction, op_name, dcoll, dd), vec ) actx = vec.array_context @@ -317,67 +339,127 @@ def elementwise_prg(): ) -def elementwise_sum(dcoll: DiscretizationCollection, *args) -> DOFArray: +def elementwise_sum( + dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Returns a vector of DOFs with all entries on each element set to the sum of DOFs on that element. - May be called with ``(dcoll, vec)`` or ``(dcoll, dd, vec)``. + May be called with ``(vec)`` or ``(dd, vec)``. + + The input *vec* can either be a :class:`~meshmode.dof_array.DOFArray` or + an :class:`~arraycontext.container.ArrayContainer` with + :class:`~meshmode.dof_array.DOFArray` entries. If the underlying + array context (see :class:`arraycontext.ArrayContext`) for *vec* + supports nonscalar broadcasting, all :class:`~meshmode.dof_array.DOFArray` + entries will contain a single value for each element. Otherwise, the + entries will have the same number of degrees of freedom as *vec*, but + set to the same value. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` - :returns: a :class:`~meshmode.dof_array.DOFArray` whose entries + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec* whose entries denote the element-wise sum of *vec*. """ return _apply_elementwise_reduction("sum", dcoll, *args) -def elementwise_max(dcoll: DiscretizationCollection, *args) -> DOFArray: +def elementwise_max( + dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Returns a vector of DOFs with all entries on each element set to the maximum over all DOFs on that element. - May be called with ``(dcoll, vec)`` or ``(dcoll, dd, vec)``. + May be called with ``(vec)`` or ``(dd, vec)``. + + The input *vec* can either be a :class:`~meshmode.dof_array.DOFArray` or + an :class:`~arraycontext.container.ArrayContainer` with + :class:`~meshmode.dof_array.DOFArray` entries. If the underlying + array context (see :class:`arraycontext.ArrayContext`) for *vec* + supports nonscalar broadcasting, all :class:`~meshmode.dof_array.DOFArray` + entries will contain a single value for each element. Otherwise, the + entries will have the same number of degrees of freedom as *vec*, but + set to the same value. :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` - :returns: a :class:`~meshmode.dof_array.DOFArray` whose entries + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer`. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec* whose entries denote the element-wise max of *vec*. """ return _apply_elementwise_reduction("max", dcoll, *args) -def elementwise_min(dcoll: DiscretizationCollection, *args) -> DOFArray: +def elementwise_min( + dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: r"""Returns a vector of DOFs with all entries on each element set to the minimum over all DOFs on that element. - May be called with ``(dcoll, vec)`` or ``(dcoll, dd, vec)``. + May be called with ``(vec)`` or ``(dd, vec)``. + + The input *vec* can either be a :class:`~meshmode.dof_array.DOFArray` or + an :class:`~arraycontext.container.ArrayContainer` with + :class:`~meshmode.dof_array.DOFArray` entries. If the underlying + array context (see :class:`arraycontext.ArrayContext`) for *vec* + supports nonscalar broadcasting, all :class:`~meshmode.dof_array.DOFArray` + entries will contain a single value for each element. Otherwise, the + entries will have the same number of degrees of freedom as *vec*, but + set to the same value. :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` - :returns: a :class:`~meshmode.dof_array.DOFArray` whose entries + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec* whose entries denote the element-wise min of *vec*. """ return _apply_elementwise_reduction("min", dcoll, *args) -def elementwise_integral(dcoll: DiscretizationCollection, dd, vec) -> DOFArray: +def elementwise_integral( + dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: """Numerically integrates a function represented by a :class:`~meshmode.dof_array.DOFArray` of degrees of freedom in each element of a discretization, given by *dd*. + May be called with ``(vec)`` or ``(dd, vec)``. + + The input *vec* can either be a :class:`~meshmode.dof_array.DOFArray` or + an :class:`~arraycontext.container.ArrayContainer` with + :class:`~meshmode.dof_array.DOFArray` entries. If the underlying + array context (see :class:`arraycontext.ArrayContext`) for *vec* + supports nonscalar broadcasting, all :class:`~meshmode.dof_array.DOFArray` + entries will contain a single value for each element. Otherwise, the + entries will have the same number of degrees of freedom as *vec*, but + set to the same value. + + :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` - :returns: a :class:`~meshmode.dof_array.DOFArray` containing the + Defaults to the base volume discretization if not provided. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + :returns: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` like *vec* containing the elementwise integral if *vec*. """ - from grudge.op import _apply_mass_operator + if len(args) == 1: + vec, = args + dd = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) + elif len(args) == 2: + dd, vec = args + else: + raise TypeError("invalid number of arguments") dd = dof_desc.as_dofdesc(dd) + from grudge.op import _apply_mass_operator + ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0 return elementwise_sum( dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 26dad69a..b634219b 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -49,19 +49,23 @@ from arraycontext import ( ArrayContainer, with_container_arithmetic, - dataclass_array_container + dataclass_array_container, + get_container_context_recursively, + flatten, to_numpy, + unflatten, from_numpy ) +from arraycontext.container import ArrayOrContainerT from dataclasses import dataclass from numbers import Number from pytools import memoize_on_first_arg -from pytools.obj_array import obj_array_vectorize, make_obj_array +from pytools.obj_array import obj_array_vectorize from grudge.discretization import DiscretizationCollection +from grudge.projection import project -from meshmode.dof_array import flatten, unflatten from meshmode.mesh import BTAG_PARTITION import numpy as np @@ -134,7 +138,7 @@ def __len__(self): @property def int(self): - """A class:`~meshmode.dof_array.DOFArray` or + """A :class:`~meshmode.dof_array.DOFArray` or :class:`~arraycontext.ArrayContainer` of them representing the interior value to be used for the flux. """ @@ -142,7 +146,7 @@ def int(self): @property def ext(self): - """A class:`~meshmode.dof_array.DOFArray` or + """A :class:`~meshmode.dof_array.DOFArray` or :class:`~arraycontext.ArrayContainer` of them representing the exterior value to be used for the flux. """ @@ -150,7 +154,7 @@ def ext(self): @property def avg(self): - """A class:`~meshmode.dof_array.DOFArray` or + """A :class:`~meshmode.dof_array.DOFArray` or :class:`~arraycontext.ArrayContainer` of them representing the average of the interior and exterior values. """ @@ -158,7 +162,7 @@ def avg(self): @property def diff(self): - """A class:`~meshmode.dof_array.DOFArray` or + """A :class:`~meshmode.dof_array.DOFArray` or :class:`~arraycontext.ArrayContainer` of them representing the difference (exterior - interior) of the pair values. """ @@ -173,13 +177,18 @@ def bdry_trace_pair( dcoll: DiscretizationCollection, dd, interior, exterior) -> TracePair: """Returns a trace pair defined on the exterior boundary. Input arguments are assumed to already be defined on the boundary denoted by *dd*. + If the input arguments *interior* and *exterior* are + :class:`~arraycontext.container.ArrayContainer` objects, they must both + have the same internal structure. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one, which describes the boundary discretization. - :arg interior: a :class:`~meshmode.dof_array.DOFArray` that contains data + :arg interior: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them that contains data already on the boundary representing the interior value to be used for the flux. - :arg exterior: a :class:`~meshmode.dof_array.DOFArray` that contains data + :arg exterior: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them that contains data that already lives on the boundary representing the exterior value to be used for the flux. :returns: a :class:`TracePair` on the boundary. @@ -193,22 +202,26 @@ def bv_trace_pair( argument is assumed to be defined on the volume discretization, and will therefore be restricted to the boundary *dd* prior to creating a :class:`TracePair`. + If the input arguments *interior* and *exterior* are + :class:`~arraycontext.container.ArrayContainer` objects, they must both + have the same internal structure. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one, which describes the boundary discretization. - :arg interior: a :class:`~meshmode.dof_array.DOFArray` that contains data + :arg interior: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` that contains data defined in the volume, which will be restricted to the boundary denoted by *dd*. The result will be used as the interior value for the flux. - :arg exterior: a :class:`~meshmode.dof_array.DOFArray` that contains data + :arg exterior: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` that contains data that already lives on the boundary representing the exterior value to be used for the flux. :returns: a :class:`TracePair` on the boundary. """ - from grudge.op import project - - interior = project(dcoll, "vol", dd, interior) - return bdry_trace_pair(dcoll, dd, interior, exterior) + return bdry_trace_pair( + dcoll, dd, project(dcoll, "vol", dd, interior), exterior + ) # }}} @@ -220,6 +233,10 @@ def local_interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair *dcoll* with a discretization tag specified by *discr_tag*. This does not include interior faces on different MPI ranks. + + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. + For certain applications, it may be useful to distinguish between rank-local and cross-rank trace pairs. For example, avoiding unnecessary communication of derived quantities (i.e. temperature) on partition @@ -227,13 +244,8 @@ def local_interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair user applications to distinguish between rank-local and cross-rank contributions can also help enable overlapping communication with computation. - - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. :returns: a :class:`TracePair` object. """ - from grudge.op import project - i = project(dcoll, "vol", "int_faces", vec) def get_opposite_face(el): @@ -266,8 +278,8 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec) -> list: if those are needed in isolation. Similarly, :func:`cross_rank_trace_pairs` provides only the trace pairs defined on cross-rank boundaries. - :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of - :class:`~meshmode.dof_array.DOFArray`\ s. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :returns: a :class:`list` of :class:`TracePair` objects. """ return ( @@ -288,58 +300,77 @@ def connected_ranks(dcoll: DiscretizationCollection): class _RankBoundaryCommunication: base_tag = 1273 - def __init__(self, dcoll: DiscretizationCollection, - remote_rank, vol_field, tag=None): - self.tag = self.base_tag - if tag is not None: - self.tag += tag - - self.dcoll = dcoll - self.array_context = vol_field.array_context - self.remote_btag = BTAG_PARTITION(remote_rank) - self.bdry_discr = dcoll.discr_from_dd(self.remote_btag) + def __init__(self, + dcoll: DiscretizationCollection, + array_container: ArrayOrContainerT, + remote_rank, tag=None): + actx = get_container_context_recursively(array_container) + btag = BTAG_PARTITION(remote_rank) - from grudge.op import project + local_bdry_data = project(dcoll, "vol", btag, array_container) + comm = dcoll.mpi_communicator - self.local_dof_array = project(dcoll, "vol", self.remote_btag, vol_field) + self.dcoll = dcoll + self.array_context = actx + self.remote_btag = btag + self.bdry_discr = dcoll.discr_from_dd(btag) + self.local_bdry_data = local_bdry_data + self.local_bdry_data_np = \ + to_numpy(flatten(self.local_bdry_data, actx), actx) - local_data = self.array_context.to_numpy(flatten(self.local_dof_array)) - comm = self.dcoll.mpi_communicator + self.tag = self.base_tag + if tag is not None: + self.tag += tag - self.send_req = comm.Isend(local_data, remote_rank, tag=self.tag) - self.remote_data_host = np.empty_like(local_data) - self.recv_req = comm.Irecv(self.remote_data_host, remote_rank, self.tag) + # Here, we initialize both send and recieve operations through + # mpi4py `Request` (MPI_Request) instances for comm.Isend (MPI_Isend) + # and comm.Irecv (MPI_Irecv) respectively. These initiate non-blocking + # point-to-point communication requests and require explicit management + # via the use of wait (MPI_Wait, MPI_Waitall, MPI_Waitany, MPI_Waitsome), + # test (MPI_Test, MPI_Testall, MPI_Testany, MPI_Testsome), and cancel + # (MPI_Cancel). The rank-local data `self.local_bdry_data_np` will have its + # associated memory buffer sent across connected ranks and must not be + # modified at the Python level during this process. Completion of the + # requests is handled in :meth:`finish`. + # + # For more details on the mpi4py semantics, see: + # https://mpi4py.readthedocs.io/en/stable/overview.html#nonblocking-communications + # + # NOTE: mpi4py currently (2021-11-03) holds a reference to the send + # memory buffer for (i.e. `self.local_bdry_data_np`) until the send + # requests is complete, however it is not clear that this is documented + # behavior. We hold on to the buffer (via the instance attribute) + # as well, just in case. + self.send_req = comm.Isend(self.local_bdry_data_np, + remote_rank, + tag=self.tag) + self.remote_data_host_numpy = np.empty_like(self.local_bdry_data_np) + self.recv_req = comm.Irecv(self.remote_data_host_numpy, + remote_rank, + tag=self.tag) def finish(self): + # Wait for the nonblocking receive request to complete before + # accessing the data self.recv_req.Wait() + # Nonblocking receive is complete, we can now access the data and apply + # the boundary-swap connection actx = self.array_context - remote_dof_array = unflatten( - self.array_context, self.bdry_discr, - actx.from_numpy(self.remote_data_host) - ) - + remote_bdry_data_flat = from_numpy(self.remote_data_host_numpy, actx) + remote_bdry_data = unflatten(self.local_bdry_data, + remote_bdry_data_flat, actx) bdry_conn = self.dcoll.distributed_boundary_swap_connection( - dof_desc.as_dofdesc(dof_desc.DTAG_BOUNDARY(self.remote_btag)) - ) - swapped_remote_dof_array = bdry_conn(remote_dof_array) + dof_desc.as_dofdesc(dof_desc.DTAG_BOUNDARY(self.remote_btag))) + swapped_remote_bdry_data = bdry_conn(remote_bdry_data) + # Complete the nonblocking send request associated with communicating + # `self.local_bdry_data_np` self.send_req.Wait() return TracePair(self.remote_btag, - interior=self.local_dof_array, - exterior=swapped_remote_dof_array) - - -def _cross_rank_trace_pairs_scalar_field( - dcoll: DiscretizationCollection, vec, tag=None) -> list: - if isinstance(vec, Number): - return [TracePair(BTAG_PARTITION(remote_rank), interior=vec, exterior=vec) - for remote_rank in connected_ranks(dcoll)] - else: - rbcomms = [_RankBoundaryCommunication(dcoll, remote_rank, vec, tag=tag) - for remote_rank in connected_ranks(dcoll)] - return [rbcomm.finish() for rbcomm in rbcomms] + interior=self.local_bdry_data, + exterior=swapped_remote_bdry_data) def cross_rank_trace_pairs( @@ -357,38 +388,27 @@ def cross_rank_trace_pairs( components, respectively. Each of the TracePair components are structured like *ary*. - :arg ary: a single :class:`~meshmode.dof_array.DOFArray`, or an object - array of :class:`~meshmode.dof_array.DOFArray`\ s - of arbitrary shape. + If *ary* is a number, rather than a + :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them, it is assumed + that the same number is being communicated on every rank. + + :arg ary: a :class:`~meshmode.dof_array.DOFArray` or an + :class:`~arraycontext.container.ArrayContainer` of them. :returns: a :class:`list` of :class:`TracePair` objects. """ - if isinstance(ary, np.ndarray): - oshape = ary.shape - comm_vec = ary.flatten() - - n, = comm_vec.shape - result = {} - # FIXME: Batch this communication rather than - # doing it in sequence. - for ivec in range(n): - for rank_tpair in _cross_rank_trace_pairs_scalar_field( - dcoll, comm_vec[ivec]): - assert isinstance(rank_tpair.dd.domain_tag, dof_desc.DTAG_BOUNDARY) - assert isinstance(rank_tpair.dd.domain_tag.tag, BTAG_PARTITION) - result[rank_tpair.dd.domain_tag.tag.part_nr, ivec] = rank_tpair - - return [ - TracePair( - dd=dof_desc.as_dofdesc( - dof_desc.DTAG_BOUNDARY(BTAG_PARTITION(remote_rank))), - interior=make_obj_array([ - result[remote_rank, i].int for i in range(n)]).reshape(oshape), - exterior=make_obj_array([ - result[remote_rank, i].ext for i in range(n)]).reshape(oshape) - ) for remote_rank in connected_ranks(dcoll) - ] - else: - return _cross_rank_trace_pairs_scalar_field(dcoll, ary, tag=tag) + if isinstance(ary, Number): + # NOTE: Assumed that the same number is passed on every rank + return [TracePair(BTAG_PARTITION(remote_rank), interior=ary, exterior=ary) + for remote_rank in connected_ranks(dcoll)] + + # Initialize and post all sends/receives + rank_bdry_communcators = [ + _RankBoundaryCommunication(dcoll, ary, remote_rank, tag=tag) + for remote_rank in connected_ranks(dcoll) + ] + # Complete send/receives and return communicated data + return [rc.finish() for rc in rank_bdry_communcators] # }}} diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 4f627670..cfbcdd0d 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -99,7 +99,7 @@ def simple_mpi_communication_entrypoint(): for tpair in op.cross_rank_trace_pairs(dcoll, myfunc)) ) - (all_faces_func - bdry_faces_func) - error = flat_norm(hopefully_zero, ord=np.inf) + error = actx.to_numpy(flat_norm(hopefully_zero, ord=np.inf)) print(__file__) with np.printoptions(threshold=100000000, suppress=True): diff --git a/test/test_reductions.py b/test/test_reductions.py index 64d1eec3..67291e49 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -24,10 +24,19 @@ import numpy as np -from arraycontext import thaw +from dataclasses import dataclass + +from arraycontext import ( + thaw, + with_container_arithmetic, + dataclass_array_container, + pytest_generate_tests_for_array_contexts +) + +from meshmode.dof_array import DOFArray from grudge.array_context import PytestPyOpenCLArrayContextFactory -from arraycontext import pytest_generate_tests_for_array_contexts + pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) @@ -73,30 +82,20 @@ def h(x): h_ref = actx.to_numpy(flatten(fields[2])) concat_fields = np.concatenate([f_ref, g_ref, h_ref]) - for inner_grudge_op, np_op in [(op.nodal_sum, np.sum), + for grudge_op, np_op in [(op.nodal_sum, np.sum), (op.nodal_max, np.max), (op.nodal_min, np.min)]: - # FIXME: Remove this once all grudge reductions return device scalars - def grudge_op(dcoll, dd, vec): - res = inner_grudge_op(dcoll, dd, vec) - - from numbers import Number - if not isinstance(res, Number): - return actx.to_numpy(res) - else: - return res - # Componentwise reduction checks - assert np.isclose(grudge_op(dcoll, "vol", fields[0]), + assert np.isclose(actx.to_numpy(grudge_op(dcoll, "vol", fields[0])), np_op(f_ref), rtol=1e-13) - assert np.isclose(grudge_op(dcoll, "vol", fields[1]), + assert np.isclose(actx.to_numpy(grudge_op(dcoll, "vol", fields[1])), np_op(g_ref), rtol=1e-13) - assert np.isclose(grudge_op(dcoll, "vol", fields[2]), + assert np.isclose(actx.to_numpy(grudge_op(dcoll, "vol", fields[2])), np_op(h_ref), rtol=1e-13) # Test nodal reductions work on object arrays - assert np.isclose(grudge_op(dcoll, "vol", fields), + assert np.isclose(actx.to_numpy(grudge_op(dcoll, "vol", fields)), np_op(concat_fields), rtol=1e-13) @@ -118,11 +117,11 @@ def f(x): mins = [] maxs = [] sums = [] - for gidx, grp_f in enumerate(field): + for grp_f in field: min_res = np.empty(grp_f.shape) max_res = np.empty(grp_f.shape) sum_res = np.empty(grp_f.shape) - for eidx in range(dcoll._volume_discr.groups[gidx].nelements): + for eidx in range(dcoll.mesh.nelements): element_data = actx.to_numpy(grp_f[eidx]) min_res[eidx, :] = np.min(element_data) max_res[eidx, :] = np.max(element_data) @@ -131,8 +130,6 @@ def f(x): maxs.append(actx.from_numpy(max_res)) sums.append(actx.from_numpy(sum_res)) - from meshmode.dof_array import DOFArray, flat_norm - ref_mins = DOFArray(actx, data=tuple(mins)) ref_maxs = DOFArray(actx, data=tuple(maxs)) ref_sums = DOFArray(actx, data=tuple(sums)) @@ -141,9 +138,163 @@ def f(x): elem_maxs = op.elementwise_max(dcoll, field) elem_sums = op.elementwise_sum(dcoll, field) - assert flat_norm(elem_mins - ref_mins, ord=np.inf) < 1.e-15 - assert flat_norm(elem_maxs - ref_maxs, ord=np.inf) < 1.e-15 - assert flat_norm(elem_sums - ref_sums, ord=np.inf) < 1.e-15 + assert actx.to_numpy(op.norm(dcoll, elem_mins - ref_mins, np.inf)) < 1.e-15 + assert actx.to_numpy(op.norm(dcoll, elem_maxs - ref_maxs, np.inf)) < 1.e-15 + assert actx.to_numpy(op.norm(dcoll, elem_sums - ref_sums, np.inf)) < 1.e-15 + + +# {{{ Array container tests + +@with_container_arithmetic(bcast_obj_array=False, + eq_comparison=False, rel_comparison=False) +@dataclass_array_container +@dataclass(frozen=True) +class MyContainer: + name: str + mass: DOFArray + momentum: np.ndarray + enthalpy: DOFArray + + @property + def array_context(self): + return self.mass.array_context + + +def test_nodal_reductions_with_container(actx_factory): + actx = actx_factory() + + from mesh_data import BoxMeshBuilder + builder = BoxMeshBuilder(ambient_dim=2) + + mesh = builder.get_mesh(4, builder.mesh_order) + dcoll = DiscretizationCollection(actx, mesh, order=builder.order) + x = thaw(dcoll.nodes(), actx) + + def f(x): + return -actx.np.sin(10*x[0]) * actx.np.cos(2*x[1]) + + def g(x): + return actx.np.cos(2*x[0]) * actx.np.sin(10*x[1]) + + def h(x): + return -actx.np.tan(5*x[0]) * actx.np.tan(0.5*x[1]) + + mass = f(x) + g(x) + momentum = make_obj_array([f(x)/g(x), h(x)]) + enthalpy = h(x) - g(x) + + ary_container = MyContainer(name="container", + mass=mass, + momentum=momentum, + enthalpy=enthalpy) + + mass_ref = actx.to_numpy(flatten(mass)) + momentum_ref = np.concatenate([actx.to_numpy(mom_i) + for mom_i in flatten(momentum)]) + enthalpy_ref = actx.to_numpy(flatten(enthalpy)) + concat_fields = np.concatenate([mass_ref, momentum_ref, enthalpy_ref]) + + for grudge_op, np_op in [(op.nodal_sum, np.sum), + (op.nodal_max, np.max), + (op.nodal_min, np.min)]: + + assert np.isclose(actx.to_numpy(grudge_op(dcoll, "vol", ary_container)), + np_op(concat_fields), rtol=1e-13) + + # Check norm reduction + assert np.isclose(actx.to_numpy(op.norm(dcoll, ary_container, np.inf)), + np.linalg.norm(concat_fields, ord=np.inf), + rtol=1e-13) + + +def test_elementwise_reductions_with_container(actx_factory): + actx = actx_factory() + + from mesh_data import BoxMeshBuilder + builder = BoxMeshBuilder(ambient_dim=2) + + nelements = 4 + mesh = builder.get_mesh(nelements, builder.mesh_order) + dcoll = DiscretizationCollection(actx, mesh, order=builder.order) + x = thaw(dcoll.nodes(), actx) + + def f(x): + return actx.np.sin(x[0]) * actx.np.sin(x[1]) + + def g(x): + return actx.np.cos(x[0]) * actx.np.cos(x[1]) + + def h(x): + return actx.np.cos(x[0]) * actx.np.sin(x[1]) + + mass = 2*f(x) + 0.5*g(x) + momentum = make_obj_array([f(x)/g(x), h(x)]) + enthalpy = 3*h(x) - g(x) + + ary_container = MyContainer(name="container", + mass=mass, + momentum=momentum, + enthalpy=enthalpy) + + def _get_ref_data(field): + mins = [] + maxs = [] + sums = [] + for grp_f in field: + min_res = np.empty(grp_f.shape) + max_res = np.empty(grp_f.shape) + sum_res = np.empty(grp_f.shape) + for eidx in range(dcoll.mesh.nelements): + element_data = actx.to_numpy(grp_f[eidx]) + min_res[eidx, :] = np.min(element_data) + max_res[eidx, :] = np.max(element_data) + sum_res[eidx, :] = np.sum(element_data) + mins.append(actx.from_numpy(min_res)) + maxs.append(actx.from_numpy(max_res)) + sums.append(actx.from_numpy(sum_res)) + min_field = DOFArray(actx, data=tuple(mins)) + max_field = DOFArray(actx, data=tuple(maxs)) + sums_field = DOFArray(actx, data=tuple(sums)) + return min_field, max_field, sums_field + + min_mass, max_mass, sums_mass = _get_ref_data(mass) + min_enthalpy, max_enthalpy, sums_enthalpy = _get_ref_data(enthalpy) + min_mom_x, max_mom_x, sums_mom_x = _get_ref_data(momentum[0]) + min_mom_y, max_mom_y, sums_mom_y = _get_ref_data(momentum[1]) + min_momentum = make_obj_array([min_mom_x, min_mom_y]) + max_momentum = make_obj_array([max_mom_x, max_mom_y]) + sums_momentum = make_obj_array([sums_mom_x, sums_mom_y]) + + reference_min = MyContainer( + name="Reference min", + mass=min_mass, + momentum=min_momentum, + enthalpy=min_enthalpy + ) + + reference_max = MyContainer( + name="Reference max", + mass=max_mass, + momentum=max_momentum, + enthalpy=max_enthalpy + ) + + reference_sum = MyContainer( + name="Reference sums", + mass=sums_mass, + momentum=sums_momentum, + enthalpy=sums_enthalpy + ) + + elem_mins = op.elementwise_min(dcoll, ary_container) + elem_maxs = op.elementwise_max(dcoll, ary_container) + elem_sums = op.elementwise_sum(dcoll, ary_container) + + assert actx.to_numpy(op.norm(dcoll, elem_mins - reference_min, np.inf)) < 1.e-14 + assert actx.to_numpy(op.norm(dcoll, elem_maxs - reference_max, np.inf)) < 1.e-14 + assert actx.to_numpy(op.norm(dcoll, elem_sums - reference_sum, np.inf)) < 1.e-14 + +# }}} # You can test individual routines by typing