Skip to content

Commit

Permalink
configure asyncio loop using loop_factory kwarg rather than using the…
Browse files Browse the repository at this point in the history
… set_event_loop_policy (#7969)

* backport asyncio.run's loop_factory kwarg from 3.12

* configure asyncio loop using loop_factory kwarg rather than using the set_event_loop_policy

* fix test_worker.py::test_io_loop_alternate_loop and test_scheduler.py::test_io_loop
  • Loading branch information
graingert committed Jul 13, 2023
1 parent 8e3e0f6 commit de50c90
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 55 deletions.
4 changes: 3 additions & 1 deletion distributed/cli/dask_scheduler.py
Expand Up @@ -13,6 +13,8 @@

from distributed import Scheduler
from distributed._signals import wait_for_signals
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
enable_proctitle_on_children,
Expand Down Expand Up @@ -246,7 +248,7 @@ async def wait_for_signals_and_close():
logger.info("Stopped scheduler at %r", scheduler.address)

try:
asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())
finally:
logger.info("End scheduler")

Expand Down
4 changes: 3 additions & 1 deletion distributed/cli/dask_spec.py
Expand Up @@ -7,6 +7,8 @@
import click
import yaml

from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.deploy.spec import run_spec


Expand Down Expand Up @@ -39,7 +41,7 @@ async def run():
except KeyboardInterrupt:
await asyncio.gather(*(w.close() for w in servers.values()))

asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion distributed/cli/dask_worker.py
Expand Up @@ -21,6 +21,8 @@
from distributed import Nanny
from distributed._signals import wait_for_signals
from distributed.comm import get_address_host_port
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.deploy.utils import nprocesses_nthreads
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
Expand Down Expand Up @@ -443,7 +445,7 @@ async def wait_for_signals_and_close():
[task.result() for task in done]

try:
asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())
except (TimeoutError, asyncio.TimeoutError):
# We already log the exception in nanny / worker. Don't do it again.
if not signal_fired:
Expand Down
6 changes: 5 additions & 1 deletion distributed/comm/tests/test_comms.py
Expand Up @@ -25,6 +25,8 @@
unparse_host_port,
)
from distributed.comm.registry import backends, get_backend
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for
Expand Down Expand Up @@ -438,7 +440,9 @@ async def run_with_timeout():
t = asyncio.create_task(func(*args, **kwargs))
return await wait_for(t, timeout=10)

return await asyncio.to_thread(asyncio.run, run_with_timeout())
return await asyncio.to_thread(
asyncio_run, run_with_timeout(), loop_factory=get_loop_factory()
)


@gen_test()
Expand Down
85 changes: 84 additions & 1 deletion distributed/compatibility.py
Expand Up @@ -5,6 +5,8 @@
import random
import sys
import warnings
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar

import tornado

Expand Down Expand Up @@ -48,7 +50,7 @@ def randbytes(*args, **kwargs):
# takes longer than the interval
import datetime
import math
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable
from inspect import isawaitable

from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -182,3 +184,84 @@ def _update_next(self, current_time: float) -> None:
# time.monotonic().
# https://github.com/tornadoweb/tornado/issues/2333
self._next_timeout += callback_time_sec


_T = TypeVar("_T")

if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):

def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)

else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)

if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")

if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()

def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
22 changes: 11 additions & 11 deletions distributed/config.py
Expand Up @@ -4,6 +4,7 @@
import logging.config
import os
import sys
from collections.abc import Callable
from typing import Any

import yaml
Expand Down Expand Up @@ -177,7 +178,7 @@ def initialize_logging(config: dict[Any, Any]) -> None:
_initialize_logging_old_style(config)


def initialize_event_loop(config: dict[Any, Any]) -> None:
def get_loop_factory() -> Callable[[], asyncio.AbstractEventLoop] | None:
event_loop = dask.config.get("distributed.admin.event-loop")
if event_loop == "uvloop":
uvloop = import_required(
Expand All @@ -189,19 +190,18 @@ def initialize_event_loop(config: dict[Any, Any]) -> None:
" conda install uvloop\n"
" pip install uvloop",
)
uvloop.install()
elif event_loop in {"asyncio", "tornado"}:
return uvloop.new_event_loop
if event_loop in {"asyncio", "tornado"}:
if sys.platform == "win32":
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
# ProactorEventLoop is not compatible with tornado 6
# fallback to the pre-3.8 default of Selector
# https://github.com/tornadoweb/tornado/issues/2608
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
else:
raise ValueError(
"Expected distributed.admin.event-loop to be in ('asyncio', 'tornado', 'uvloop'), got %s"
% dask.config.get("distributed.admin.event-loop")
)
return asyncio.SelectorEventLoop
return None
raise ValueError(
"Expected distributed.admin.event-loop to be in ('asyncio', 'tornado', 'uvloop'), got %s"
% dask.config.get("distributed.admin.event-loop")
)


initialize_logging(dask.config.config)
initialize_event_loop(dask.config.config)
5 changes: 3 additions & 2 deletions distributed/deploy/tests/test_local.py
Expand Up @@ -14,7 +14,8 @@
from dask.system import CPU_COUNT

from distributed import Client, LocalCluster, Nanny, Worker, get_client
from distributed.compatibility import LINUX
from distributed.compatibility import LINUX, asyncio_run
from distributed.config import get_loop_factory
from distributed.core import Status
from distributed.metrics import time
from distributed.system import MEMORY_LIMIT
Expand Down Expand Up @@ -670,7 +671,7 @@ async def amain():
box = cluster._cached_widget
assert isinstance(box, ipywidgets.Widget)

asyncio.run(amain())
asyncio_run(amain(), loop_factory=get_loop_factory())


def test_no_ipywidgets(loop, monkeypatch):
Expand Down
4 changes: 3 additions & 1 deletion distributed/nanny.py
Expand Up @@ -28,6 +28,8 @@
from distributed import preloading
from distributed.comm import get_address_host
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.core import (
AsyncTaskGroupClosedError,
CommClosedError,
Expand Down Expand Up @@ -996,7 +998,7 @@ def close_stop_q() -> None:
if silence_logs:
logger.setLevel(silence_logs)

asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())


def _get_env_variables(config_key: str) -> dict[str, str]:
Expand Down
5 changes: 3 additions & 2 deletions distributed/tests/test_asyncprocess.py
Expand Up @@ -13,7 +13,8 @@
import pytest
from tornado.ioloop import IOLoop

from distributed.compatibility import LINUX, MACOS, WINDOWS
from distributed.compatibility import LINUX, MACOS, WINDOWS, asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.process import AsyncProcess
from distributed.utils import get_mp_context, wait_for
Expand Down Expand Up @@ -389,7 +390,7 @@ async def run_with_timeout():
t = asyncio.create_task(parent_process_coroutine())
return await wait_for(t, timeout=10)

asyncio.run(run_with_timeout())
asyncio_run(run_with_timeout(), loop_factory=get_loop_factory())
raise RuntimeError("this should be unreachable due to os._exit")


Expand Down
16 changes: 7 additions & 9 deletions distributed/tests/test_scheduler.py
Expand Up @@ -1365,15 +1365,13 @@ async def test_update_graph_culls(s, a, b):
assert "z" not in s.tasks


def test_io_loop(loop):
async def main():
with pytest.warns(
DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated"
):
s = Scheduler(loop=loop, dashboard_address=":0", validate=True)
assert s.io_loop is IOLoop.current()

asyncio.run(main())
@gen_test()
async def test_io_loop(loop):
with pytest.warns(
DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated"
):
s = Scheduler(loop=loop, dashboard_address=":0", validate=True)
assert s.io_loop is IOLoop.current()


@gen_cluster(client=True)
Expand Down
13 changes: 8 additions & 5 deletions distributed/tests/test_utils.py
Expand Up @@ -23,7 +23,8 @@

import dask

from distributed.compatibility import MACOS, WINDOWS
from distributed.compatibility import MACOS, WINDOWS, asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.utils import (
All,
Expand Down Expand Up @@ -134,7 +135,7 @@ def test_sync_closed_loop():
async def get_loop():
return IOLoop.current()

loop = asyncio.run(get_loop())
loop = asyncio_run(get_loop(), loop_factory=get_loop_factory())
loop.close()

with pytest.raises(RuntimeError) as exc_info:
Expand Down Expand Up @@ -399,7 +400,9 @@ def test_loop_runner(loop_in_thread):
async def make_looprunner_in_async_context():
return IOLoop.current(), LoopRunner()

loop, runner = asyncio.run(make_looprunner_in_async_context())
loop, runner = asyncio_run(
make_looprunner_in_async_context(), loop_factory=get_loop_factory()
)
with pytest.raises(
RuntimeError,
match=r"Accessing the loop property while the loop is not running is not supported",
Expand All @@ -423,7 +426,7 @@ async def make_io_loop_in_async_context():
return IOLoop.current()

# Explicit loop
loop = asyncio.run(make_io_loop_in_async_context())
loop = asyncio_run(make_io_loop_in_async_context(), loop_factory=get_loop_factory())
with pytest.raises(
RuntimeError,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is not supported",
Expand All @@ -449,7 +452,7 @@ async def make_io_loop_in_async_context():
LoopRunner(asynchronous=True)

# Explicit loop
loop = asyncio.run(make_io_loop_in_async_context())
loop = asyncio_run(make_io_loop_in_async_context(), loop_factory=get_loop_factory())
with pytest.raises(
RuntimeError,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is not supported",
Expand Down
9 changes: 5 additions & 4 deletions distributed/tests/test_utils_comm.py
@@ -1,12 +1,13 @@
from __future__ import annotations

import asyncio
from unittest import mock

import pytest

from dask.optimization import SubgraphCallable

from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.core import ConnectionPool
from distributed.utils_comm import (
WrappedKey,
Expand Down Expand Up @@ -81,7 +82,7 @@ async def coro():
async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)

assert asyncio.run(f()) is retval
assert asyncio_run(f(), loop_factory=get_loop_factory()) is retval
assert n_calls == 1


Expand All @@ -99,7 +100,7 @@ async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)

with pytest.raises(RuntimeError, match="RT_ERROR 1"):
asyncio.run(f())
asyncio_run(f(), loop_factory=get_loop_factory())

assert n_calls == 1

Expand Down Expand Up @@ -134,7 +135,7 @@ async def f():

with mock.patch("asyncio.sleep", my_sleep):
with pytest.raises(MyEx, match="RT_ERROR 6"):
asyncio.run(f())
asyncio_run(f(), loop_factory=get_loop_factory())

assert n_calls == 6
assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]
Expand Down

0 comments on commit de50c90

Please sign in to comment.