Skip to content

Commit

Permalink
anyio solution
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Aug 8, 2020
1 parent ccc2d19 commit 111ed19
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 213 deletions.
330 changes: 119 additions & 211 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,12 @@
import contextlib
from typing import (
TYPE_CHECKING,
AsyncIterator,
Awaitable,
Callable,
List,
Mapping,
Optional,
Tuple,
Union,
)
import sys
from typing import AsyncIterator, Callable, List, Mapping, Optional, Tuple

import httpcore
import sniffio

if TYPE_CHECKING: # pragma: no cover
import asyncio

import trio

Event = Union[asyncio.Event, trio.Event]


def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
import trio

return trio.Event()
else:
import asyncio

return asyncio.Event()


async def create_background_task(
async_fn: Callable[[], Awaitable[None]]
) -> Callable[[], Awaitable[None]]:
if sniffio.current_async_library() == "trio":
import trio

nursery_manager = trio.open_nursery()
nursery = await nursery_manager.__aenter__()
nursery.start_soon(async_fn)

async def aclose() -> None:
await nursery_manager.__aexit__(None, None, None)

return aclose

else:
import asyncio

loop = asyncio.get_event_loop()
task = loop.create_task(async_fn())

async def aclose() -> None:
task.cancel()
# Task must be awaited in all cases to avoid debug warnings.
with contextlib.suppress(asyncio.CancelledError):
await task

return aclose


def create_channel(
capacity: int,
) -> Tuple[
Callable[[bytes], Awaitable[None]],
Callable[[], Awaitable[None]],
Callable[[], AsyncIterator[bytes]],
]:
"""
Create an in-memory channel to pass data chunks between tasks.
* `produce()`: send data through the channel, blocking if necessary.
* `consume()`: iterate over data in the channel.
* `aclose_produce()`: mark that no more data will be produced, causing
`consume()` to flush remaining data chunks then stop.
"""
if sniffio.current_async_library() == "trio":
import trio

send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)

async def consume() -> AsyncIterator[bytes]:
async for chunk in receive_channel:
yield chunk

return send_channel.send, send_channel.aclose, consume

else:
import asyncio

queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
produce_closed = False

async def produce(chunk: bytes) -> None:
assert not produce_closed
await queue.put(chunk)

async def aclose_produce() -> None:
nonlocal produce_closed
await queue.put(b"") # Make sure (*) doesn't block forever.
produce_closed = True

async def consume() -> AsyncIterator[bytes]:
while True:
if produce_closed and queue.empty():
break
yield await queue.get() # (*)

return produce, aclose_produce, consume
try:
from contextlib import asynccontextmanager # type: ignore # Python 3.6.
except ImportError: # pragma: no cover # Python 3.6.
from async_generator import asynccontextmanager # type: ignore


class ASGITransport(httpcore.AsyncHTTPTransport):
Expand Down Expand Up @@ -153,6 +49,11 @@ def __init__(
root_path: str = "",
client: Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
try:
import anyio # noqa
except ImportError:
raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)")

self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
Expand All @@ -166,111 +67,118 @@ async def request(
stream: httpcore.AsyncByteStream = None,
timeout: Mapping[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:

headers = [] if headers is None else headers
stream = httpcore.PlainByteStream(content=b"") if stream is None else stream

# ASGI scope.
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method.decode(),
"headers": headers,
"scheme": scheme.decode("ascii"),
"path": path.decode("ascii"),
"query_string": query,
"server": (host.decode("ascii"), port),
"client": self.client,
"root_path": self.root_path,
}

# Request.
request_body_chunks = stream.__aiter__()
request_complete = False

# Response.
status_code: Optional[int] = None
response_headers: Optional[List[Tuple[bytes, bytes]]] = None
produce_body, aclose_body, consume_body = create_channel(1)
response_started_or_app_crashed = create_event()
response_complete = create_event()

# ASGI callables.

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers
if message["type"] == "http.response.start":
assert not response_started_or_app_crashed.is_set()
status_code = message["status"]
response_headers = message.get("headers", [])
response_started_or_app_crashed.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
await produce_body(body)

if not more_body:
await aclose_body()
response_complete.set()

# Application wrapper.

app_exception: Optional[Exception] = None

async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send)
except Exception as exc:
app_exception = exc
response_started_or_app_crashed.set()
await aclose_body() # Stop response body consumer once flushed (*).

# Response body iterator.

async def aiter_response_body() -> AsyncIterator[bytes]:
async for chunk in consume_body(): # (*)
yield chunk

if app_exception is not None and self.raise_app_exceptions:
raise app_exception

# Now we wire things up...

aclose = await create_background_task(run_app)

await response_started_or_app_crashed.wait()

if app_exception is not None:
await aclose()
if self.raise_app_exceptions or not response_complete.is_set():
raise app_exception
app_context = run_asgi(
self.app,
method,
url,
headers,
stream,
client=self.client,
root_path=self.root_path,
)

status_code, response_headers, response_body = await app_context.__aenter__()

assert status_code is not None
assert response_headers is not None
async def aclose() -> None:
await app_context.__aexit__(*sys.exc_info())

stream = httpcore.AsyncIteratorByteStream(
aiter_response_body(), aclose_func=aclose
)
stream = httpcore.AsyncIteratorByteStream(response_body, aclose_func=aclose)

return (b"HTTP/1.1", status_code, b"", response_headers, stream)


@asynccontextmanager
async def run_asgi(
app: Callable,
method: bytes,
url: Tuple[bytes, bytes, Optional[int], bytes],
headers: List[Tuple[bytes, bytes]],
stream: httpcore.AsyncByteStream,
*,
client: str,
root_path: str,
) -> AsyncIterator[Tuple[int, List[Tuple[bytes, bytes]], AsyncIterator[bytes]]]:
import anyio

# ASGI scope.
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method.decode(),
"headers": headers,
"scheme": scheme.decode("ascii"),
"path": path.decode("ascii"),
"query_string": query,
"server": (host.decode("ascii"), port),
"client": client,
"root_path": root_path,
}

# Request.
request_body_chunks = stream.__aiter__()
request_complete = False

# Response.
status_code: Optional[int] = None
response_headers: Optional[List[Tuple[bytes, bytes]]] = None
response_body_queue = anyio.create_queue(1)
response_started = anyio.create_event()
response_complete = anyio.create_event()

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers
if message["type"] == "http.response.start":
assert not response_started.is_set()
status_code = message["status"]
response_headers = message.get("headers", [])
await response_started.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
await response_body_queue.put(body)

if not more_body:
await response_body_queue.put(None)
await response_complete.set()

async def body_iterator() -> AsyncIterator[bytes]:
while True:
chunk = await response_body_queue.get() # (*)
if chunk is None:
break
yield chunk

async with anyio.create_task_group() as task_group:
await task_group.spawn(app, scope, receive, send)

await response_started.wait()

assert status_code is not None
assert response_headers is not None

yield status_code, response_headers, body_iterator()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
-e .[http2]

# Optional
async_generator; python_version < '3.7'
anyio
brotlipy==0.7.*

# Documentation
Expand Down
4 changes: 2 additions & 2 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def test_asgi_streaming_exc():
@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
async with client.stream("GET", "http://www.example.org/") as response:
with pytest.raises(ValueError):
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/") as response:
async for _ in response.aiter_bytes():
pass # pragma: no cover

0 comments on commit 111ed19

Please sign in to comment.