Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream interface #1550

Merged
merged 9 commits into from Apr 13, 2021
4 changes: 2 additions & 2 deletions docs/advanced.md
Expand Up @@ -1070,7 +1070,7 @@ class HelloWorldTransport(httpx.BaseTransport):
def handle_request(self, method, url, headers, stream, extensions):
message = {"text": "Hello, world!"}
content = json.dumps(message).encode("utf-8")
stream = [content]
stream = httpx.ByteStream(content)
headers = [(b"content-type", b"application/json")]
extensions = {}
return 200, headers, stream, extensions
Expand Down Expand Up @@ -1131,7 +1131,7 @@ class HTTPSRedirectTransport(httpx.BaseTransport):
location = b"https://%s%s" % (host, path)
else:
location = b"https://%s:%d%s" % (host, port, path)
stream = [b""]
stream = httpx.ByteStream(b"")
headers = [(b"location", location)]
extensions = {}
return 303, headers, stream, extensions
Expand Down
11 changes: 10 additions & 1 deletion httpx/__init__.py
Expand Up @@ -3,6 +3,7 @@
from ._auth import Auth, BasicAuth, DigestAuth
from ._client import AsyncClient, Client
from ._config import Limits, Proxy, Timeout, create_ssl_context
from ._content import ByteStream
from ._exceptions import (
CloseError,
ConnectError,
Expand Down Expand Up @@ -36,7 +37,12 @@
from ._models import URL, Cookies, Headers, QueryParams, Request, Response
from ._status_codes import StatusCode, codes
from ._transports.asgi import ASGITransport
from ._transports.base import AsyncBaseTransport, BaseTransport
from ._transports.base import (
AsyncBaseTransport,
AsyncByteStream,
BaseTransport,
SyncByteStream,
)
from ._transports.default import AsyncHTTPTransport, HTTPTransport
from ._transports.mock import MockTransport
from ._transports.wsgi import WSGITransport
Expand All @@ -47,11 +53,13 @@
"__version__",
"ASGITransport",
"AsyncBaseTransport",
"AsyncByteStream",
"AsyncClient",
"AsyncHTTPTransport",
"Auth",
"BaseTransport",
"BasicAuth",
"ByteStream",
"Client",
"CloseError",
"codes",
Expand Down Expand Up @@ -97,6 +105,7 @@
"stream",
"StreamConsumed",
"StreamError",
"SyncByteStream",
"Timeout",
"TimeoutException",
"TooManyRedirects",
Expand Down
16 changes: 9 additions & 7 deletions httpx/_client.py
Expand Up @@ -26,12 +26,16 @@
from ._models import URL, Cookies, Headers, QueryParams, Request, Response
from ._status_codes import codes
from ._transports.asgi import ASGITransport
from ._transports.base import AsyncBaseTransport, BaseTransport
from ._transports.base import (
AsyncBaseTransport,
AsyncByteStream,
BaseTransport,
SyncByteStream,
)
from ._transports.default import AsyncHTTPTransport, HTTPTransport
from ._transports.wsgi import WSGITransport
from ._types import (
AuthTypes,
ByteStream,
CertTypes,
CookieTypes,
HeaderTypes,
Expand Down Expand Up @@ -509,7 +513,7 @@ def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers:

def _redirect_stream(
self, request: Request, method: str
) -> typing.Optional[ByteStream]:
) -> typing.Optional[typing.Union[SyncByteStream, AsyncByteStream]]:
"""
Return the body that should be used for the redirect request.
"""
Expand Down Expand Up @@ -880,8 +884,7 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:

def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
if "close" in extensions:
extensions["close"]()
stream.close()

response = Response(
status_code,
Expand Down Expand Up @@ -1524,8 +1527,7 @@ async def _send_single_request(

async def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
if "aclose" in extensions:
await extensions["aclose"]()
await stream.aclose()

response = Response(
status_code,
Expand Down
105 changes: 41 additions & 64 deletions httpx/_content.py
Expand Up @@ -14,92 +14,69 @@

from ._exceptions import StreamConsumed
from ._multipart import MultipartStream
from ._types import (
ByteStream,
RequestContent,
RequestData,
RequestFiles,
ResponseContent,
)
from ._transports.base import AsyncByteStream, SyncByteStream
from ._types import RequestContent, RequestData, RequestFiles, ResponseContent
from ._utils import primitive_value_to_str


class PlainByteStream:
"""
Request content encoded as plain bytes.
"""

def __init__(self, body: bytes) -> None:
self._body = body
class ByteStream(AsyncByteStream, SyncByteStream):
def __init__(self, stream: bytes) -> None:
self._stream = stream

def __iter__(self) -> Iterator[bytes]:
yield self._body
yield self._stream

async def __aiter__(self) -> AsyncIterator[bytes]:
yield self._body
yield self._stream


class GeneratorStream:
"""
Request content encoded as plain bytes, using an byte generator.
"""

def __init__(self, generator: Iterable[bytes]) -> None:
self._generator = generator
class IteratorByteStream(SyncByteStream):
def __init__(self, stream: Iterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isgenerator(stream)

def __iter__(self) -> Iterator[bytes]:
if self._is_stream_consumed:
if self._is_stream_consumed and self._is_generator:
raise StreamConsumed()

self._is_stream_consumed = True
for part in self._generator:
for part in self._stream:
yield part


class AsyncGeneratorStream:
"""
Request content encoded as plain bytes, using an async byte iterator.
"""

def __init__(self, agenerator: AsyncIterable[bytes]) -> None:
self._agenerator = agenerator
class AsyncIteratorByteStream(AsyncByteStream):
def __init__(self, stream: AsyncIterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
self._is_generator = inspect.isasyncgen(stream)

async def __aiter__(self) -> AsyncIterator[bytes]:
if self._is_stream_consumed:
if self._is_stream_consumed and self._is_generator:
raise StreamConsumed()

self._is_stream_consumed = True
async for part in self._agenerator:
async for part in self._stream:
yield part


def encode_content(
content: Union[str, bytes, ByteStream]
) -> Tuple[Dict[str, str], ByteStream]:
if isinstance(content, (str, bytes)):
content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:

if isinstance(content, (bytes, str)):
body = content.encode("utf-8") if isinstance(content, str) else content
content_length = str(len(body))
headers = {"Content-Length": content_length} if body else {}
stream = PlainByteStream(body)
return headers, stream
return headers, ByteStream(body)

elif isinstance(content, (Iterable, AsyncIterable)):
elif isinstance(content, Iterable):
headers = {"Transfer-Encoding": "chunked"}
return headers, IteratorByteStream(content) # type: ignore

# Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
# which will raise `StreamConsumed` if the stream is accessed more
# than once. (Eg. Following HTTP 307 or HTTP 308 redirects.)
if inspect.isgenerator(content):
generator_stream = GeneratorStream(content) # type: ignore
return headers, generator_stream
if inspect.isasyncgen(content):
agenerator_stream = AsyncGeneratorStream(content) # type: ignore
return headers, agenerator_stream

# Other iterables may be passed through as-is.
return headers, content # type: ignore
elif isinstance(content, AsyncIterable):
headers = {"Transfer-Encoding": "chunked"}
return headers, AsyncIteratorByteStream(content)

raise TypeError(f"Unexpected type for 'content', {type(content)!r}")

Expand All @@ -117,39 +94,39 @@ def encode_urlencoded_data(
content_length = str(len(body))
content_type = "application/x-www-form-urlencoded"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_multipart_data(
data: dict, files: RequestFiles, boundary: bytes = None
) -> Tuple[Dict[str, str], ByteStream]:
stream = MultipartStream(data=data, files=files, boundary=boundary)
headers = stream.get_headers()
return headers, stream
) -> Tuple[Dict[str, str], MultipartStream]:
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
return headers, multipart


def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
body = text.encode("utf-8")
content_length = str(len(body))
content_type = "text/plain; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
body = html.encode("utf-8")
content_length = str(len(body))
content_type = "text/html; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
body = json_dumps(json).encode("utf-8")
content_length = str(len(body))
content_type = "application/json"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
return headers, ByteStream(body)


def encode_request(
Expand All @@ -158,7 +135,7 @@ def encode_request(
files: RequestFiles = None,
json: Any = None,
boundary: bytes = None,
) -> Tuple[Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
returning a two-tuple of (<headers>, <stream>).
Expand All @@ -182,15 +159,15 @@ def encode_request(
elif json is not None:
return encode_json(json)

return {}, PlainByteStream(b"")
return {}, ByteStream(b"")


def encode_response(
content: ResponseContent = None,
text: str = None,
html: str = None,
json: Any = None,
) -> Tuple[Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
"""
Handles encoding the given `content`, returning a two-tuple of
(<headers>, <stream>).
Expand All @@ -204,4 +181,4 @@ def encode_response(
elif json is not None:
return encode_json(json)

return {}, PlainByteStream(b"")
return {}, ByteStream(b"")
18 changes: 9 additions & 9 deletions httpx/_models.py
Expand Up @@ -11,7 +11,7 @@
import rfc3986
import rfc3986.exceptions

from ._content import PlainByteStream, encode_request, encode_response
from ._content import ByteStream, encode_request, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ByteChunker,
Expand All @@ -33,8 +33,8 @@
request_context,
)
from ._status_codes import codes
from ._transports.base import AsyncByteStream, SyncByteStream
from ._types import (
ByteStream,
CookieTypes,
HeaderTypes,
PrimitiveData,
Expand Down Expand Up @@ -798,7 +798,7 @@ def __init__(
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
stream: ByteStream = None,
stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
):
if isinstance(method, bytes):
self.method = method.decode("ascii").upper()
Expand Down Expand Up @@ -872,7 +872,7 @@ def read(self) -> bytes:
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = PlainByteStream(self._content)
self.stream = ByteStream(self._content)
return self._content

async def aread(self) -> bytes:
Expand All @@ -885,7 +885,7 @@ async def aread(self) -> bytes:
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = PlainByteStream(self._content)
self.stream = ByteStream(self._content)
return self._content

def __repr__(self) -> str:
Expand All @@ -904,7 +904,7 @@ def __init__(
text: str = None,
html: str = None,
json: typing.Any = None,
stream: ByteStream = None,
stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
request: Request = None,
extensions: dict = None,
history: typing.List["Response"] = None,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if not isinstance(self.stream, typing.Iterable):
if not isinstance(self.stream, SyncByteStream):
raise RuntimeError("Attempted to call a sync iterator on an async stream.")

self.is_stream_consumed = True
Expand Down Expand Up @@ -1318,8 +1318,8 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if not isinstance(self.stream, typing.AsyncIterable):
raise RuntimeError("Attempted to call a async iterator on a sync stream.")
if not isinstance(self.stream, AsyncByteStream):
raise RuntimeError("Attempted to call an async iterator on an sync stream.")

self.is_stream_consumed = True
self._num_bytes_downloaded = 0
Expand Down