Skip to content

Commit

Permalink
Fix for zero length collective IO (#965)
Browse files Browse the repository at this point in the history
Collective IO where one rank attempts to read or write a zero length
slice causes the a hang in an underlying collective call. This fix
ensures that all ranks participate in any collective IO operation even
if some operations are on zero length slices.

This also includes some basic unit testing for collective IO that tests
for this problem.
  • Loading branch information
jrs65 authored and aragilar committed Aug 9, 2020
1 parent 6f4c578 commit f487ad1
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 3 deletions.
41 changes: 39 additions & 2 deletions h5py/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,13 @@ def collective(self):
""" Context manager for MPI collective reads & writes """
return CollectiveContext(self)

def _collective_mode(self):
"""Return True if the dataset is in MPI collective mode"""
if MPI:
return self._dxpl.get_dxpl_mpio() == h5fd.MPIO_COLLECTIVE
else:
return False

@property
def dims(self):
""" Access dimension scales attached to this dataset. """
Expand Down Expand Up @@ -775,7 +782,9 @@ def __getitem__(self, args, new_dtype=None):
# Perform the dataspace selection.
selection = sel.select(self.shape, args, dataset=self)

if selection.nselect == 0:
# If we are in MPI collective mode, we need to do the read even if it's
# an empty selection, to ensure all MPI processes read.
if selection.nselect == 0 and not self._collective_mode():
return numpy.ndarray(selection.array_shape, dtype=new_dtype)

arr = numpy.ndarray(selection.array_shape, new_dtype, order='C')
Expand Down Expand Up @@ -908,7 +917,11 @@ def __setitem__(self, args, val):
# Perform the dataspace selection
selection = sel.select(self.shape, args, dataset=self)

if selection.nselect == 0:
# If we are in MPI collective mode, we need to do the write even if it's
# an empty selection, to ensure all MPI processes write.
is_collective = self._collective_mode()

if selection.nselect == 0 and not is_collective:
return

# Broadcast scalars if necessary.
Expand Down Expand Up @@ -938,6 +951,14 @@ def __setitem__(self, args, val):
val = val2
mshape = val.shape

if is_collective and (mshape != selection.mshape):
warn("Broadcasting in collective mode is deprecated, because "
"processes may do different numbers of writes. "
"Expand the data shape - {} - to match the selection: {}."
.format(mshape, selection.mshape),
H5pyDeprecationWarning, stacklevel=2,
)

# Perform the write, with broadcasting
mspace = h5s.create_simple(selection.expand_shape(mshape))
for fspace in selection.broadcast(mshape):
Expand Down Expand Up @@ -965,6 +986,14 @@ def read_direct(self, dest, source_sel=None, dest_sel=None):
else:
dest_sel = sel.select(dest.shape, dest_sel, self)

if self._collective_mode() and (dest_sel.mshape != source_sel.mshape):
warn("Broadcasting in collective mode is deprecated, because "
"processes may do different numbers of reads. "
"Expand the selection shape - {} - to match the array: {}."
.format(source_sel.mshape, dest_sel.mshape),
H5pyDeprecationWarning, stacklevel=2,
)

for mspace in dest_sel.broadcast(source_sel.mshape):
self.id.read(mspace, fspace, dest, dxpl=self._dxpl)

Expand All @@ -990,6 +1019,14 @@ def write_direct(self, source, source_sel=None, dest_sel=None):
else:
dest_sel = sel.select(self.shape, dest_sel, self)

if self._collective_mode() and (dest_sel.mshape != source_sel.mshape):
warn("Broadcasting in collective mode is deprecated, because "
"processes may do different numbers of writes. "
"Expand the data shape - {} - to match the selection: {}."
.format(source_sel.mshape, dest_sel.mshape),
H5pyDeprecationWarning, stacklevel=2,
)

for fspace in dest_sel.broadcast(source_sel.mshape):
self.id.write(mspace, fspace, source, dxpl=self._dxpl)

Expand Down
4 changes: 3 additions & 1 deletion h5py/_hl/selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ def broadcast(self, source_shape):
rank = len(count)
tshape = self.expand_shape(source_shape)

chunks = tuple(x//y for x, y in zip(count, tshape))
# A zero-length selection needs to work for collective I/O
chunks = tuple(1 if (x == y == 0) else x // y
for x, y in zip(count, tshape))
nchunks = product(chunks)

if nchunks == 1:
Expand Down
135 changes: 135 additions & 0 deletions h5py/tests/test_mpi_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# This file is part of h5py, a Python interface to the HDF5 library.
#
# http://www.h5py.org
#
# Copyright 2008-2013 Andrew Collette and contributors
#
# License: Standard 3-clause BSD; see "license.txt" for full license terms
# and contributor agreement.

"""
Tests the h5py collective IO.
This must be run in an MPI job, otherwise it will be skipped. For example:
mpirun -np 4 pytest test_mpi_collective.py
"""
from __future__ import absolute_import

import numpy as np
import h5py

from .common import ut, TestCase


# Check if we are in an MPI environment, need more than 1 process for these
# tests to be meaningful
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
MPI_ENV = (comm.size > 1)
MPI_SIZE = comm.size
MPI_RANK = comm.rank
except ImportError:
MPI_ENV = False


@ut.skipUnless(h5py.get_config().mpi and MPI_ENV, 'MPI support required')
class TestCollectiveWrite(TestCase):

def setUp(self):
"""Open a file in MPI mode"""
self.path = self.mktemp() if MPI_RANK == 0 else None
self.path = comm.bcast(self.path, root=0)
self.f = h5py.File(self.path, 'w', driver='mpio', comm=comm)

def test_collective_write(self):
"""Test a standard collective write."""

dset = self.f.create_dataset("test_data", (MPI_SIZE, 20), dtype=np.int32)

# Write dataset collectively, each process writes one row
with dset.collective:
dset[MPI_RANK:(MPI_RANK + 1)] = MPI_RANK
self.f.close()

# Test that the array is as expected
with h5py.File(self.path, "r") as fh:
self.assertEqual(fh["test_data"].shape, (MPI_SIZE, 20))
arr = np.tile(np.arange(MPI_SIZE), (20, 1)).T
self.assertTrue((fh["test_data"][:] == arr).all())

def test_collective_write_empty_rank(self):
"""Test a collective write where some ranks may be empty.
WARNING: if this test fails it may cause a lockup in the MPI code.
"""

# Only the first NUM_WRITE ranks will actually write anything
NUM_WRITE = MPI_SIZE // 2

dset = self.f.create_dataset("test_data", (NUM_WRITE, 20), dtype=np.int32)

# Write dataset collectively, each process writes one row
start = min(MPI_RANK, NUM_WRITE)
end = min(MPI_RANK + 1, NUM_WRITE)
with dset.collective:
dset[start:end] = MPI_RANK
self.f.close()

# Test that the array is as expected
with h5py.File(self.path, "r") as fh:
self.assertEqual(fh["test_data"].shape, (NUM_WRITE, 20))
arr = np.tile(np.arange(NUM_WRITE), (20, 1)).T
self.assertTrue((fh["test_data"][:] == arr).all())


@ut.skipUnless(h5py.get_config().mpi and MPI_ENV, 'MPI support required')
class TestCollectiveRead(TestCase):

def setUp(self):
"""Open a file in MPI mode"""
self.path = self.mktemp() if MPI_RANK == 0 else None
self.path = comm.bcast(self.path, root=0)

if MPI_RANK == 0:
with h5py.File(self.path, 'w') as fh:
dset = fh.create_dataset("test_data", (20, MPI_SIZE), dtype=np.int32)
dset[:] = np.arange(MPI_SIZE)[np.newaxis, :]

self.f = h5py.File(self.path, 'r', driver='mpio', comm=comm)

def test_collective_read(self):
"""Test a standard collective read."""

dset = self.f["test_data"]

self.assertEqual(dset.shape, (20, MPI_SIZE))

# Read dataset collectively, each process reads one column
with dset.collective:
d = dset[:, MPI_RANK:(MPI_RANK + 1)]

self.assertTrue((d == MPI_RANK).all())

def test_collective_read_empty_rank(self):
"""Test a collective read where some ranks may read nothing.
WARNING: if this test fails it may cause a lockup in the MPI code.
"""

start = 0 if MPI_RANK == 0 else MPI_SIZE
end = MPI_SIZE

dset = self.f["test_data"]
self.assertEqual(dset.shape, (20, MPI_SIZE))

# Read dataset collectively, only the first rank should actually read
# anything
with dset.collective:
d = dset[:, start:end]

if MPI_RANK == 0:
self.assertTrue((d == np.arange(MPI_SIZE)[np.newaxis, :]).all())
else:
self.assertEqual(d.shape, (20, 0))

0 comments on commit f487ad1

Please sign in to comment.