diff --git a/starlette/responses.py b/starlette/responses.py index 3d5b3e43c..d24484c3c 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -157,6 +157,8 @@ def delete_cookie( ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await send( { "type": "http.response.start", diff --git a/starlette/routing.py b/starlette/routing.py index 0aa90aa25..2589d2021 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -6,9 +6,11 @@ import types import typing import warnings -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum +from anyio import Event, create_task_group + from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor @@ -17,7 +19,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -572,6 +574,91 @@ def __call__(self: _T, app: object) -> _T: return self +class LifespanException(Exception): + def __init__(self, message: str): + self.message = message + + +@asynccontextmanager +async def _app_lifespan(scope: Scope, app: ASGIApp) -> typing.AsyncIterator[None]: + startup_sent = Event() + startup_done = Event() + + shutdown_init = Event() + shutdown_sent = Event() + shutdown_done = Event() + + lifespan_supported = False + exception = None + + async def receive() -> Message: + nonlocal lifespan_supported + + lifespan_supported = True + + if not startup_sent.is_set(): + startup_sent.set() + return {"type": "lifespan.startup"} + + elif startup_done.is_set() and not shutdown_sent.is_set(): + await shutdown_init.wait() + shutdown_sent.set() + return {"type": "lifespan.shutdown"} + + else: + raise RuntimeError("unexpected receive") + + async def send(message: Message) -> None: + nonlocal exception, lifespan_supported + + lifespan_supported = True + + if startup_sent.is_set() and not startup_done.is_set(): + if message["type"] == "lifespan.startup.complete": + pass + elif message["type"] == "lifespan.startup.failed": + exception = message.get("message", "") + else: + raise ValueError(f"unexpected type: {message['type']}") + startup_done.set() + + elif shutdown_sent.is_set() and not shutdown_done.is_set(): + if message["type"] == "lifespan.shutdown.complete": + pass + elif message["type"] == "lifespan.shutdown.failed": + exception = message.get("message", "") + else: + raise ValueError(f"unexpected type: {message['type']}") + shutdown_done.set() + + else: + raise RuntimeError("unexpected send") + + async def coro_app(scope: Scope, receive: Receive, send: Send) -> None: + await app(scope, receive, send) + if exception is None and not shutdown_done.is_set(): + raise RuntimeError("lifespan returned unexpectedly") + + try: + async with create_task_group() as tg: + tg.start_soon(coro_app, {**scope, "app": app}, receive, send) + await startup_done.wait() + if exception is not None: + raise LifespanException(exception) + try: + yield + finally: + shutdown_init.set() + await shutdown_done.wait() + if exception is not None: + raise LifespanException(exception) + except Exception: + if lifespan_supported: + raise + else: + yield + + class Router: def __init__( self, @@ -659,6 +746,14 @@ async def shutdown(self) -> None: else: handler() + @asynccontextmanager + async def mount_lifespans(self, scope: Scope) -> typing.AsyncIterator[None]: + async with AsyncExitStack() as stack: + for route in self.routes: + if isinstance(route, Mount): + await stack.enter_async_context(_app_lifespan(scope, route.app)) + yield + async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ Handle ASGI lifespan messages, which allows us to manage application @@ -668,16 +763,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: app = scope.get("app") await receive() try: - async with self.lifespan_context(app): + async with self.lifespan_context(app), self.mount_lifespans(scope): await send({"type": "lifespan.startup.complete"}) started = True await receive() - except BaseException: - exc_text = traceback.format_exc() + except BaseException as e: + if isinstance(e, LifespanException): + message = e.message + else: + message = traceback.format_exc() if started: - await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + await send({"type": "lifespan.shutdown.failed", "message": message}) else: - await send({"type": "lifespan.startup.failed", "message": exc_text}) + await send({"type": "lifespan.startup.failed", "message": message}) raise else: await send({"type": "lifespan.shutdown.complete"}) diff --git a/tests/test_routing.py b/tests/test_routing.py index 09beb8bb9..976a62a84 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,6 +1,8 @@ import functools import typing import uuid +from traceback import format_exc +from unittest.mock import MagicMock import pytest @@ -9,7 +11,15 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.routing import ( + Host, + LifespanException, + Mount, + NoMatchFound, + Route, + Router, + WebSocketRoute, +) from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect @@ -1022,6 +1032,273 @@ def test_host_named_repr() -> None: assert repr(route).startswith("Host(host='example.com', name='app', app=") +@pytest.fixture +def lifespan(): + startup = MagicMock() + shutdown = MagicMock() + + async def lifespan(scope: Scope, receive: Receive, send: Send): + assert scope["type"] == "lifespan" + assert scope["app"] == lifespan + + message = await receive() + assert message["type"] == "lifespan.startup" + try: + startup() + except Exception: + await send( + { + "type": "lifespan.startup.failed", + "message": format_exc(), + } + ) + return + await send({"type": "lifespan.startup.complete"}) + + message = await receive() + assert message["type"] == "lifespan.shutdown" + try: + shutdown() + except Exception: + await send( + { + "type": "lifespan.shutdown.failed", + "message": format_exc(), + } + ) + return + await send({"type": "lifespan.shutdown.complete"}) + + return lifespan, startup, shutdown + + +def test_sub_lifespan(test_client_factory, lifespan): + lifespan, startup, shutdown = lifespan + app = Router(routes=[Mount("/sub", lifespan)]) + + startup.assert_not_called() + shutdown.assert_not_called() + + with test_client_factory(app): + startup.assert_called_once_with() + shutdown.assert_not_called() + + startup.assert_called_once_with() + shutdown.assert_called_once_with() + + +def test_sub_lifespan_startup_fails(test_client_factory, lifespan): + lifespan, startup, shutdown = lifespan + app = Router(routes=[Mount("/sub", lifespan)]) + + startup.assert_not_called() + shutdown.assert_not_called() + + startup.side_effect = Exception("crash") + + client = test_client_factory(app) + with pytest.raises(LifespanException): + client.__enter__() + + startup.assert_called_once_with() + shutdown.assert_not_called() + + +def test_sub_lifespan_shutdown_fails(test_client_factory, lifespan): + lifespan, startup, shutdown = lifespan + app = Router(routes=[Mount("/sub", lifespan)]) + + startup.assert_not_called() + shutdown.assert_not_called() + + shutdown.side_effect = Exception("crash") + + client = test_client_factory(app) + client.__enter__() + + startup.assert_called_once_with() + shutdown.assert_not_called() + + with pytest.raises(LifespanException): + client.__exit__(None, None, None) + + startup.assert_called_once_with() + shutdown.assert_called_once_with() + + +def test_lifespan_receive_during_startup(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await receive() + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + client.__enter__() + + +def test_lifespan_receive_during_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + await receive() + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + client.__enter__() + with pytest.raises(RuntimeError): + client.__exit__(None, None, None) + + +def test_lifespan_receive_after_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + await send({"type": "lifespan.shutdown.complete"}) + await receive() + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + client.__enter__() + with pytest.raises(RuntimeError): + client.__exit__(None, None, None) + + +def test_lifespan_send_before_startup(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + await send({"type": "lifespan.startup.complete"}) + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + client.__enter__() + + +def test_lifespan_send_wrong_type_during_startup(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.shutdown.complete"}) + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + with pytest.raises(ValueError): + client.__enter__() + + +def test_lifespan_send_between_startup_and_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + await send({"type": "lifespan.shutdown.complete"}) + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + client.__enter__() + + +def test_lifespan_send_wrong_type_during_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + await send({"type": "lifespan.startup.complete"}) + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + client.__enter__() + with pytest.raises(ValueError): + client.__exit__(None, None, None) + + +def test_lifespan_send_after_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + await send({"type": "lifespan.shutdown.complete"}) + await send({"type": "lifespan.shutdown.complete"}) + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + client.__enter__() + with pytest.raises(RuntimeError): + client.__exit__(None, None, None) + + +def test_lifespan_return_before_startup(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + pass + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + client.__enter__() + client.__exit__(None, None, None) + + +def test_lifespan_return_during_startup(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + client.__enter__() + + +def test_lifespan_return_between_startup_and_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + client.__enter__() + + +def test_lifespan_return_during_shutdown(test_client_factory): + async def lifespan(scope: Scope, receive: Receive, send: Send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + + app = Router(routes=[Mount("/sub", lifespan)]) + + client = test_client_factory(app) + client.__enter__() + with pytest.raises(RuntimeError): + client.__exit__(None, None, None) + + def test_decorator_deprecations() -> None: router = Router()