diff --git a/httpx/_dispatch/asgi.py b/httpx/_dispatch/asgi.py index 5edca1ed8c..a86969bcca 100644 --- a/httpx/_dispatch/asgi.py +++ b/httpx/_dispatch/asgi.py @@ -1,9 +1,28 @@ +import typing from typing import Callable, Dict, List, Optional, Tuple import httpcore +import sniffio from .._content_streams import ByteStream +if typing.TYPE_CHECKING: # pragma: no cover + import asyncio + import trio + + Event = typing.Union[asyncio.Event, trio.Event] + + +def create_event() -> "Event": + if sniffio.current_async_library() == "trio": + import trio + + return trio.Event() + else: + import asyncio + + return asyncio.Event() + class ASGIDispatch(httpcore.AsyncHTTPTransport): """ @@ -76,8 +95,9 @@ async def request( status_code = None response_headers = None body_parts = [] + request_complete = False response_started = False - response_complete = False + response_complete = create_event() headers = [] if headers is None else headers stream = ByteStream(b"") if stream is None else stream @@ -85,14 +105,16 @@ async def request( request_body_chunks = stream.__aiter__() async def receive() -> dict: - nonlocal response_complete + nonlocal request_complete, response_complete - if response_complete: + if request_complete: + await response_complete.wait() return {"type": "http.disconnect"} try: body = await request_body_chunks.__anext__() except StopAsyncIteration: + request_complete = True return {"type": "http.request", "body": b"", "more_body": False} return {"type": "http.request", "body": body, "more_body": True} @@ -108,7 +130,7 @@ async def send(message: dict) -> None: response_started = True elif message["type"] == "http.response.body": - assert not response_complete + assert not response_complete.is_set() body = message.get("body", b"") more_body = message.get("more_body", False) @@ -116,7 +138,7 @@ async def send(message: dict) -> None: body_parts.append(body) if not more_body: - response_complete = True + response_complete.set() try: await self.app(scope, receive, send) @@ -124,7 +146,7 @@ async def send(message: dict) -> None: if self.raise_app_exceptions or not response_complete: raise - assert response_complete + assert response_complete.is_set() assert status_code is not None assert response_headers is not None diff --git a/requirements.txt b/requirements.txt index e5ac1a2ed9..dd2409067d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ pytest-asyncio pytest-trio pytest-cov trio +trio-typing trustme uvicorn seed-isort-config diff --git a/setup.py b/setup.py index 554fed604c..6a5b137eda 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def get_packages(package): "idna==2.*", "rfc3986>=1.3,<2", "httpcore>=0.8.3", + "sniffio", ], classifiers=[ "Development Status :: 4 - Beta", diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 72a003936e..d225baf411 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -69,8 +69,7 @@ async def test_asgi_exc_after_response(): await client.get("http://www.example.org/") -@pytest.mark.asyncio -async def test_asgi_disconnect_after_response_complete(): +async def test_asgi_disconnect_after_response_complete(async_environment): disconnect = False async def read_body(scope, receive, send):