Skip to content

Commit

Permalink
Use comm.scatter for mesh distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed May 3, 2023
1 parent 2d37bfa commit 7380319
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 35 deletions.
18 changes: 10 additions & 8 deletions examples/wave/wave-min-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class WaveTag:

def main(ctx_factory, dim=2, order=4, visualize=False):
comm = MPI.COMM_WORLD
num_parts = comm.Get_size()
num_parts = comm.size

cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx)
Expand All @@ -60,10 +60,10 @@ def main(ctx_factory, dim=2, order=4, visualize=False):
force_device_scalars=True,
)

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map
from meshmode.mesh.processing import partition_mesh

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(
a=(-0.5,)*dim,
Expand All @@ -72,14 +72,16 @@ def main(ctx_factory, dim=2, order=4, visualize=False):

logger.info("%d elements", mesh.nelements)

part_per_element = get_partition_by_pymetis(mesh, num_parts)

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]
local_mesh = comm.scatter(parts)

del mesh

else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=order)

Expand Down
17 changes: 10 additions & 7 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def main(ctx_factory, dim=2, order=3,
queue = cl.CommandQueue(cl_ctx)

comm = MPI.COMM_WORLD
num_parts = comm.Get_size()
num_parts = comm.size

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
Expand All @@ -195,12 +195,12 @@ def main(ctx_factory, dim=2, order=3,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
force_device_scalars=True)

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map
from meshmode.mesh.processing import partition_mesh

nel_1d = 16

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
if use_nonaffine_mesh:
from meshmode.mesh.generation import generate_warped_rect_mesh
# FIXME: *generate_warped_rect_mesh* in meshmode warps a
Expand All @@ -218,14 +218,17 @@ def main(ctx_factory, dim=2, order=3,

logger.info("%d elements", mesh.nelements)

part_per_element = get_partition_by_pymetis(mesh, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
local_mesh = comm.scatter(parts)

del mesh

else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

from meshmode.discretization.poly_element import \
QuadratureSimplexGroupFactory, \
Expand Down
44 changes: 24 additions & 20 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,26 @@ def _test_func_comparison_mpi_communication_entrypoint(actx):

comm = actx.mpi_communicator

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
from meshmode.distributed import (
get_partition_by_pymetis, membership_list_to_map)
from meshmode.mesh import BTAG_ALL
from meshmode.mesh.processing import partition_mesh

num_parts = comm.Get_size()
num_parts = comm.size

mesh_dist = MPIMeshDistributor(comm)

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(a=(-1,)*2,
b=(1,)*2,
nelements_per_axis=(2,)*2)

part_per_element = get_partition_by_pymetis(mesh, num_parts)

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]
local_mesh = comm.scatter(parts)
else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=5)

Expand Down Expand Up @@ -188,28 +190,30 @@ def test_mpi_wave_op(actx_class, num_ranks):

def _test_mpi_wave_op_entrypoint(actx, visualize=False):
comm = actx.mpi_communicator
i_local_rank = comm.Get_rank()
num_parts = comm.Get_size()
num_parts = comm.size

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
from meshmode.distributed import (
get_partition_by_pymetis, membership_list_to_map)
from meshmode.mesh.processing import partition_mesh

dim = 2
order = 4

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(a=(-0.5,)*dim,
b=(0.5,)*dim,
nelements_per_axis=(16,)*dim)

part_per_element = get_partition_by_pymetis(mesh, num_parts)

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]
local_mesh = comm.scatter(parts)

del mesh
else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=order)

Expand Down Expand Up @@ -270,7 +274,7 @@ def rhs(t, w):

final_t = 4
nsteps = int(final_t/dt)
logger.info("[%04d] dt %.5e nsteps %4d", i_local_rank, dt, nsteps)
logger.info("[%04d] dt %.5e nsteps %4d", comm.rank, dt, nsteps)

step = 0

Expand Down Expand Up @@ -308,7 +312,7 @@ def rhs(t, w):

logmgr.tick_after()
logmgr.close()
logger.info("Rank %d exiting", i_local_rank)
logger.info("Rank %d exiting", comm.rank)

# }}}

Expand Down

0 comments on commit 7380319

Please sign in to comment.