Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push httpcore exception wrapping out of client into transport #1524

Merged
merged 3 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
)
from ._decoders import SUPPORTED_DECODERS
from ._exceptions import (
HTTPCORE_EXC_MAP,
InvalidURL,
RemoteProtocolError,
TooManyRedirects,
map_exceptions,
request_context,
)
from ._models import URL, Cookies, Headers, QueryParams, Request, Response
from ._status_codes import codes
Expand Down Expand Up @@ -849,7 +848,7 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:
timer = Timer()
timer.sync_start()

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
with request_context(request=request):
(status_code, headers, stream, extensions) = transport.handle_request(
request.method.encode(),
request.url.raw,
Expand All @@ -860,13 +859,13 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:

def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
if hasattr(stream, "close"):
stream.close() # type: ignore
if "close" in extensions:
extensions["close"]()

response = Response(
status_code,
headers=headers,
stream=stream, # type: ignore
stream=stream,
extensions=extensions,
request=request,
on_close=on_close,
Expand Down Expand Up @@ -1483,7 +1482,7 @@ async def _send_single_request(
timer = Timer()
await timer.async_start()

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
with request_context(request=request):
(
status_code,
headers,
Expand All @@ -1499,14 +1498,13 @@ async def _send_single_request(

async def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
if hasattr(stream, "aclose"):
with map_exceptions(HTTPCORE_EXC_MAP, request=request):
await stream.aclose() # type: ignore
if "aclose" in extensions:
await extensions["aclose"]()

response = Response(
status_code,
headers=headers,
stream=stream, # type: ignore
stream=stream,
extensions=extensions,
request=request,
on_close=on_close,
Expand Down
14 changes: 8 additions & 6 deletions httpx/_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import typing
import zlib

from ._exceptions import DecodingError

try:
import brotli
except ImportError: # pragma: nocover
Expand Down Expand Up @@ -54,13 +56,13 @@ def decode(self, data: bytes) -> bytes:
if was_first_attempt:
self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
return self.decode(data)
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc


class GZipDecoder(ContentDecoder):
Expand All @@ -77,13 +79,13 @@ def decode(self, data: bytes) -> bytes:
try:
return self.decompressor.decompress(data)
except zlib.error as exc:
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc


class BrotliDecoder(ContentDecoder):
Expand Down Expand Up @@ -118,7 +120,7 @@ def decode(self, data: bytes) -> bytes:
try:
return self._decompress(data)
except brotli.error as exc:
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc

def flush(self) -> bytes:
if not self.seen_data:
Expand All @@ -128,7 +130,7 @@ def flush(self) -> bytes:
self.decompressor.finish()
return b""
except brotli.error as exc: # pragma: nocover
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc


class MultiDecoder(ContentDecoder):
Expand Down
74 changes: 25 additions & 49 deletions httpx/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import contextlib
import typing

import httpcore

if typing.TYPE_CHECKING:
from ._models import Request, Response # pragma: nocover

Expand All @@ -58,25 +56,33 @@ class HTTPError(Exception):
```
"""

def __init__(self, message: str, *, request: "Request") -> None:
def __init__(self, message: str) -> None:
super().__init__(message)
self.request = request


class RequestError(HTTPError):
"""
Base class for all exceptions that may occur when issuing a `.request()`.
"""

def __init__(self, message: str, *, request: "Request") -> None:
super().__init__(message, request=request)
def __init__(self, message: str, *, request: "Request" = None) -> None:
super().__init__(message)
self._request = request

@property
def request(self) -> "Request":
if self._request is None:
raise RuntimeError("The .request property has not been set.")
return self._request

@request.setter
def request(self, request: "Request") -> None:
self._request = request


class TransportError(RequestError):
"""
Base class for all exceptions that occur at the level of the Transport API.

All of these exceptions also have an equivelent mapping in `httpcore`.
"""


Expand Down Expand Up @@ -219,7 +225,8 @@ class HTTPStatusError(HTTPError):
def __init__(
self, message: str, *, request: "Request", response: "Response"
) -> None:
super().__init__(message, request=request)
super().__init__(message)
self.request = request
self.response = response


Expand Down Expand Up @@ -318,45 +325,14 @@ def __init__(self) -> None:


@contextlib.contextmanager
def map_exceptions(
mapping: typing.Mapping[typing.Type[Exception], typing.Type[Exception]],
**kwargs: typing.Any,
) -> typing.Iterator[None]:
def request_context(request: "Request" = None) -> typing.Iterator[None]:
"""
A context manager that can be used to attach the given request context
to any `RequestError` exceptions that are raised within the block.
"""
try:
yield
except Exception as exc:
mapped_exc = None

for from_exc, to_exc in mapping.items():
if not isinstance(exc, from_exc):
continue
# We want to map to the most specific exception we can find.
# Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
# `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc

if mapped_exc is None:
raise

message = str(exc)
raise mapped_exc(message, **kwargs) from exc # type: ignore


HTTPCORE_EXC_MAP = {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.CloseError: CloseError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
except RequestError as exc:
if request is not None:
exc.request = request
raise exc
38 changes: 13 additions & 25 deletions httpx/_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cgi
import contextlib
import datetime
import email.message
import json as jsonlib
Expand All @@ -24,16 +23,14 @@
TextDecoder,
)
from ._exceptions import (
HTTPCORE_EXC_MAP,
CookieConflict,
DecodingError,
HTTPStatusError,
InvalidURL,
RequestNotRead,
ResponseClosed,
ResponseNotRead,
StreamConsumed,
map_exceptions,
request_context,
)
from ._status_codes import codes
from ._types import (
Expand Down Expand Up @@ -1145,17 +1142,6 @@ def num_bytes_downloaded(self) -> int:
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"

@contextlib.contextmanager
def _wrap_decoder_errors(self) -> typing.Iterator[None]:
# If the response has an associated request instance, we want decoding
# errors to be raised as proper `httpx.DecodingError` exceptions.
try:
yield
except ValueError as exc:
if self._request is None:
raise exc
raise DecodingError(message=str(exc), request=self.request) from exc

def read(self) -> bytes:
"""
Read and return the response content.
Expand All @@ -1176,7 +1162,7 @@ def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
for raw_bytes in self.iter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
Expand All @@ -1195,7 +1181,7 @@ def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
for byte_content in self.iter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
Expand All @@ -1208,7 +1194,7 @@ def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:

def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
with self._wrap_decoder_errors():
with request_context(request=self._request):
for text in self.iter_text():
for line in decoder.decode(text):
yield line
Expand All @@ -1230,7 +1216,7 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
with request_context(request=self._request):
for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
Expand All @@ -1249,7 +1235,8 @@ def close(self) -> None:
if not self.is_closed:
self.is_closed = True
if self._on_close is not None:
self._on_close(self)
with request_context(request=self._request):
self._on_close(self)

async def aread(self) -> bytes:
"""
Expand All @@ -1271,7 +1258,7 @@ async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[byte
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
Expand All @@ -1290,7 +1277,7 @@ async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
Expand All @@ -1303,7 +1290,7 @@ async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:

async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
with self._wrap_decoder_errors():
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
Expand All @@ -1325,7 +1312,7 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
with request_context(request=self._request):
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
Expand All @@ -1344,7 +1331,8 @@ async def aclose(self) -> None:
if not self.is_closed:
self.is_closed = True
if self._on_close is not None:
await self._on_close(self)
with request_context(request=self._request):
await self._on_close(self)


class Cookies(MutableMapping):
Expand Down
8 changes: 6 additions & 2 deletions httpx/_transports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def handle_request(
try:
body = b''.join([part for part in stream])
finally:
if hasattr(stream 'close'):
stream.close()
if 'close' in extensions:
extensions['close']()
print(status_code, headers, body)

Arguments:
Expand Down Expand Up @@ -86,6 +86,10 @@ def handle_request(
eg. the leading response bytes were b"HTTP/1.1 200 <CRLF>".
http_version: The HTTP version, as a string. Eg. "HTTP/1.1".
When no http_version key is included, "HTTP/1.1" may be assumed.
close: A callback which should be invoked to release any network
resources.
aclose: An async callback which should be invoked to release any
network resources.
"""
raise NotImplementedError(
"The 'handle_request' method must be implemented."
Expand Down
Loading