Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conformance/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ async def send_unary_request(
except ConnectError as e:
test_response.response.error.code = _convert_code(e.code)
test_response.response.error.message = e.message
test_response.response.error.details.extend(e.details)
test_response.response.error.details.extend(d._any for d in e.details)
except (asyncio.CancelledError, Exception) as e:
traceback.print_tb(e.__traceback__, file=sys.stderr)
test_response.error.message = str(e)
Expand Down
36 changes: 14 additions & 22 deletions src/connectrpc/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Protocol, TypeVar, cast

from google.protobuf import symbol_database
from google.protobuf.any_pb2 import Any
from google.protobuf.json_format import MessageToDict

from ._compression import Compression
from .code import Code
from .errors import ConnectError
from .errors import ConnectError, ErrorDetail

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -91,7 +90,7 @@ def from_http_status(status: HTTPStatus) -> ExtendedHTTPStatus:
class ConnectWireError:
code: Code
message: str
details: Sequence[Any]
details: Sequence[ErrorDetail]

@staticmethod
def from_exception(exc: Exception) -> ConnectWireError:
Expand Down Expand Up @@ -122,7 +121,7 @@ def from_dict(
else:
code = _http_status_code_to_error.get(http_status, Code.UNKNOWN)
message = data.get("message", "")
details: Sequence[Any] = ()
details: Sequence[ErrorDetail] = ()
details_json = cast("list[dict[str, str]] | None", data.get("details"))
if details_json:
details = []
Expand All @@ -133,9 +132,11 @@ def from_dict(
# Ignore malformed details
continue
details.append(
Any(
type_url="type.googleapis.com/" + detail_type,
value=b64decode(detail_value + "==="),
ErrorDetail(
Any(
type_url="type.googleapis.com/" + detail_type,
value=b64decode(detail_value + "==="),
)
)
)
return ConnectWireError(code, message, details)
Expand All @@ -161,26 +162,17 @@ def to_dict(self) -> dict:
if self.details:
details: list[dict] = []
for detail in self.details:
if detail.type_url.startswith("type.googleapis.com/"):
detail_type = detail.type_url[len("type.googleapis.com/") :]
else:
detail_type = detail.type_url
detail_dict: dict = {
"type": detail_type,
"type": detail.type_name,
# Connect requires unpadded base64
"value": b64encode(detail.value).decode("utf-8").rstrip("="),
"value": b64encode(detail.message_bytes)
.decode("utf-8")
.rstrip("="),
}
# Try to produce debug info, but expect failure when we don't
# have descriptors for the message type.
debug = None
try:
msg_instance = symbol_database.Default().GetSymbol(detail_type)()
if detail.Unpack(msg_instance):
debug = MessageToDict(msg_instance)
except Exception:
debug = None
if debug is not None:
detail_dict["debug"] = debug
if (debug := detail.value()) is not None:
detail_dict["debug"] = MessageToDict(debug)
details.append(detail_dict)
data["details"] = details
return data
Expand Down
4 changes: 3 additions & 1 deletion src/connectrpc/_protocol_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def end(self, user_trailers: Headers, error: ConnectWireError | None) -> Headers
trailers["grpc-message"] = message
if error.details:
grpc_status = Status(
code=int(status), message=error.message, details=error.details
code=int(status),
message=error.message,
details=[d._any for d in error.details], # noqa: SLF001
)
grpc_status_bin = (
b64encode(grpc_status.SerializeToString()).decode().rstrip("=")
Expand Down
68 changes: 61 additions & 7 deletions src/connectrpc/errors.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,73 @@
from __future__ import annotations

__all__ = ["ConnectError"]
__all__ = ["ConnectError", "ErrorDetail"]


from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar, overload

from google.protobuf import symbol_database
from google.protobuf.any_pb2 import Any
from google.protobuf.message import Message

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence

from google.protobuf.message import Message

from .code import Code

T = TypeVar("T", bound=Message)


class ErrorDetail:
"""A self-describing Protobuf message attached to a [ConnectError][].

Error details are sent over the network to clients, which can then work with
strongly-typed data rather than trying to parse a complex error message. For
example, you might use details to send a localized error message or retry
parameters to a client.
"""

def __init__(self, message: Message) -> None:
if isinstance(message, Any):
self._message = None
self._any = message
return
self._message = message
self._any = pack_any(message)

@property
def type_name(self) -> str:
"""The fully-qualified name of the details Protobuf message (for example, acme.foo.v1.FooDetail)."""
return self._any.type_url.removeprefix("type.googleapis.com/")

@property
def message_bytes(self) -> bytes:
"""The Protobuf message serialized as bytes."""
return self._any.value

@overload
def value(self) -> Message | None: ...

@overload
def value(self, typ: type[T], /) -> T | None: ...

def value(self, desc: type[Message] | None = None) -> Message | None:
"""The details message as a Protobuf message, or None if it cannot be deserialized."""
if self._message:
return self._message
if isinstance(desc, type):
msg = desc()
if self._any.Unpack(msg):
return msg
return None
try:
detail_type = self._any.type_url.removeprefix("type.googleapis.com/")
msg_instance = symbol_database.Default().GetSymbol(detail_type)()
if self._any.Unpack(msg_instance):
return msg_instance
return None
except Exception:
return None


class ConnectError(Exception):
"""An exception in a Connect RPC.
Expand All @@ -25,7 +79,7 @@ class ConnectError(Exception):
"""

def __init__(
self, code: Code, message: str, details: Iterable[Message] = ()
self, code: Code, message: str, details: Iterable[Message | ErrorDetail] = ()
) -> None:
"""
Creates a new Connect error.
Expand All @@ -40,7 +94,7 @@ def __init__(
self._message = message

self._details = (
[m if isinstance(m, Any) else pack_any(m) for m in details]
[m if isinstance(m, ErrorDetail) else ErrorDetail(m) for m in details]
if details
else ()
)
Expand All @@ -54,7 +108,7 @@ def message(self) -> str:
return self._message

@property
def details(self) -> Sequence[Any]:
def details(self) -> Sequence[ErrorDetail]:
return self._details


Expand Down
24 changes: 12 additions & 12 deletions test/test_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from connectrpc._protocol import ConnectWireError
from connectrpc.code import Code
from connectrpc.errors import ConnectError, pack_any
from connectrpc.errors import ConnectError, ErrorDetail

from .haberdasher_connect import (
Haberdasher,
Expand All @@ -32,7 +32,7 @@ def make_hat(self, request, ctx) -> NoReturn:
"Resource exhausted",
details=[
Struct(fields={"animal": Value(string_value="bear")}),
pack_any(Struct(fields={"color": Value(string_value="red")})),
ErrorDetail(Struct(fields={"color": Value(string_value="red")})),
],
)

Expand All @@ -47,11 +47,11 @@ def make_hat(self, request, ctx) -> NoReturn:
assert exc_info.value.code == Code.RESOURCE_EXHAUSTED
assert exc_info.value.message == "Resource exhausted"
assert len(exc_info.value.details) == 2
s0 = Struct()
assert exc_info.value.details[0].Unpack(s0)
s0 = exc_info.value.details[0].value(Struct)
assert s0 is not None
assert s0.fields["animal"].string_value == "bear"
s1 = Struct()
assert exc_info.value.details[1].Unpack(s1)
s1 = exc_info.value.details[1].value(Struct)
assert s1 is not None
assert s1.fields["color"].string_value == "red"


Expand All @@ -64,7 +64,7 @@ async def make_hat(self, request, ctx) -> NoReturn:
"Resource exhausted",
details=[
Struct(fields={"animal": Value(string_value="bear")}),
pack_any(Struct(fields={"color": Value(string_value="red")})),
ErrorDetail(Struct(fields={"color": Value(string_value="red")})),
],
)

Expand All @@ -78,11 +78,11 @@ async def make_hat(self, request, ctx) -> NoReturn:
assert exc_info.value.code == Code.RESOURCE_EXHAUSTED
assert exc_info.value.message == "Resource exhausted"
assert len(exc_info.value.details) == 2
s0 = Struct()
assert exc_info.value.details[0].Unpack(s0)
s0 = exc_info.value.details[0].value(Struct)
assert s0 is not None
assert s0.fields["animal"].string_value == "bear"
s1 = Struct()
assert exc_info.value.details[1].Unpack(s1)
s1 = exc_info.value.details[1].value(Struct)
assert s1 is not None
assert s1.fields["color"].string_value == "red"


Expand Down Expand Up @@ -124,7 +124,7 @@ def test_error_detail_debug_field_absent_for_unknown_type() -> None:
type_url="type.googleapis.com/completely.Unknown.Message", value=b"\x08\x01"
)
wire_error = ConnectWireError(
code=Code.INTERNAL, message="test", details=[unknown_detail]
code=Code.INTERNAL, message="test", details=[ErrorDetail(unknown_detail)]
)
data = wire_error.to_dict()
assert len(data["details"]) == 1
Expand Down
Loading