Skip to content

Commit

Permalink
ASGI: Wait for response to complete before sending disconnect message (
Browse files Browse the repository at this point in the history
…#919)

* asgi: Wait for response to complete before sending disconnect message

* Dial back type checking + remove concurrency module

* Remove somewhat redundant comment
  • Loading branch information
JayH5 committed May 12, 2020
1 parent 560b119 commit d568ecd
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
34 changes: 28 additions & 6 deletions httpx/_dispatch/asgi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple

import httpcore
import sniffio

from .._content_streams import ByteStream

if typing.TYPE_CHECKING: # pragma: no cover
import asyncio
import trio

Event = typing.Union[asyncio.Event, trio.Event]


def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
import trio

return trio.Event()
else:
import asyncio

return asyncio.Event()


class ASGIDispatch(httpcore.AsyncHTTPTransport):
"""
Expand Down Expand Up @@ -76,23 +95,26 @@ async def request(
status_code = None
response_headers = None
body_parts = []
request_complete = False
response_started = False
response_complete = False
response_complete = create_event()

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

request_body_chunks = stream.__aiter__()

async def receive() -> dict:
nonlocal response_complete
nonlocal request_complete, response_complete

if response_complete:
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

Expand All @@ -108,23 +130,23 @@ async def send(message: dict) -> None:
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
body_parts.append(body)

if not more_body:
response_complete = True
response_complete.set()

try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete:
raise

assert response_complete
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pytest-asyncio
pytest-trio
pytest-cov
trio
trio-typing
trustme
uvicorn
seed-isort-config
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_packages(package):
"idna==2.*",
"rfc3986>=1.3,<2",
"httpcore>=0.8.3",
"sniffio",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down
3 changes: 1 addition & 2 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ async def test_asgi_exc_after_response():
await client.get("http://www.example.org/")


@pytest.mark.asyncio
async def test_asgi_disconnect_after_response_complete():
async def test_asgi_disconnect_after_response_complete(async_environment):
disconnect = False

async def read_body(scope, receive, send):
Expand Down

0 comments on commit d568ecd

Please sign in to comment.