From 11a3f1227e69a29fc9dc56f348f982cb5372a1d1 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 13 Jul 2023 09:48:10 +0200 Subject: [PATCH] Stop `body_stream` in case `more_body=False` (#2194) --- starlette/middleware/base.py | 2 ++ tests/middleware/test_base.py | 65 +++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 2ff0e047b..170a805a7 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -170,6 +170,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: body = message.get("body", b"") if body: yield body + if not message.get("more_body", False): + break if app_exc is not None: raise app_exc diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index f7dcf521c..cf4780cce 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -265,6 +265,71 @@ async def send(message): assert background_task_run.is_set() +@pytest.mark.anyio +async def test_do_not_block_on_background_tasks(): + request_body_sent = False + response_complete = anyio.Event() + events: List[Union[str, Message]] = [] + + async def sleep_and_set(): + events.append("Background task started") + await anyio.sleep(0.1) + events.append("Background task finished") + + async def endpoint_with_background_task(_): + return PlainTextResponse( + content="Hello", background=BackgroundTask(sleep_and_set) + ) + + async def passthrough( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_background_task)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message): + if message["type"] == "http.response.body": + events.append(message) + if not message.get("more_body", False): + response_complete.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(app, scope, receive, send) + tg.start_soon(app, scope, receive, send) + + # Without the fix, the background tasks would start and finish before the + # last http.response.body is sent. + assert events == [ + {"body": b"Hello", "more_body": True, "type": "http.response.body"}, + {"body": b"", "more_body": False, "type": "http.response.body"}, + {"body": b"Hello", "more_body": True, "type": "http.response.body"}, + {"body": b"", "more_body": False, "type": "http.response.body"}, + "Background task started", + "Background task started", + "Background task finished", + "Background task finished", + ] + + @pytest.mark.anyio async def test_run_context_manager_exit_even_if_client_disconnects(): # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042