Skip to content

Commit

Permalink
Stream interface (#1550)
Browse files Browse the repository at this point in the history
* Add SyncByteStream, AsyncByteStream to interface

* request.stream and response.stream as httpx.SyncByteStream/httpx.AsyncByteStream

* Update httpx/_transports/base.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>

* Update httpx/_transports/default.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>

* Move response classes in transports to module level

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
  • Loading branch information
tomchristie and florimondmanca committed Apr 13, 2021
1 parent 535df6c commit 110ce85
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 155 deletions.
4 changes: 2 additions & 2 deletions docs/advanced.md
Expand Up @@ -1070,7 +1070,7 @@ class HelloWorldTransport(httpx.BaseTransport):
def handle_request(self, method, url, headers, stream, extensions):
message = {"text": "Hello, world!"}
content = json.dumps(message).encode("utf-8")
stream = [content]
stream = httpx.ByteStream(content)
headers = [(b"content-type", b"application/json")]
extensions = {}
return 200, headers, stream, extensions
Expand Down Expand Up @@ -1131,7 +1131,7 @@ class HTTPSRedirectTransport(httpx.BaseTransport):
location = b"https://%s%s" % (host, path)
else:
location = b"https://%s:%d%s" % (host, port, path)
stream = [b""]
stream = httpx.ByteStream(b"")
headers = [(b"location", location)]
extensions = {}
return 303, headers, stream, extensions
Expand Down
11 changes: 10 additions & 1 deletion httpx/__init__.py
Expand Up @@ -3,6 +3,7 @@
from ._auth import Auth, BasicAuth, DigestAuth
from ._client import AsyncClient, Client
from ._config import Limits, Proxy, Timeout, create_ssl_context
from ._content import ByteStream
from ._exceptions import (
CloseError,
ConnectError,
Expand Down Expand Up @@ -36,7 +37,12 @@
from ._models import URL, Cookies, Headers, QueryParams, Request, Response
from ._status_codes import StatusCode, codes
from ._transports.asgi import ASGITransport
from ._transports.base import AsyncBaseTransport, BaseTransport
from ._transports.base import (
AsyncBaseTransport,
AsyncByteStream,
BaseTransport,
SyncByteStream,
)
from ._transports.default import AsyncHTTPTransport, HTTPTransport
from ._transports.mock import MockTransport
from ._transports.wsgi import WSGITransport
Expand All @@ -47,11 +53,13 @@
"__version__",
"ASGITransport",
"AsyncBaseTransport",
"AsyncByteStream",
"AsyncClient",
"AsyncHTTPTransport",
"Auth",
"BaseTransport",
"BasicAuth",
"ByteStream",
"Client",
"CloseError",
"codes",
Expand Down Expand Up @@ -97,6 +105,7 @@
"stream",
"StreamConsumed",
"StreamError",
"SyncByteStream",
"Timeout",
"TimeoutException",
"TooManyRedirects",
Expand Down
16 changes: 9 additions & 7 deletions httpx/_client.py
Expand Up @@ -26,12 +26,16 @@
from ._models import URL, Cookies, Headers, QueryParams, Request, Response
from ._status_codes import codes
from ._transports.asgi import ASGITransport
from ._transports.base import AsyncBaseTransport, BaseTransport
from ._transports.base import (
AsyncBaseTransport,
AsyncByteStream,
BaseTransport,
SyncByteStream,
)
from ._transports.default import AsyncHTTPTransport, HTTPTransport
from ._transports.wsgi import WSGITransport
from ._types import (
AuthTypes,
ByteStream,
CertTypes,
CookieTypes,
HeaderTypes,
Expand Down Expand Up @@ -509,7 +513,7 @@ def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers:

def _redirect_stream(
self, request: Request, method: str
) -> typing.Optional[ByteStream]:
) -> typing.Optional[typing.Union[SyncByteStream, AsyncByteStream]]:
"""
Return the body that should be used for the redirect request.
"""
Expand Down Expand Up @@ -880,8 +884,7 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:

def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
if "close" in extensions:
extensions["close"]()
stream.close()

response = Response(
status_code,
Expand Down Expand Up @@ -1524,8 +1527,7 @@ async def _send_single_request(

async def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
if "aclose" in extensions:
await extensions["aclose"]()
await stream.aclose()

response = Response(
status_code,
Expand Down
105 changes: 41 additions & 64 deletions httpx/_content.py
Expand Up @@ -14,92 +14,69 @@

from ._exceptions import StreamConsumed
from ._multipart import MultipartStream
from ._types import (
ByteStream,
RequestContent,
RequestData,
RequestFiles,
ResponseContent,
)
from ._transports.base import AsyncByteStream, SyncByteStream
from ._types import RequestContent, RequestData, RequestFiles, ResponseContent
from ._utils import primitive_value_to_str


class PlainByteStream:
"""
Request content encoded as plain bytes.
"""

def __init__(self, body: bytes) -> None:
self._body = body
class ByteStream(AsyncByteStream, SyncByteStream):
def __init__(self, stream: bytes) -> None:
self._stream = stream

def __iter__(self) -> Iterator[bytes]:
yield self._body
yield self._stream

async def __aiter__(self) -> AsyncIterator[bytes]:
yield self._body
yield self._stream


class GeneratorStream:
"""
Request content encoded as plain bytes, using an byte generator.
"""

def __init__(self, generator: Iterable[bytes]) -> None:
self._generator = generator
class IteratorByteStream(SyncByteStream):
def __init__(self, stream: Iterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isgenerator(stream)

def __iter__(self) -> Iterator[bytes]:
if self._is_stream_consumed:
if self._is_stream_consumed and self._is_generator:
raise StreamConsumed()

self._is_stream_consumed = True
for part in self._generator:
for part in self._stream:
yield part


class AsyncGeneratorStream:
"""
Request content encoded as plain bytes, using an async byte iterator.
"""

def __init__(self, agenerator: AsyncIterable[bytes]) -> None:
self._agenerator = agenerator
class AsyncIteratorByteStream(AsyncByteStream):
def __init__(self, stream: AsyncIterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isasyncgen(stream)

async def __aiter__(self) -> AsyncIterator[bytes]:
if self._is_stream_consumed:
if self._is_stream_consumed and self._is_generator:
raise StreamConsumed()

self._is_stream_consumed = True
async for part in self._agenerator:
async for part in self._stream:
yield part


def encode_content(
content: Union[str, bytes, ByteStream]
) -> Tuple[Dict[str, str], ByteStream]:
if isinstance(content, (str, bytes)):
content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:

if isinstance(content, (bytes, str)):
body = content.encode("utf-8") if isinstance(content, str) else content
content_length = str(len(body))
headers = {"Content-Length": content_length} if body else {}
stream = PlainByteStream(body)
return headers, stream
return headers, ByteStream(body)

elif isinstance(content, (Iterable, AsyncIterable)):
elif isinstance(content, Iterable):
headers = {"Transfer-Encoding": "chunked"}
return headers, IteratorByteStream(content) # type: ignore

# Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
# which will raise `StreamConsumed` if the stream is accessed more
# than once. (Eg. Following HTTP 307 or HTTP 308 redirects.)
if inspect.isgenerator(content):
generator_stream = GeneratorStream(content) # type: ignore
return headers, generator_stream
if inspect.isasyncgen(content):
agenerator_stream = AsyncGeneratorStream(content) # type: ignore
return headers, agenerator_stream

# Other iterables may be passed through as-is.
return headers, content # type: ignore
elif isinstance(content, AsyncIterable):
headers = {"Transfer-Encoding": "chunked"}
return headers, AsyncIteratorByteStream(content)

raise TypeError(f"Unexpected type for 'content', {type(content)!r}")

Expand All @@ -117,39 +94,39 @@ def encode_urlencoded_data(
content_length = str(len(body))
content_type = "application/x-www-form-urlencoded"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_multipart_data(
data: dict, files: RequestFiles, boundary: bytes = None
) -> Tuple[Dict[str, str], ByteStream]:
stream = MultipartStream(data=data, files=files, boundary=boundary)
headers = stream.get_headers()
return headers, stream
) -> Tuple[Dict[str, str], MultipartStream]:
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
return headers, multipart


def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
body = text.encode("utf-8")
content_length = str(len(body))
content_type = "text/plain; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
body = html.encode("utf-8")
content_length = str(len(body))
content_type = "text/html; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
body = json_dumps(json).encode("utf-8")
content_length = str(len(body))
content_type = "application/json"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_request(
Expand All @@ -158,7 +135,7 @@ def encode_request(
files: RequestFiles = None,
json: Any = None,
boundary: bytes = None,
) -> Tuple[Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
returning a two-tuple of (<headers>, <stream>).
Expand All @@ -182,15 +159,15 @@ def encode_request(
elif json is not None:
return encode_json(json)

return {}, PlainByteStream(b"")
return {}, ByteStream(b"")


def encode_response(
content: ResponseContent = None,
text: str = None,
html: str = None,
json: Any = None,
) -> Tuple[Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
"""
Handles encoding the given `content`, returning a two-tuple of
(<headers>, <stream>).
Expand All @@ -204,4 +181,4 @@ def encode_response(
elif json is not None:
return encode_json(json)

return {}, PlainByteStream(b"")
return {}, ByteStream(b"")
18 changes: 9 additions & 9 deletions httpx/_models.py
Expand Up @@ -11,7 +11,7 @@
import rfc3986
import rfc3986.exceptions

from ._content import PlainByteStream, encode_request, encode_response
from ._content import ByteStream, encode_request, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ByteChunker,
Expand All @@ -33,8 +33,8 @@
request_context,
)
from ._status_codes import codes
from ._transports.base import AsyncByteStream, SyncByteStream
from ._types import (
ByteStream,
CookieTypes,
HeaderTypes,
PrimitiveData,
Expand Down Expand Up @@ -798,7 +798,7 @@ def __init__(
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
stream: ByteStream = None,
stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
):
if isinstance(method, bytes):
self.method = method.decode("ascii").upper()
Expand Down Expand Up @@ -872,7 +872,7 @@ def read(self) -> bytes:
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = PlainByteStream(self._content)
self.stream = ByteStream(self._content)
return self._content

async def aread(self) -> bytes:
Expand All @@ -885,7 +885,7 @@ async def aread(self) -> bytes:
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = PlainByteStream(self._content)
self.stream = ByteStream(self._content)
return self._content

def __repr__(self) -> str:
Expand All @@ -904,7 +904,7 @@ def __init__(
text: str = None,
html: str = None,
json: typing.Any = None,
stream: ByteStream = None,
stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
request: Request = None,
extensions: dict = None,
history: typing.List["Response"] = None,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if not isinstance(self.stream, typing.Iterable):
if not isinstance(self.stream, SyncByteStream):
raise RuntimeError("Attempted to call a sync iterator on an async stream.")

self.is_stream_consumed = True
Expand Down Expand Up @@ -1318,8 +1318,8 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if not isinstance(self.stream, typing.AsyncIterable):
raise RuntimeError("Attempted to call a async iterator on a sync stream.")
if not isinstance(self.stream, AsyncByteStream):
raise RuntimeError("Attempted to call an async iterator on an sync stream.")

self.is_stream_consumed = True
self._num_bytes_downloaded = 0
Expand Down

0 comments on commit 110ce85

Please sign in to comment.