Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for streaming in ASGIDispatch #2588

Draft
wants to merge 1 commit into
base: bug/async-early-stream-break
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 147 additions & 76 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import typing
from contextlib import AsyncExitStack, asynccontextmanager

import sniffio

from .._models import Request, Response
from .._types import AsyncByteStream
from .base import AsyncBaseTransport

try:
import anyio
except ImportError: # pragma: no cover
anyio = None # type: ignore


if typing.TYPE_CHECKING: # pragma: no cover
import asyncio

Expand Down Expand Up @@ -35,12 +42,19 @@ def create_event() -> "Event":
return asyncio.Event()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like we can drop this create_event() function and the preceding if typing.TYPE_CHECKING block?



class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: typing.List[bytes]) -> None:
self._body = body
class ASGIResponseByteStream(AsyncByteStream):
def __init__(
self, stream: typing.AsyncGenerator[bytes, None], app_context: AsyncExitStack
) -> None:
self._stream = stream
self._app_context = app_context

def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self._stream.__aiter__()

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b"".join(self._body)
async def aclose(self) -> None:
await self._stream.aclose()
await self._app_context.aclose()


class ASGITransport(AsyncBaseTransport):
Expand Down Expand Up @@ -83,6 +97,9 @@ def __init__(
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
if anyio is None:
raise RuntimeError("ASGITransport requires anyio (Hint: pip install anyio)")

self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
Expand All @@ -92,82 +109,136 @@ async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)

# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}

# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False

# Response.
status_code = None
response_headers = None
body_parts = []
response_started = False
response_complete = create_event()

# ASGI callables.

async def receive() -> typing.Dict[str, typing.Any]:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: typing.Dict[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started

if message["type"] == "http.response.start":
assert not response_started

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and request.method != "HEAD":
body_parts.append(body)

if not more_body:
response_complete.set()

exit_stack = AsyncExitStack()

(
status_code,
response_headers,
response_body,
) = await exit_stack.enter_async_context(
run_asgi(
self.app,
raise_app_exceptions=self.raise_app_exceptions,
root_path=self.root_path,
client=self.client,
request=request,
)
)

return Response(
status_code,
headers=response_headers,
stream=ASGIResponseByteStream(response_body, exit_stack),
)


@asynccontextmanager
async def run_asgi(
app: _ASGIApp,
raise_app_exceptions: bool,
client: typing.Tuple[str, int],
root_path: str,
request: Request,
) -> typing.AsyncIterator[
typing.Tuple[
int,
typing.Sequence[typing.Tuple[bytes, bytes]],
typing.AsyncGenerator[bytes, None],
]
]:
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": client,
"root_path": root_path,
}

# Request.
assert isinstance(request.stream, AsyncByteStream)
request_body_chunks = request.stream.__aiter__()
request_complete = False

# Response.
status_code = None
response_headers = None
response_started = anyio.Event()
response_complete = anyio.Event()

send_stream, receive_stream = anyio.create_memory_object_stream()
disconnected = anyio.Event()

async def watch_disconnect(cancel_scope: anyio.CancelScope) -> None:
await disconnected.wait()
cancel_scope.cancel()

async def run_app(cancel_scope: anyio.CancelScope) -> None:
try:
await self.app(scope, receive, send)
await app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions or not response_complete.is_set():
if raise_app_exceptions or not response_complete.is_set():
raise

assert response_complete.is_set()
# ASGI callables.

async def receive() -> typing.Dict[str, typing.Any]:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: _Message) -> None:
nonlocal status_code, response_headers

if disconnected.is_set():
return

if message["type"] == "http.response.start":
assert not response_started.is_set()

status_code = message["status"]
response_headers = message.get("headers", [])
response_started.set()

elif message["type"] == "http.response.body":
assert response_started.is_set()
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and request.method != "HEAD":
await send_stream.send(body)

if not more_body:
response_complete.set()

async with anyio.create_task_group() as tg:
tg.start_soon(watch_disconnect, tg.cancel_scope)
tg.start_soon(run_app, tg.cancel_scope)

await response_started.wait()
assert status_code is not None
assert response_headers is not None

stream = ASGIResponseStream(body_parts)
async def stream() -> typing.AsyncGenerator[bytes, None]:
async for chunk in receive_stream:
yield chunk

return Response(status_code, headers=response_headers, stream=stream)
yield (status_code, response_headers, stream())
disconnected.set()
52 changes: 52 additions & 0 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from contextlib import aclosing

import pytest

Expand Down Expand Up @@ -56,6 +57,20 @@ async def echo_headers(scope, receive, send):
await send({"type": "http.response.body", "body": output})


async def hello_world_endlessly(scope, receive, send):
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})

k = 0
while True:
body = b"%d: %s\n" % (k, output)
await send({"type": "http.response.body", "body": body, "more_body": True})
k += 1


async def raise_exc(scope, receive, send):
raise RuntimeError()

Expand Down Expand Up @@ -191,3 +206,40 @@ async def read_body(scope, receive, send):

assert response.status_code == 200
assert disconnect


@pytest.mark.anyio
async def test_asgi_streaming():
client = httpx.AsyncClient(app=hello_world_endlessly)
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
lines = []

async with aclosing(response.aiter_lines()) as stream:
async for line in stream:
if line.startswith("3: "):
break
lines.append(line)

assert lines == [
"0: Hello, World!\n",
"1: Hello, World!\n",
"2: Hello, World!\n",
]


@pytest.mark.anyio
async def test_asgi_streaming_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(RuntimeError):
async with client.stream("GET", "http://www.example.org/"):
pass # pragma: no cover


@pytest.mark.anyio
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
with pytest.raises(RuntimeError):
async with client.stream("GET", "http://www.example.org/") as response:
async for _ in response.aiter_bytes():
pass # pragma: no cover