Skip to content

Commit

Permalink
Array container support and batched (eager) communication (#154)
Browse files Browse the repository at this point in the history
* Support array containers in core grudge DG ops

* Convert wave example to use array containers

* Update projection routine to support array containers

* Convert mpi communication

* Use DOFArray-based traversal routines

* Containerize distributed (eager) MPI communication

* Use array container in wave-op-mpi example

* Clean up wave-op-mpi demo

* Use non-recursive array traversal

* Use map_array_container in trace pair exchange

* Remove unused imports

* Remove sketchy assert; shape differs when evaluating fluxes vs storing state

* Refactor containerization routines

* Fix suppressed logging error

* Fix forward flattening map

* Fix type check in projection routine

* dofdesc-itize descriptors first in projection

* Make flattening methods private and provide short docs

* Clean up flatten/unflatten functions and document

* Carefully track ordering of data in flattened array

* Use container flatten/unflatten functions from arraycontext

* Containerize elementwise reductions and add unit tests

* Containerize nodal reductions and add unit tests

* Update documentation to reflect container support

* Catch scalars in op.norm

* Remove flat_norm and use op.norm instead

* Separate initialization and completion of send/recv

* Fix documentation in operator module

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* Rename boundary communicator data/attributes

* Comment on nonblocking MPI communication process in communicator class

* Clean up comments related to rank-communication for trace pairs

* Clean up/clarify documentation related to array containers

* Update reduction unit tests

* Documentation updates

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* Fix comment formatting

* Containerize grad/div functions

* Add checks for empty arrays in grad/div operators

* Fix gradient op for vector-valued components

* Correct div helper for handling empty and nested obj arrays

* Clear up documentation and add clarifying notes

* More shared code in array container support (#191)

* Refactor grad/weak grad to use the same helper, simplify div empty handling

* Factor out some redundant code in elementwise reductions

Co-authored-by: Andreas Klöckner <inform@tiker.net>
  • Loading branch information
thomasgibson and inducer committed Dec 6, 2021
1 parent 95561d0 commit d48c9bc
Show file tree
Hide file tree
Showing 7 changed files with 703 additions and 365 deletions.
87 changes: 55 additions & 32 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@
import pyopencl as cl
import pyopencl.tools as cl_tools

from arraycontext import thaw, freeze
from arraycontext import (
thaw, freeze,
with_container_arithmetic,
dataclass_array_container
)
from grudge.array_context import PytatoPyOpenCLArrayContext, PyOpenCLArrayContext

from pytools.obj_array import flat_obj_array
from dataclasses import dataclass

from pytools.obj_array import flat_obj_array, make_obj_array

from meshmode.dof_array import DOFArray
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa

from grudge.discretization import DiscretizationCollection
Expand All @@ -51,42 +58,58 @@

# {{{ wave equation bits

@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
@dataclass_array_container
@dataclass(frozen=True)
class WaveState:
u: DOFArray
v: np.ndarray # [object array]

def __post_init__(self):
assert isinstance(self.v, np.ndarray) and self.v.dtype.char == "O"

@property
def array_context(self):
return self.u.array_context


def wave_flux(dcoll, c, w_tpair):
u = w_tpair[0]
v = w_tpair[1:]
u = w_tpair.u
v = w_tpair.v

normal = thaw(dcoll.normal(w_tpair.dd), u.int.array_context)

flux_weak = flat_obj_array(
np.dot(v.avg, normal),
normal*u.avg,
)
flux_weak = WaveState(
u=v.avg @ normal,
v=u.avg * normal
)

# upwind
v_jump = np.dot(normal, v.ext-v.int)
flux_weak += flat_obj_array(
0.5*(u.ext-u.int),
0.5*normal*v_jump,
)
v_jump = v.diff @ normal
flux_weak += WaveState(
u=0.5 * u.diff,
v=0.5 * v_jump * normal,
)

return op.project(dcoll, w_tpair.dd, "all_faces", c*flux_weak)


def wave_operator(dcoll, c, w):
u = w[0]
v = w[1:]
u = w.u
v = w.v

dir_u = op.project(dcoll, "vol", BTAG_ALL, u)
dir_v = op.project(dcoll, "vol", BTAG_ALL, v)
dir_bval = flat_obj_array(dir_u, dir_v)
dir_bc = flat_obj_array(-dir_u, dir_v)
dir_w = op.project(dcoll, "vol", BTAG_ALL, w)
dir_u = dir_w.u
dir_v = dir_w.v
dir_bval = WaveState(u=dir_u, v=dir_v)
dir_bc = WaveState(u=-dir_u, v=dir_v)

return (
op.inverse_mass(
dcoll,
flat_obj_array(
-c*op.weak_local_div(dcoll, v),
-c*op.weak_local_grad(dcoll, u)
WaveState(
u=-c*op.weak_local_div(dcoll, v),
v=-c*op.weak_local_grad(dcoll, u)
)
+ op.face_mass(
dcoll,
Expand Down Expand Up @@ -183,10 +206,10 @@ def main(ctx_factory, dim=2, order=3, visualize=False, lazy=False):
dcoll = DiscretizationCollection(actx, local_mesh, order=order,
mpi_communicator=comm)

fields = flat_obj_array(
bump(actx, dcoll),
[dcoll.zeros(actx) for i in range(dcoll.dim)]
)
fields = WaveState(
u=bump(actx, dcoll),
v=make_obj_array([dcoll.zeros(actx) for i in range(dcoll.dim)])
)

c = 1
dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c))
Expand All @@ -210,12 +233,12 @@ def rhs(t, w):

fields = rk4_step(fields, t, dt, compiled_rhs)

l2norm = actx.to_numpy(op.norm(dcoll, fields[0], 2))
l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2))

if istep % 10 == 0:
linfnorm = actx.to_numpy(op.norm(dcoll, fields[0], np.inf))
nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields[0]))
nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields[0]))
linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf))
nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u))
nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u))
if comm.rank == 0:
logger.info(f"step: {istep} t: {t} "
f"L2: {l2norm} "
Expand All @@ -227,8 +250,8 @@ def rhs(t, w):
comm,
f"fld-wave-eager-mpi-{{rank:03d}}-{istep:04d}.vtu",
[
("u", fields[0]),
("v", fields[1:]),
("u", fields.u),
("v", fields.v),
]
)

Expand Down
Loading

0 comments on commit d48c9bc

Please sign in to comment.