Skip to content

Commit

Permalink
Accept AsyncIterables being passed to Response
Browse files Browse the repository at this point in the history
  • Loading branch information
mjsir911 committed May 20, 2024
1 parent 2fc6d4f commit a6bf4ae
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/quart/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AnyStr,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
Expand Down
17 changes: 13 additions & 4 deletions src/quart/wrappers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,28 @@ async def __anext__(self) -> bytes:


class IterableBody(ResponseBody):
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
def __init__(self, iterable: AsyncIterable[bytes] | Iterable) -> None:
self.iter: AsyncGenerator[bytes, None]
if isasyncgen(iterable):
self.iter = iterable
elif isgenerator(iterable):
self.iter = run_sync_iterable(iterable)
else:
elif isinstance(iterable, AsyncIterable):

async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable: # type: ignore
async for data in iterable:
yield data

self.iter = _aiter()
elif isinstance(iterable, Iterable):

async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable:
yield data

self.iter = _aiter()
else:
raise ValueError("unreachable?")

async def __aenter__(self) -> IterableBody:
return self
Expand Down Expand Up @@ -262,7 +271,7 @@ class Response(SansIOResponse):

def __init__(
self,
response: ResponseBody | AnyStr | Iterable | None = None,
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
status: int | None = None,
headers: dict | Headers | None = None,
mimetype: str | None = None,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
g,
Quart,
render_template_string,
Response,
ResponseReturnValue,
session,
stream_template_string,
Expand Down Expand Up @@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
test_client = app.test_client()
response = await test_client.get("/")
assert (await response.data) == b"42"

@app.get("/2")
async def index2() -> ResponseReturnValue:
return Response(await stream_template_string("{{ config }}", config=43))

test_client = app.test_client()
response = await test_client.get("/2")
assert (await response.data) == b"43"

0 comments on commit a6bf4ae

Please sign in to comment.