Skip to content

Commit

Permalink
use asyncio.run(..., loop_factory) to avoid asyncio.set_event_loop_po…
Browse files Browse the repository at this point in the history
…licy
  • Loading branch information
graingert committed Oct 14, 2023
1 parent 40b99b8 commit 12c6f54
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 29 deletions.
11 changes: 6 additions & 5 deletions tests/test_auto_detection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import contextlib
import importlib

import pytest

from uvicorn.config import Config
from uvicorn.loops.auto import auto_loop_setup
from uvicorn.loops.auto import auto_loop_factory
from uvicorn.main import ServerState
from uvicorn.protocols.http.auto import AutoHTTPProtocol
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol
Expand Down Expand Up @@ -37,10 +38,10 @@ async def app(scope, receive, send):


def test_loop_auto():
auto_loop_setup()
policy = asyncio.get_event_loop_policy()
assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy)
assert type(policy).__module__.startswith(expected_loop)
loop_factory = auto_loop_factory()
with contextlib.closing(loop_factory()) as loop:
assert isinstance(loop, asyncio.AbstractEventLoop)
assert type(loop).__module__.startswith(expected_loop)


@pytest.mark.anyio
Expand Down
86 changes: 86 additions & 0 deletions uvicorn/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import asyncio
import sys
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar

_T = TypeVar("_T")

if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):

def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)

else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)

if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")

if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()

def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
17 changes: 9 additions & 8 deletions uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@
"on": "uvicorn.lifespan.on:LifespanOn",
"off": "uvicorn.lifespan.off:LifespanOff",
}
LOOP_SETUPS: Dict[LoopSetupType, Optional[str]] = {
LOOP_FACTORIES: Dict[LoopSetupType, Optional[str]] = {
"none": None,
"auto": "uvicorn.loops.auto:auto_loop_setup",
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
"auto": "uvicorn.loops.auto:auto_loop_factory",
"asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
"uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
}
INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]

Expand Down Expand Up @@ -507,10 +507,11 @@ def load(self) -> None:

self.loaded = True

def setup_event_loop(self) -> None:
loop_setup: Optional[Callable] = import_from_string(LOOP_SETUPS[self.loop])
if loop_setup is not None:
loop_setup(use_subprocess=self.use_subprocess)
def get_loop_factory(self) -> Union[Callable[[], asyncio.AbstractEventLoop], None]:
loop_factory: Optional[Callable] = import_from_string(LOOP_FACTORIES[self.loop])
if loop_factory is None:
return None
return loop_factory(use_subprocess=self.use_subprocess)

def bind_socket(self) -> socket.socket:
logger_args: List[Union[str, int]]
Expand Down
15 changes: 12 additions & 3 deletions uvicorn/loops/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from __future__ import annotations

import asyncio
import logging
import sys
from collections.abc import Callable
from typing import TypeVar

_T = TypeVar("_T")

logger = logging.getLogger("uvicorn.error")


def asyncio_setup(use_subprocess: bool = False) -> None:
if sys.platform == "win32" and use_subprocess:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def asyncio_loop_factory(
use_subprocess: bool = False,
) -> Callable[[], asyncio.AbstractEventLoop]:
if sys.platform == "win32" and not use_subprocess:
return asyncio.ProactorEventLoop
return asyncio.SelectorEventLoop
18 changes: 13 additions & 5 deletions uvicorn/loops/auto.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
def auto_loop_setup(use_subprocess: bool = False) -> None:
from __future__ import annotations

import asyncio
from collections.abc import Callable


def auto_loop_factory(
use_subprocess: bool = False,
) -> Callable[[], asyncio.AbstractEventLoop]:
try:
import uvloop # noqa
except ImportError: # pragma: no cover
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory

loop_setup(use_subprocess=use_subprocess)
return loop_factory(use_subprocess=use_subprocess)
else: # pragma: no cover
from uvicorn.loops.uvloop import uvloop_setup
from uvicorn.loops.uvloop import uvloop_loop_factory

uvloop_setup(use_subprocess=use_subprocess)
return uvloop_loop_factory(use_subprocess=use_subprocess)
9 changes: 7 additions & 2 deletions uvicorn/loops/uvloop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable

import uvloop


def uvloop_setup(use_subprocess: bool = False) -> None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
def uvloop_loop_factory(
use_subprocess: bool = False,
) -> Callable[[], asyncio.AbstractEventLoop]:
return uvloop.new_event_loop
4 changes: 2 additions & 2 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LIFESPAN,
LOG_LEVELS,
LOGGING_CONFIG,
LOOP_SETUPS,
LOOP_FACTORIES,
SSL_PROTOCOL_VERSION,
WS_PROTOCOLS,
Config,
Expand All @@ -33,7 +33,7 @@
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
LOOP_CHOICES = click.Choice([key for key in LOOP_FACTORIES.keys() if key != "none"])
INTERFACE_CHOICES = click.Choice(INTERFACES)

STARTUP_FAILURE = 3
Expand Down
6 changes: 4 additions & 2 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import click

from uvicorn._compat import asyncio_run
from uvicorn.config import Config

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,8 +58,9 @@ def __init__(self, config: Config) -> None:
self.last_notified = 0.0

def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))
return asyncio_run(
self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory()
)

async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None:
process_id = os.getpid()
Expand Down
4 changes: 2 additions & 2 deletions uvicorn/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gunicorn.arbiter import Arbiter
from gunicorn.workers.base import Worker

from uvicorn._compat import asyncio_run
from uvicorn.config import Config
from uvicorn.main import Server

Expand Down Expand Up @@ -62,7 +63,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.config = Config(**config_kwargs)

def init_process(self) -> None:
self.config.setup_event_loop()
super(UvicornWorker, self).init_process()

def init_signals(self) -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ async def _serve(self) -> None:
sys.exit(Arbiter.WORKER_BOOT_ERROR)

def run(self) -> None:
return asyncio.run(self._serve())
return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory())

async def callback_notify(self) -> None:
self.notify()
Expand Down

0 comments on commit 12c6f54

Please sign in to comment.