Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[MASTER]
disable = locally-disabled,file-ignored,no-else-return
enable = useless-suppression
good-names = i,j,k,_,fn,tb
good-names = i,j,k,_,fs,fn,tb
typealias-rgx=[_A-Z][_a-zA-Z0-9]*
reports = no
extension-pkg-allow-list =
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Release 4.0.0 [2023-XX-XX]

* Enhancements:

+ `mpi4py.futures`: Support for parallel tasks.

+ `mpi4py.futures`: Report exception tracebacks in workers.

+ `mpi4py.util.pkl5`: Add support for collective communication.
Expand Down
222 changes: 214 additions & 8 deletions demo/futures/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ class ThenTest(unittest.TestCase):

assert_ = unittest.TestCase.assertTrue

def test_not_done(self):
def test_cancel_base(self):

base_f = ThenableFuture()
new_f = base_f.then()
Expand All @@ -1224,10 +1224,14 @@ def test_not_done(self):
self.assertTrue(not base_f.done())
self.assertTrue(not new_f.done())

base_f._invoke_callbacks()
base_f.cancel()
self.assertTrue(base_f.done())
self.assertTrue(new_f.done())

self.assertTrue(base_f.cancelled())
self.assertTrue(new_f.cancelled())

def test_cancel(self):
def test_cancel_new(self):

base_f = ThenableFuture()
new_f = base_f.then()
Expand All @@ -1236,11 +1240,12 @@ def test_cancel(self):
self.assertTrue(not base_f.done())
self.assertTrue(not new_f.done())

base_f.cancel()
self.assertTrue(base_f.done())
new_f.cancel()
self.assertTrue(not base_f.done())
self.assertTrue(new_f.done())

self.assertTrue(base_f.cancelled())
base_f.set_result(1)
self.assertTrue(base_f.done())
self.assertTrue(new_f.cancelled())

def test_then_multiple(self):
Expand Down Expand Up @@ -1476,7 +1481,29 @@ def transform(value):
self.assertTrue(not new_f.exception())
self.assertTrue(new_f.result() == 5)

def test_detect_circular_chains(self):
def test_chained_failure_callback_and_success(self):

def transform(exc):
self.assertIsInstance(exc, RuntimeError)
f = ThenableFuture()
f.set_result(5)
return f

base_f = ThenableFuture()
new_f = base_f.catch(transform)

self.assertTrue(base_f is not new_f)
self.assertTrue(not base_f.done())
self.assertTrue(not new_f.done())

base_f.set_exception(RuntimeError())
self.assertTrue(base_f.done())
self.assertTrue(new_f.done())

self.assertTrue(not new_f.exception())
self.assertTrue(new_f.result() == 5)

def test_detect_cycle_chain(self):

f1 = ThenableFuture()
f2 = ThenableFuture()
Expand Down Expand Up @@ -1507,11 +1534,190 @@ def transform(a):
with self.assertRaises(RuntimeError) as catcher:
new_f.result()
self.assertTrue(
'Circular future chain detected'
'chain cycle detected'
in catcher.exception.args[0],
)

def test_detect_self_chain(self):

base_f = ThenableFuture()
new_f = base_f.then(lambda arg: new_f)

self.assertTrue(base_f is not new_f)
self.assertTrue(not base_f.done())
self.assertTrue(not new_f.done())

base_f.set_result(1)
self.assertTrue(base_f.done())
self.assertTrue(new_f.done())

self.assertTrue(new_f.exception())
with self.assertRaises(RuntimeError) as catcher:
new_f.result()
self.assertTrue(
'chain cycle detected'
in catcher.exception.args[0],
)


class CollectTest(unittest.TestCase):

def test_empty(self):
future = futures.collect([])
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
self.assertEqual(future.result(), [])

def test_item_success(self):
fs = [futures.Future() for _ in range(5)]
future = futures.collect(fs)
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertFalse(future.done())
for i in range(5):
fs[i].set_result(i)
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
self.assertEqual(future.result(), list(range(5)))

def test_item_failure(self):
fs = [futures.Future() for _ in range(5)]
future = futures.collect(fs)
for i in range(2, 4):
fs[i].set_result(i)
fs[-1].set_exception(RuntimeError())
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
self.assertIsInstance(future.exception(), RuntimeError)
for i in range(0, 2):
self.assertTrue(fs[i].cancelled())
for i in range(2, 4):
self.assertFalse(fs[i].cancelled())
self.assertFalse(fs[-1].cancelled())

def test_item_done(self):
fs = [futures.Future() for _ in range(5)]
for i in range(5):
fs[i].set_result(i)
future = futures.collect(fs)
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
self.assertEqual(future.result(), list(range(5)))

def test_item_cancel(self):
fs = [futures.Future() for _ in range(5)]
future = futures.collect(fs)
for i in range(2, 4):
fs[i].set_result(i)
fs[-1].cancel()
self.assertTrue(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
for i in range(0, 2):
self.assertTrue(fs[i].cancelled())
for i in range(2, 4):
self.assertFalse(fs[i].cancelled())
self.assertTrue(fs[-1].cancelled())

def test_cancel(self):
fs = [futures.Future() for _ in range(5)]
future = futures.collect(fs)
future.cancel()
for f in fs:
self.assertTrue(f.cancelled())

def test_cancel_pending(self):
class MyFuture(futures.Future):
def cancel(self):
pass
fs = [MyFuture() for _ in range(5)]
future = futures.collect(fs)
self.assertIs(type(future), MyFuture)
super(MyFuture, future).cancel()
for f in fs:
self.assertFalse(f.cancelled())
f.set_result(None)


class ComposeTest(unittest.TestCase):

def test_result(self):
base = futures.Future()
future = futures.compose(base)
self.assertIs(type(future), type(base))
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertFalse(future.done())
base.set_result(42)
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
self.assertEqual(future.result(), 42)

def test_except(self):
base = futures.Future()
future = futures.compose(base)
self.assertIs(type(future), type(base))
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertFalse(future.done())
base.set_exception(RuntimeError(42))
self.assertFalse(future.cancelled())
self.assertFalse(future.running())
self.assertTrue(future.done())
self.assertIs(type(future.exception()), RuntimeError)
self.assertEqual(future.exception().args, (42,))

def test_cancel_new(self):
base = futures.Future()
future = futures.compose(base)
base.cancel()
self.assertTrue(future.cancelled())

def test_cancel_old(self):
base = futures.Future()
future = futures.compose(base)
future.cancel()
self.assertTrue(base.cancelled())

def test_result_hook(self):
base = futures.Future()
future = futures.compose(base, int)
base.set_result('42')
self.assertEqual(future.result(), 42)

def test_result_hook_failure(self):
base = futures.Future()
future = futures.compose(base, resulthook=lambda x: 1/0)
base.set_result(42)
self.assertIs(type(future.exception()), ZeroDivisionError)

def test_except_hook(self):
base = futures.Future()
future = futures.compose(base, excepthook=lambda exc: exc.args[0])
base.set_exception(RuntimeError(42))
self.assertEqual(future.result(), 42)

def test_except_hook_except(self):
base = futures.Future()
future = futures.compose(
base, excepthook=lambda exc: RuntimeError(exc.args[0])
)
base.set_exception(ValueError(42))
self.assertIs(type(future.exception()), RuntimeError)
self.assertEqual(future.exception().args, (42,))

def test_except_hook_failure(self):
base = futures.Future()
future = futures.compose(base, excepthook=lambda exc: 1/0)
base.set_exception(ValueError(42))
self.assertIs(type(future.exception()), ZeroDivisionError)


SKIP_POOL_TEST = False
name, version = MPI.get_vendor()
if name == 'Open MPI':
Expand Down
11 changes: 11 additions & 0 deletions docs/source/mpi4py.futures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,17 @@ rules, cf. highlighted lines in example :ref:`cpi-py` :
the parallel task.


Utilities
---------

The :mod:`mpi4py.futures` package provides additional utilities for handling
:class:`~concurrent.futures.Future` instances.

.. autofunction:: mpi4py.futures.collect

.. autofunction:: mpi4py.futures.compose


Examples
--------

Expand Down
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ package = {
'_base.py',
'_core.py',
'pool.py',
'util.py',
'server.py',
'aplus.py',
],
Expand Down
21 changes: 13 additions & 8 deletions src/mpi4py/futures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
InvalidStateError,
BrokenExecutor,
)

from .pool import MPIPoolExecutor
from .pool import MPICommExecutor

from .pool import ThreadPoolExecutor
from .pool import ProcessPoolExecutor

from .pool import get_comm_workers
from .pool import (
MPIPoolExecutor,
MPICommExecutor,
ThreadPoolExecutor,
ProcessPoolExecutor,
get_comm_workers,
)
from .util import (
collect,
compose,
)

__all__ = [
'Future',
Expand All @@ -42,4 +45,6 @@
'ThreadPoolExecutor',
'ProcessPoolExecutor',
'get_comm_workers',
'collect',
'compose',
]
12 changes: 6 additions & 6 deletions src/mpi4py/futures/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ from ._base import (
from .pool import (
MPIPoolExecutor as MPIPoolExecutor,
MPICommExecutor as MPICommExecutor,
)

from .pool import (
ThreadPoolExecutor as ThreadPoolExecutor,
ProcessPoolExecutor as ProcessPoolExecutor,
)

from .pool import (
get_comm_workers as get_comm_workers,
)
from .util import (
collect as collect,
compose as compose,
)

__all__: list[str] = [
'Future',
Expand All @@ -43,4 +41,6 @@ __all__: list[str] = [
'ThreadPoolExecutor',
'ProcessPoolExecutor',
'get_comm_workers',
'collect',
'compose',
]
Loading