diff --git a/httpx/_client.py b/httpx/_client.py index 1f9f3beb56..5702f490b1 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -142,9 +142,8 @@ def __init__( self._response = response self._timer = timer - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - async for chunk in self._stream: - yield chunk + def __aiter__(self) -> typing.AsyncIterator[bytes]: + return self._stream.__aiter__() async def aclose(self) -> None: seconds = await self._timer.async_elapsed() diff --git a/httpx/_models.py b/httpx/_models.py index e0e5278cc0..f6fe5a172f 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -4,6 +4,7 @@ import typing import urllib.request from collections.abc import Mapping +from contextlib import aclosing from http.cookiejar import Cookie, CookieJar from ._content import ByteStream, UnattachedStream, encode_request, encode_response @@ -911,7 +912,7 @@ async def aread(self) -> bytes: async def aiter_bytes( self, chunk_size: typing.Optional[int] = None - ) -> typing.AsyncIterator[bytes]: + ) -> typing.AsyncGenerator[bytes, None]: """ A byte-iterator over the decoded response content. This allows us to handle gzip, deflate, and brotli encoded responses. @@ -924,19 +925,20 @@ async def aiter_bytes( decoder = self._get_content_decoder() chunker = ByteChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for raw_bytes in self.aiter_raw(): - decoded = decoder.decode(raw_bytes) + async with aclosing(self.aiter_raw()) as stream: + async for raw_bytes in stream: + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): yield chunk - decoded = decoder.flush() - for chunk in chunker.decode(decoded): - yield chunk # pragma: no cover - for chunk in chunker.flush(): - yield chunk async def aiter_text( self, chunk_size: typing.Optional[int] = None - ) -> typing.AsyncIterator[str]: + ) -> typing.AsyncGenerator[str, None]: """ A str-iterator over the decoded response content that handles both gzip, deflate, etc but also detects the content's @@ -945,28 +947,30 @@ async def aiter_text( decoder = TextDecoder(encoding=self.encoding or "utf-8") chunker = TextChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for byte_content in self.aiter_bytes(): - text_content = decoder.decode(byte_content) + async with aclosing(self.aiter_bytes()) as stream: + async for byte_content in stream: + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() for chunk in chunker.decode(text_content): yield chunk - text_content = decoder.flush() - for chunk in chunker.decode(text_content): - yield chunk - for chunk in chunker.flush(): - yield chunk + for chunk in chunker.flush(): + yield chunk - async def aiter_lines(self) -> typing.AsyncIterator[str]: + async def aiter_lines(self) -> typing.AsyncGenerator[str, None]: decoder = LineDecoder() with request_context(request=self._request): - async for text in self.aiter_text(): - for line in decoder.decode(text): + async with aclosing(self.aiter_text()) as stream: + async for text in stream: + for line in decoder.decode(text): + yield line + for line in decoder.flush(): yield line - for line in decoder.flush(): - yield line async def aiter_raw( self, chunk_size: typing.Optional[int] = None - ) -> typing.AsyncIterator[bytes]: + ) -> typing.AsyncGenerator[bytes, None]: """ A byte-iterator over the raw response content. """ diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index dfd274e7bf..888e28219e 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -232,12 +232,14 @@ def close(self) -> None: class AsyncResponseStream(AsyncByteStream): def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]): - self._httpcore_stream = httpcore_stream + self._httpcore_stream = httpcore_stream.__aiter__() + + def __aiter__(self) -> typing.AsyncIterator[bytes]: + return self - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __anext__(self) -> bytes: with map_httpcore_exceptions(): - async for part in self._httpcore_stream: - yield part + return await self._httpcore_stream.__anext__() async def aclose(self) -> None: if hasattr(self._httpcore_stream, "aclose"): diff --git a/pyproject.toml b/pyproject.toml index b11c02825b..b7cc224f78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "certifi", - "httpcore>=0.15.0,<0.17.0", + "httpcore==git+https://github.com/encode/httpcore.git@bug/async-early-stream-break", "idna", "sniffio", ] diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 5be0de3b12..2c681b553b 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,4 +1,5 @@ import typing +from contextlib import aclosing from datetime import timedelta import pytest @@ -76,6 +77,34 @@ async def test_stream_response(server): assert response.content == b"Hello, world!" +@pytest.mark.anyio +async def test_stream_iterator(server): + body = b"" + + async with httpx.AsyncClient() as client: + async with client.stream("GET", server.url) as response: + async for chunk in response.aiter_bytes(): + body += chunk + + assert response.status_code == 200 + assert body == b"Hello, world!" + + +@pytest.mark.anyio +async def test_stream_iterator_partial(server): + body = "" + + async with httpx.AsyncClient() as client: + async with client.stream("GET", server.url) as response: + async with aclosing(response.aiter_text(5)) as stream: + async for chunk in stream: + body += chunk + break + + assert response.status_code == 200 + assert body == "Hello" + + @pytest.mark.anyio async def test_access_content_stream_response(server): async with httpx.AsyncClient() as client: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 268cd10689..c35725ecfe 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -107,6 +107,19 @@ def test_stream_iterator(server): assert body == b"Hello, world!" +def test_stream_iterator_partial(server): + body = "" + + with httpx.Client() as client: + with client.stream("GET", server.url) as response: + for chunk in response.iter_text(5): + body += chunk + break + + assert response.status_code == 200 + assert body == "Hello" + + def test_raw_iterator(server): body = b""