Skip to content

Commit

Permalink
test: Update for Open MPI
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed Apr 21, 2024
1 parent 83bd54f commit 88f6064
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion test/test_grequest.py
Expand Up @@ -46,14 +46,16 @@ def testConstructor(self):
greq.Complete()
greq.Wait()

@unittest.skipMPI('openmpi') # TODO: open-mpi/ompi#11681
@unittest.skipMPI('openmpi(<=5.0.2)')
def testExceptionHandling(self):
ctx = GReqCtx()

def raise_mpi(*args):
raise MPI.Exception(MPI.ERR_BUFFER)

def raise_rte(*args):
raise ValueError(42)

def check_exc(exception, is_mpi, stderr):
output = stderr.getvalue()
header = 'Traceback (most recent call last):\n'
Expand All @@ -72,8 +74,50 @@ def check_exc(exception, is_mpi, stderr):
(raise_mpi, True),
(raise_rte, False),
):
if is_mpi:
_query, _free = raise_mpi, raise_rte
else:
_query, _free = raise_rte, raise_mpi

q_called = 0
def query(*args):
nonlocal q_called
q_called +=1
return _query(*args)

f_called = 0
def free(*args):
nonlocal f_called
f_called +=1
return _free(*args)

greq = MPI.Grequest.Start(query, free, ctx.cancel)
self.assertFalse(greq.Get_status())
self.assertEqual(q_called, 0)
self.assertEqual(f_called, 0)
greq.Complete()
self.assertEqual(q_called, 0)
self.assertEqual(f_called, 0)
with self.assertRaises(MPI.Exception) as exc_cm:
with unittest.capture_stderr() as stderr:
greq.Get_status()
check_exc(exc_cm.exception, is_mpi, stderr)
self.assertEqual(q_called, 1)
self.assertEqual(f_called, 0)
with self.assertRaises(MPI.Exception) as exc_cm:
with unittest.capture_stderr() as stderr:
greq.Wait()
if greq:
greq.Free()
self.assertEqual(q_called, 2)
self.assertEqual(f_called, 1)
#
greq = MPI.Grequest.Start(raise_fn, ctx.free, ctx.cancel)
self.assertFalse(greq.Get_status())
greq.Complete()
with self.assertRaises(MPI.Exception) as exc_cm:
with unittest.capture_stderr() as stderr:
greq.Get_status()
with self.assertRaises(MPI.Exception) as exc_cm:
with unittest.capture_stderr() as stderr:
greq.Wait()
Expand All @@ -82,7 +126,9 @@ def check_exc(exception, is_mpi, stderr):
check_exc(exc_cm.exception, is_mpi, stderr)
#
greq = MPI.Grequest.Start(ctx.query, raise_fn, ctx.cancel)
self.assertFalse(greq.Get_status())
greq.Complete()
self.assertTrue(greq.Get_status())
with self.assertRaises(MPI.Exception) as exc_cm:
with unittest.capture_stderr() as stderr:
greq.Wait()
Expand All @@ -91,10 +137,13 @@ def check_exc(exception, is_mpi, stderr):
check_exc(exc_cm.exception, is_mpi, stderr)
#
greq = MPI.Grequest.Start(ctx.query, ctx.free, raise_fn)
self.assertFalse(greq.Get_status())
with self.assertRaises(MPI.Exception) as exc_cm:
with unittest.capture_stderr() as stderr:
greq.Cancel()
self.assertFalse(greq.Get_status())
greq.Complete()
self.assertTrue(greq.Get_status())
greq.Wait()
if greq:
greq.Free()
Expand Down

0 comments on commit 88f6064

Please sign in to comment.