Skip to content

Commit

Permalink
Fix to incorrect rank-view indexing calculations in mpi_array.decompo…
Browse files Browse the repository at this point in the history
…sition.CartesianDecomposition.calculate_rank_view_slices.
  • Loading branch information
Shane-J-Latham committed Jun 5, 2017
1 parent 2ee1111 commit 52abddc
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 29 deletions.
75 changes: 66 additions & 9 deletions mpi_array/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,26 @@
import array_split.split # noqa: F401
from array_split import ARRAY_BOUNDS
from array_split.split import convert_halo_to_array_form, shape_factors as _shape_factors
import mpi_array.logging as _logging
import numpy as _np


__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _pkg_resources.resource_string("mpi_array", "version.txt").decode()


def mpi_version():
"""
Return the MPI API version.
:rtype: :obj:`int`
:return: MPI major version number.
"""
return _mpi.VERSION


class SharedMemInfo(object):
"""
Info on possible shared memory allocation for a specified MPI communicator.
Expand All @@ -59,7 +71,7 @@ def __init__(self, comm=None, shared_mem_comm=None):
if comm is None:
comm = _mpi.COMM_WORLD
if shared_mem_comm is None:
if _mpi.VERSION >= 3:
if mpi_version() >= 3:
shared_mem_comm = comm.Split_type(_mpi.COMM_TYPE_SHARED, key=comm.rank)
else:
shared_mem_comm = comm.Split(comm.rank, key=comm.rank)
Expand All @@ -68,9 +80,7 @@ def __init__(self, comm=None, shared_mem_comm=None):

# Count the number of self._shared_mem_comm rank-0 processes
# to work out how many communicators comm was split into.
is_rank_zero = 0
if self._shared_mem_comm.rank == 0:
is_rank_zero = 1
is_rank_zero = int(self._shared_mem_comm.rank == 0)
self._num_shared_mem_nodes = comm.allreduce(is_rank_zero, _mpi.SUM)

@property
Expand Down Expand Up @@ -303,6 +313,14 @@ def calc_intersection(self, other):

return intersection_extent

def to_slice(self):
"""
Returns ":obj:`tuple` of :obj:`slice`" equivalent of this indexing extent.
:rtype: :obj:`tuple` of :obj:`slice` elements
:return: Tuple of slice equivalent to this indexing extent.
"""

def __repr__(self):
"""
Stringize.
Expand Down Expand Up @@ -715,6 +733,8 @@ def __init__(
self._shape = None
self._mem_alloc_topology = mem_alloc_topology
self._shape_decomp = None
self._rank_logger = None
self._root_logger = None

self.recalculate(shape, halo)

Expand All @@ -731,6 +751,7 @@ def calculate_rank_view_slices(self):
halo=0,
array_start=self._lndarray_extent.start_n
)

split = shape_splitter.calculate_split()
rank_extent_n = IndexingExtent(split.flatten()[self.shared_mem_comm.rank])
rank_extent_h = \
Expand All @@ -747,15 +768,16 @@ def calculate_rank_view_slices(self):

# Convert rank_extent_n and rank_extent_h from global-indices
# to local-indices
halo_lo = self._lndarray_extent.halo[:, self.LO]
rank_extent_n = \
IndexingExtent(
start=rank_extent_n.start - self._lndarray_extent.start_n,
stop=rank_extent_n.stop - self._lndarray_extent.start_n,
start=rank_extent_n.start - self._lndarray_extent.start_n + halo_lo,
stop=rank_extent_n.stop - self._lndarray_extent.start_n + halo_lo,
)
rank_extent_h = \
IndexingExtent(
start=rank_extent_h.start - self._lndarray_extent.start_n,
stop=rank_extent_h.stop - self._lndarray_extent.start_n,
start=rank_extent_h.start - self._lndarray_extent.start_n + halo_lo,
stop=rank_extent_h.stop - self._lndarray_extent.start_n + halo_lo,
)
rank_h_relative_extent_n = \
IndexingExtent(
Expand Down Expand Up @@ -844,6 +866,13 @@ def recalculate(self, new_shape, new_halo):
) # noqa: E123

self._lndarray_extent = self.shared_mem_comm.bcast(self._lndarray_extent, 0)

self._lndarray_view_slice_n = \
IndexingExtent(
start=self._lndarray_extent.halo[:, self.LO],
stop=self._lndarray_extent.halo[:, self.LO] + self._lndarray_extent.shape_n
).to_slice()

self.calculate_rank_view_slices()

def alloc_local_buffer(self, dtype):
Expand All @@ -867,7 +896,7 @@ def alloc_local_buffer(self, dtype):
rank_shape = self.shape
num_rank_bytes = int(_np.product(rank_shape) * dtype.itemsize)

if _mpi.VERSION >= 3:
if (mpi_version() >= 3) and (self.shared_mem_comm.size > 1):
self._shared_mem_win = \
_mpi.Win.Allocate_shared(num_rank_bytes, dtype.itemsize, comm=self.shared_mem_comm)
buffer, itemsize = self._shared_mem_win.Shared_query(0)
Expand Down Expand Up @@ -1016,6 +1045,34 @@ def rank_view_relative_slice_n(self):
"""
return self._rank_view_relative_slice_n

@property
def lndarray_view_slice_n(self):
"""
Indexing slice which can be used to generate a view of :obj:`mpi_array.local.lndarray`
which has the halo removed.
"""
return self._lndarray_view_slice_n

@property
def rank_logger(self):
"""
A :obj:`logging.Logger` for :attr:`rank_comm` communicator ranks.
"""
if self._rank_logger is None:
self._rank_logger = \
_logging.get_rank_logger(self.__class__.__name__, comm=self.rank_comm)
return self._rank_logger

@property
def root_logger(self):
"""
A :obj:`logging.Logger` for rank 0 of the :attr:`rank_comm` communicator.
"""
if self._root_logger is None:
self._root_logger = \
_logging.get_root_logger(self.__class__.__name__, comm=self.rank_comm)
return self._root_logger


if (_sys.version_info[0] >= 3) and (_sys.version_info[1] >= 5):
# Set docstring for properties.
Expand Down
14 changes: 14 additions & 0 deletions mpi_array/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,20 @@ def rank_view_h(self):
"""
return self[self.decomp.rank_view_slice_h]

@property
def view_n(self):
"""
View of entire array without halo.
"""
return self[self.decomp.lndarray_view_slice_n]

@property
def view_h(self):
"""
The entire :obj:`lndarray` view including halo (i.e. :samp:{self}).
"""
return self

def __reduce_ex__(self, protocol):
"""
Pickle *reference* to shared memory.
Expand Down
61 changes: 41 additions & 20 deletions mpi_array/local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,27 +256,48 @@ def test_views_2d(self):

lshape = _np.array((4, 3), dtype="int64")
gshape = lshape * _shape_factors(_mpi.COMM_WORLD.size, lshape.size)
decomp = CartesianDecomposition(shape=gshape, halo=2)

lary = mpi_array.local.ones(decomp=decomp, dtype="int64")
self.assertEqual(_np.dtype("int64"), lary.dtype)
rank_logger = _logging.get_rank_logger(self.id(), comm=decomp.rank_comm)
rank_logger.info("========================================================")
rank_logger.info("rank_view_slice_n = %s" % (lary.decomp.rank_view_slice_n,))
rank_logger.info("rank_view_slice_h = %s" % (lary.decomp.rank_view_slice_h,))
rank_logger.info(
"rank_view_relative_slice_n = %s" % (lary.decomp.rank_view_relative_slice_n,)
)

lary.rank_view_n[...] = lary.decomp.rank_comm.rank
lary.decomp.shared_mem_comm.barrier()
if lary.decomp.shared_mem_comm.size > 1:
self.assertTrue(_np.any(lary.rank_view_h != lary.decomp.rank_comm.rank))
self.assertSequenceEqual(
lary.rank_view_h[lary.decomp.rank_view_relative_slice_n].tolist(),
lary.rank_view_n.tolist()
)
rank_logger.info("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
mats = \
[
None,
MemAllocTopology(
ndims=gshape.size,
rank_comm=_mpi.COMM_WORLD,
shared_mem_comm=_mpi.COMM_SELF
)
]
for mat in mats:
decomp = CartesianDecomposition(shape=gshape, halo=2, mem_alloc_topology=mat)

lary = mpi_array.local.ones(decomp=decomp, dtype="int64")
self.assertEqual(_np.dtype("int64"), lary.dtype)
rank_logger = _logging.get_rank_logger(self.id(), comm=decomp.rank_comm)
rank_logger.info(
(
"\n========================================================\n" +
"lndarray_extent = %s\n" +
"rank_view_slice_n = %s\n" +
"rank_view_slice_h = %s\n" +
"rank_view_relative_slice_n = %s\n" +
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
)
%
(
lary.decomp._lndarray_extent,
lary.decomp.rank_view_slice_n,
lary.decomp.rank_view_slice_h,
lary.decomp.rank_view_relative_slice_n,
)
)

lary.rank_view_n[...] = lary.decomp.rank_comm.rank
lary.decomp.shared_mem_comm.barrier()
if lary.decomp.shared_mem_comm.size > 1:
self.assertTrue(_np.any(lary.rank_view_h != lary.decomp.rank_comm.rank))
self.assertSequenceEqual(
lary.rank_view_h[lary.decomp.rank_view_relative_slice_n].tolist(),
lary.rank_view_n.tolist()
)


_unittest.main(__name__)
Expand Down

0 comments on commit 52abddc

Please sign in to comment.