Skip to content

Commit

Permalink
MPI-4.1: Add lowercase Request.get_status_{any|all|some}
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed Nov 16, 2023
1 parent 6d6d9e1 commit d2dfecd
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/mpi4py/MPI.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,19 @@ class Request:
@classmethod
def testany(cls, requests: Sequence[Request], status: Status | None = None) -> tuple[int, bool, Any | None]: ...
@classmethod
def get_status_any(cls, requests: Sequence[Request], status: Status | None = None) -> tuple[int, bool]: ...
@classmethod
def waitall(cls, requests: Sequence[Request], statuses: list[Status] | None = None) -> list[Any]: ...
@classmethod
def testall(cls, requests: Sequence[Request], statuses: list[Status] | None = None) -> tuple[bool, list[Any] | None]: ...
@classmethod
def get_status_all(cls, requests: Sequence[Request], statuses: list[Status] | None = None) -> bool: ...
@classmethod
def waitsome(cls, requests: Sequence[Request], statuses: list[Status] | None = None) -> tuple[list[int] | None, list[Any] | None]: ...
@classmethod
def testsome(cls, requests: Sequence[Request], statuses: list[Status] | None = None) -> tuple[list[int] | None, list[Any] | None]: ...
@classmethod
def get_status_some(cls, requests: Sequence[Request], statuses: list[Status] | None = None) -> list[int] | None: ...
def cancel(self) -> None: ...
handle: int

Expand Down
33 changes: 33 additions & 0 deletions src/mpi4py/MPI/Request.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,17 @@ cdef class Request:
return (index, <bint>flag, msg)
#
@classmethod
def get_status_any(
cls,
requests: Sequence[Request],
Status status: Status | None = None,
) -> tuple[int, bool]:
"""
Non-destructive test for the completion of any requests.
"""
return Request.Get_status_any(requests, status)
#
@classmethod
def waitall(
cls,
requests: Sequence[Request],
Expand All @@ -382,6 +393,17 @@ cdef class Request:
return (<bint>flag, msg)
#
@classmethod
def get_status_all(
cls,
requests: Sequence[Request],
statuses: list[Status] | None = None,
) -> bool:
"""
Non-destructive test for the completion of all requests.
"""
return Request.Get_status_all(requests, statuses)
#
@classmethod
def waitsome(
cls,
requests: Sequence[Request],
Expand All @@ -403,6 +425,17 @@ cdef class Request:
"""
return PyMPI_testsome(requests, statuses)
#
@classmethod
def get_status_some(
cls,
requests: Sequence[Request],
statuses: list[Status] | None = None,
) -> list[int] | None:
"""
Non-destructive test for completion of some requests.
"""
return Request.Get_status_some(requests, statuses)
#
def cancel(self) -> None:
"""
Cancel a request.
Expand Down
10 changes: 10 additions & 0 deletions src/mpi4py/util/pkl5.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,16 @@ def wait(self, status=None):
"""Wait for a request to complete."""
return _test(self, MPI.Request.Waitall, status)[1]

@classmethod
def get_status_all(cls, requests, statuses=None):
"""Non-destructive test for the completion of all requests."""
arglist = [requests]
if statuses is not None:
ns, nr = len(statuses), len(requests)
statuses += [Status() for _ in range(ns, nr)]
arglist.append(statuses)
return all(map(Request.get_status, *arglist))

@classmethod
def testall(cls, requests, statuses=None):
"""Test for the completion of all requests."""
Expand Down
8 changes: 7 additions & 1 deletion src/mpi4py/util/pkl5.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class Request(tuple[MPI.Request, ...]):
status: Status | None = None,
) -> Any: ...
@classmethod
def get_status_all(
cls,
requests: Sequence[Request],
statuses: list[Status] | None = None,
) -> bool: ...
@classmethod
def testall(
cls,
requests: Sequence[Request],
Expand All @@ -61,7 +67,7 @@ class Request(tuple[MPI.Request, ...]):
def waitall(
cls,
requests: Sequence[Request],
statuses: list[Status] | None = None,
statuses: list[Status] | None = None,
) -> list[Any]: ...

class Message(tuple[MPI.Message, ...]):
Expand Down
10 changes: 10 additions & 0 deletions test/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def testGetStatusAny(self):
self.assertEqual(status.source, MPI.ANY_SOURCE)
self.assertEqual(status.tag, MPI.ANY_TAG)
self.assertEqual(status.error, MPI.SUCCESS)
with self.catchNotImplementedError(4, 1):
index, flag = MPI.Request.get_status_any(self.REQUESTS)
self.assertEqual(index, MPI.UNDEFINED)
self.assertTrue(flag)

def testWaitall(self):
MPI.Request.Waitall(self.REQUESTS)
Expand Down Expand Up @@ -139,6 +143,9 @@ def testGetStatusAll(self):
self.assertEqual(status.source, MPI.ANY_SOURCE)
self.assertEqual(status.tag, MPI.ANY_TAG)
self.assertEqual(status.error, MPI.SUCCESS)
with self.catchNotImplementedError(4, 1):
flag = MPI.Request.get_status_all(self.REQUESTS)
self.assertTrue(flag)

def testWaitsome(self):
ret = MPI.Request.Waitsome(self.REQUESTS)
Expand Down Expand Up @@ -171,6 +178,9 @@ def testGetStatusSome(self):
self.assertIsNone(indices)
indices = MPI.Request.Get_status_some(self.REQUESTS, statuses)
self.assertIsNone(indices)
with self.catchNotImplementedError(4, 1):
indices = MPI.Request.get_status_some(self.REQUESTS)
self.assertIsNone(indices)


if __name__ == '__main__':
Expand Down
27 changes: 27 additions & 0 deletions test/test_util_pkl5.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,33 @@ def testISSendCancel(self):
req.Free()
self.assertFalse(req)

def testGetStatusAll(self):
comm = self.COMM
size = comm.Get_size()
rank = comm.Get_rank()
requests = []
for smess in messages:
req = comm.issend(smess, rank)
requests.append(req)
with self.catchNotImplementedError(4, 1):
flag = self.RequestType.get_status_all(requests)
self.assertFalse(flag)
comm.barrier()
for smess in messages:
rmess = comm.recv(None, rank)
self.assertEqual(rmess, smess)
with self.catchNotImplementedError(4, 1):
flag = False
statuses = []
while not flag:
flag = self.RequestType.get_status_all(requests, statuses)
self.assertEqual(len(statuses), len(requests))
for status in statuses:
self.assertIsInstance(status, MPI.Status)
flag, obj = self.RequestType.testall(requests)
self.assertTrue(flag)
self.assertEqual(obj, [None]*len(messages))

def testTestAll(self):
comm = self.COMM
size = comm.Get_size()
Expand Down

0 comments on commit d2dfecd

Please sign in to comment.