Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unclosed generator when running on trio #2587

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,8 @@ def __init__(
self._response = response
self._timer = timer

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for chunk in self._stream:
yield chunk
def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self._stream.__aiter__()

async def aclose(self) -> None:
seconds = await self._timer.async_elapsed()
Expand Down
48 changes: 26 additions & 22 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
import urllib.request
from collections.abc import Mapping
from contextlib import aclosing
from http.cookiejar import Cookie, CookieJar

from ._content import ByteStream, UnattachedStream, encode_request, encode_response
Expand Down Expand Up @@ -911,7 +912,7 @@ async def aread(self) -> bytes:

async def aiter_bytes(
self, chunk_size: typing.Optional[int] = None
) -> typing.AsyncIterator[bytes]:
) -> typing.AsyncGenerator[bytes, None]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
Expand All @@ -924,19 +925,20 @@ async def aiter_bytes(
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
async with aclosing(self.aiter_raw()) as stream:
async for raw_bytes in stream:
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk

async def aiter_text(
self, chunk_size: typing.Optional[int] = None
) -> typing.AsyncIterator[str]:
) -> typing.AsyncGenerator[str, None]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
Expand All @@ -945,28 +947,30 @@ async def aiter_text(
decoder = TextDecoder(encoding=self.encoding or "utf-8")
chunker = TextChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
async with aclosing(self.aiter_bytes()) as stream:
async for byte_content in stream:
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
for chunk in chunker.flush():
yield chunk
for chunk in chunker.flush():
yield chunk

async def aiter_lines(self) -> typing.AsyncIterator[str]:
async def aiter_lines(self) -> typing.AsyncGenerator[str, None]:
decoder = LineDecoder()
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
async with aclosing(self.aiter_text()) as stream:
async for text in stream:
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
for line in decoder.flush():
yield line

async def aiter_raw(
self, chunk_size: typing.Optional[int] = None
) -> typing.AsyncIterator[bytes]:
) -> typing.AsyncGenerator[bytes, None]:
"""
A byte-iterator over the raw response content.
"""
Expand Down
10 changes: 6 additions & 4 deletions httpx/_transports/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,14 @@ def close(self) -> None:

class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
self._httpcore_stream = httpcore_stream
self._httpcore_stream = httpcore_stream.__aiter__()

def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __anext__(self) -> bytes:
with map_httpcore_exceptions():
async for part in self._httpcore_stream:
yield part
return await self._httpcore_stream.__anext__()

async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"certifi",
"httpcore>=0.15.0,<0.17.0",
"httpcore==git+https://github.com/encode/httpcore.git@bug/async-early-stream-break",
"idna",
"sniffio",
]
Expand Down
29 changes: 29 additions & 0 deletions tests/client/test_async_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from contextlib import aclosing
from datetime import timedelta

import pytest
Expand Down Expand Up @@ -76,6 +77,34 @@ async def test_stream_response(server):
assert response.content == b"Hello, world!"


@pytest.mark.anyio
async def test_stream_iterator(server):
body = b""

async with httpx.AsyncClient() as client:
async with client.stream("GET", server.url) as response:
async for chunk in response.aiter_bytes():
body += chunk

assert response.status_code == 200
assert body == b"Hello, world!"


@pytest.mark.anyio
async def test_stream_iterator_partial(server):
body = ""

async with httpx.AsyncClient() as client:
async with client.stream("GET", server.url) as response:
async with aclosing(response.aiter_text(5)) as stream:
async for chunk in stream:
body += chunk
break

assert response.status_code == 200
assert body == "Hello"


@pytest.mark.anyio
async def test_access_content_stream_response(server):
async with httpx.AsyncClient() as client:
Expand Down
13 changes: 13 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def test_stream_iterator(server):
assert body == b"Hello, world!"


def test_stream_iterator_partial(server):
body = ""

with httpx.Client() as client:
with client.stream("GET", server.url) as response:
for chunk in response.iter_text(5):
body += chunk
break

assert response.status_code == 200
assert body == "Hello"


def test_raw_iterator(server):
body = b""

Expand Down