From 111ed196f5d8c9b1c77d90c71d3e1e63f241372b Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 8 Aug 2020 14:21:36 +0200 Subject: [PATCH] anyio solution --- httpx/_transports/asgi.py | 330 ++++++++++++++------------------------ requirements.txt | 2 + tests/test_asgi.py | 4 +- 3 files changed, 123 insertions(+), 213 deletions(-) diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index 7a141f9b48..e9c85c3111 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,116 +1,12 @@ -import contextlib -from typing import ( - TYPE_CHECKING, - AsyncIterator, - Awaitable, - Callable, - List, - Mapping, - Optional, - Tuple, - Union, -) +import sys +from typing import AsyncIterator, Callable, List, Mapping, Optional, Tuple import httpcore -import sniffio -if TYPE_CHECKING: # pragma: no cover - import asyncio - - import trio - - Event = Union[asyncio.Event, trio.Event] - - -def create_event() -> "Event": - if sniffio.current_async_library() == "trio": - import trio - - return trio.Event() - else: - import asyncio - - return asyncio.Event() - - -async def create_background_task( - async_fn: Callable[[], Awaitable[None]] -) -> Callable[[], Awaitable[None]]: - if sniffio.current_async_library() == "trio": - import trio - - nursery_manager = trio.open_nursery() - nursery = await nursery_manager.__aenter__() - nursery.start_soon(async_fn) - - async def aclose() -> None: - await nursery_manager.__aexit__(None, None, None) - - return aclose - - else: - import asyncio - - loop = asyncio.get_event_loop() - task = loop.create_task(async_fn()) - - async def aclose() -> None: - task.cancel() - # Task must be awaited in all cases to avoid debug warnings. - with contextlib.suppress(asyncio.CancelledError): - await task - - return aclose - - -def create_channel( - capacity: int, -) -> Tuple[ - Callable[[bytes], Awaitable[None]], - Callable[[], Awaitable[None]], - Callable[[], AsyncIterator[bytes]], -]: - """ - Create an in-memory channel to pass data chunks between tasks. - - * `produce()`: send data through the channel, blocking if necessary. - * `consume()`: iterate over data in the channel. - * `aclose_produce()`: mark that no more data will be produced, causing - `consume()` to flush remaining data chunks then stop. - """ - if sniffio.current_async_library() == "trio": - import trio - - send_channel, receive_channel = trio.open_memory_channel[bytes](capacity) - - async def consume() -> AsyncIterator[bytes]: - async for chunk in receive_channel: - yield chunk - - return send_channel.send, send_channel.aclose, consume - - else: - import asyncio - - queue: asyncio.Queue[bytes] = asyncio.Queue(capacity) - produce_closed = False - - async def produce(chunk: bytes) -> None: - assert not produce_closed - await queue.put(chunk) - - async def aclose_produce() -> None: - nonlocal produce_closed - await queue.put(b"") # Make sure (*) doesn't block forever. - produce_closed = True - - async def consume() -> AsyncIterator[bytes]: - while True: - if produce_closed and queue.empty(): - break - yield await queue.get() # (*) - - return produce, aclose_produce, consume +try: + from contextlib import asynccontextmanager # type: ignore # Python 3.6. +except ImportError: # pragma: no cover # Python 3.6. + from async_generator import asynccontextmanager # type: ignore class ASGITransport(httpcore.AsyncHTTPTransport): @@ -153,6 +49,11 @@ def __init__( root_path: str = "", client: Tuple[str, int] = ("127.0.0.1", 123), ) -> None: + try: + import anyio # noqa + except ImportError: + raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)") + self.app = app self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path @@ -166,111 +67,118 @@ async def request( stream: httpcore.AsyncByteStream = None, timeout: Mapping[str, Optional[float]] = None, ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]: + headers = [] if headers is None else headers stream = httpcore.PlainByteStream(content=b"") if stream is None else stream - # ASGI scope. - scheme, host, port, full_path = url - path, _, query = full_path.partition(b"?") - scope = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": method.decode(), - "headers": headers, - "scheme": scheme.decode("ascii"), - "path": path.decode("ascii"), - "query_string": query, - "server": (host.decode("ascii"), port), - "client": self.client, - "root_path": self.root_path, - } - - # Request. - request_body_chunks = stream.__aiter__() - request_complete = False - - # Response. - status_code: Optional[int] = None - response_headers: Optional[List[Tuple[bytes, bytes]]] = None - produce_body, aclose_body, consume_body = create_channel(1) - response_started_or_app_crashed = create_event() - response_complete = create_event() - - # ASGI callables. - - async def receive() -> dict: - nonlocal request_complete - - if request_complete: - await response_complete.wait() - return {"type": "http.disconnect"} - - try: - body = await request_body_chunks.__anext__() - except StopAsyncIteration: - request_complete = True - return {"type": "http.request", "body": b"", "more_body": False} - return {"type": "http.request", "body": body, "more_body": True} - - async def send(message: dict) -> None: - nonlocal status_code, response_headers - if message["type"] == "http.response.start": - assert not response_started_or_app_crashed.is_set() - status_code = message["status"] - response_headers = message.get("headers", []) - response_started_or_app_crashed.set() - - elif message["type"] == "http.response.body": - assert not response_complete.is_set() - body = message.get("body", b"") - more_body = message.get("more_body", False) - - if body and method != b"HEAD": - await produce_body(body) - - if not more_body: - await aclose_body() - response_complete.set() - - # Application wrapper. - - app_exception: Optional[Exception] = None - - async def run_app() -> None: - nonlocal app_exception - try: - await self.app(scope, receive, send) - except Exception as exc: - app_exception = exc - response_started_or_app_crashed.set() - await aclose_body() # Stop response body consumer once flushed (*). - - # Response body iterator. - - async def aiter_response_body() -> AsyncIterator[bytes]: - async for chunk in consume_body(): # (*) - yield chunk - - if app_exception is not None and self.raise_app_exceptions: - raise app_exception - - # Now we wire things up... - - aclose = await create_background_task(run_app) - - await response_started_or_app_crashed.wait() - - if app_exception is not None: - await aclose() - if self.raise_app_exceptions or not response_complete.is_set(): - raise app_exception + app_context = run_asgi( + self.app, + method, + url, + headers, + stream, + client=self.client, + root_path=self.root_path, + ) + + status_code, response_headers, response_body = await app_context.__aenter__() - assert status_code is not None - assert response_headers is not None + async def aclose() -> None: + await app_context.__aexit__(*sys.exc_info()) - stream = httpcore.AsyncIteratorByteStream( - aiter_response_body(), aclose_func=aclose - ) + stream = httpcore.AsyncIteratorByteStream(response_body, aclose_func=aclose) return (b"HTTP/1.1", status_code, b"", response_headers, stream) + + +@asynccontextmanager +async def run_asgi( + app: Callable, + method: bytes, + url: Tuple[bytes, bytes, Optional[int], bytes], + headers: List[Tuple[bytes, bytes]], + stream: httpcore.AsyncByteStream, + *, + client: str, + root_path: str, +) -> AsyncIterator[Tuple[int, List[Tuple[bytes, bytes]], AsyncIterator[bytes]]]: + import anyio + + # ASGI scope. + scheme, host, port, full_path = url + path, _, query = full_path.partition(b"?") + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method.decode(), + "headers": headers, + "scheme": scheme.decode("ascii"), + "path": path.decode("ascii"), + "query_string": query, + "server": (host.decode("ascii"), port), + "client": client, + "root_path": root_path, + } + + # Request. + request_body_chunks = stream.__aiter__() + request_complete = False + + # Response. + status_code: Optional[int] = None + response_headers: Optional[List[Tuple[bytes, bytes]]] = None + response_body_queue = anyio.create_queue(1) + response_started = anyio.create_event() + response_complete = anyio.create_event() + + async def receive() -> dict: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: dict) -> None: + nonlocal status_code, response_headers + if message["type"] == "http.response.start": + assert not response_started.is_set() + status_code = message["status"] + response_headers = message.get("headers", []) + await response_started.set() + + elif message["type"] == "http.response.body": + assert not response_complete.is_set() + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and method != b"HEAD": + await response_body_queue.put(body) + + if not more_body: + await response_body_queue.put(None) + await response_complete.set() + + async def body_iterator() -> AsyncIterator[bytes]: + while True: + chunk = await response_body_queue.get() # (*) + if chunk is None: + break + yield chunk + + async with anyio.create_task_group() as task_group: + await task_group.spawn(app, scope, receive, send) + + await response_started.wait() + + assert status_code is not None + assert response_headers is not None + + yield status_code, response_headers, body_iterator() diff --git a/requirements.txt b/requirements.txt index a901dbeaa8..23d23bcdc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -e .[http2] # Optional +async_generator; python_version < '3.7' +anyio brotlipy==0.7.* # Documentation diff --git a/tests/test_asgi.py b/tests/test_asgi.py index b3b81bdec7..35d08b95ff 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -134,7 +134,7 @@ async def test_asgi_streaming_exc(): @pytest.mark.usefixtures("async_environment") async def test_asgi_streaming_exc_after_response(): client = httpx.AsyncClient(app=raise_exc_after_response) - async with client.stream("GET", "http://www.example.org/") as response: - with pytest.raises(ValueError): + with pytest.raises(ValueError): + async with client.stream("GET", "http://www.example.org/") as response: async for _ in response.aiter_bytes(): pass # pragma: no cover