Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for streaming responses to ASGITransport #2434

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 95 additions & 74 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,82 +82,103 @@ async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)

# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}

# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False

# Response.
status_code = None
response_headers = None
body_parts = []
response_started = False
response_complete = create_event()

# ASGI callables.

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers, response_started

if message["type"] == "http.response.start":
assert not response_started

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and request.method != "HEAD":
body_parts.append(body)

if not more_body:
response_complete.set()
try:
import anyio # noqa
except ImportError: # pragma: no cover
raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)")

return await run_asgi(
request,
app=self.app,
raise_app_exceptions=self.raise_app_exceptions,
root_path=self.root_path,
client=self.client,
)


async def run_asgi(
request: Request,
app: typing.Callable,
raise_app_exceptions: bool,
root_path: str,
client: typing.Tuple[str, int],
) -> Response:
assert isinstance(request.stream, AsyncByteStream)

# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path,
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": client,
"root_path": root_path,
}

# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False

# Response.
status_code = None
response_headers = None
body_parts = []
response_started = False
response_complete = create_event()

# ASGI callables.

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions or not response_complete.is_set():
raise
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers, response_started

if message["type"] == "http.response.start":
assert not response_started

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and request.method != "HEAD":
body_parts.append(body)

if not more_body:
response_complete.set()

try:
await app(scope, receive, send)
except Exception: # noqa: PIE-786
if raise_app_exceptions or not response_complete.is_set():
raise

assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

stream = ASGIResponseStream(body_parts)
stream = ASGIResponseStream(body_parts)

return Response(status_code, headers=response_headers, stream=stream)
return Response(status_code, headers=response_headers, stream=stream)
26 changes: 26 additions & 0 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,29 @@ async def read_body(scope, receive, send):

assert response.status_code == 200
assert disconnect


@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming():
client = httpx.AsyncClient(app=hello_world)
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
text = "".join([chunk async for chunk in response.aiter_text()])
assert text == "Hello, World!"


@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/"):
pass # pragma: no cover


@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/") as response:
async for _ in response.aiter_bytes():
pass # pragma: no cover