Skip to content

Commit

Permalink
FIX Close cleanly distributed Client at the end of unit tests (#1526)
Browse files Browse the repository at this point in the history
* Close cleanly distributed Client at the end of unit tests

* Fix compatibility with python3.8
  • Loading branch information
fcharras committed Dec 1, 2023
1 parent ebfe05d commit 6310841
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions joblib/test/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,9 @@ def test_parallel_unordered_generator_returns_fastest_first(backend, n_jobs):
def test_parallel_unordered_generator_returns_fastest_first_with_dask(
n_jobs, context
):
client = distributed.Client(n_workers=2, threads_per_worker=2) # noqa
with context("dask"):
with distributed.Client(
n_workers=2, threads_per_worker=2
), context("dask"):
_test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)


Expand Down Expand Up @@ -1346,8 +1347,9 @@ def test_deadlock_with_generator(backend, return_as, n_jobs):
@parametrize("context", [parallel_config, parallel_backend])
@skipif(distributed is None, reason='This test requires dask')
def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
client = distributed.Client(n_workers=2, threads_per_worker=2) # noqa
with context("dask"):
with distributed.Client(
n_workers=2, threads_per_worker=2
), context("dask"):
_test_deadlock_with_generator(None, return_as, n_jobs)


Expand Down Expand Up @@ -1729,24 +1731,23 @@ def test_nested_parallelism_limit(context, backend):
@parametrize("context", [parallel_config, parallel_backend])
@skipif(distributed is None, reason='This test requires dask')
def test_nested_parallelism_with_dask(context):
client = distributed.Client(n_workers=2, threads_per_worker=2) # noqa

# 10 MB of data as argument to trigger implicit scattering
data = np.ones(int(1e7), dtype=np.uint8)
for i in range(2):
with distributed.Client(n_workers=2, threads_per_worker=2):
# 10 MB of data as argument to trigger implicit scattering
data = np.ones(int(1e7), dtype=np.uint8)
for i in range(2):
with context('dask'):
backend_types_and_levels = _recursive_backend_info(data=data)
assert len(backend_types_and_levels) == 4
assert all(name == 'DaskDistributedBackend'
for name, _ in backend_types_and_levels)

# No argument
with context('dask'):
backend_types_and_levels = _recursive_backend_info(data=data)
backend_types_and_levels = _recursive_backend_info()
assert len(backend_types_and_levels) == 4
assert all(name == 'DaskDistributedBackend'
for name, _ in backend_types_and_levels)

# No argument
with context('dask'):
backend_types_and_levels = _recursive_backend_info()
assert len(backend_types_and_levels) == 4
assert all(name == 'DaskDistributedBackend'
for name, _ in backend_types_and_levels)


def _recursive_parallel(nesting_limit=None):
"""A horrible function that does recursive parallel calls"""
Expand Down

0 comments on commit 6310841

Please sign in to comment.