Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions src/betterproto2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

from typing_extensions import Self

try:
import pydantic
import pydantic_core
except ImportError:
pydantic = None
pydantic_core = None

import betterproto2.validators as validators
from betterproto2.message_pool import MessagePool
from betterproto2.utils import unwrap
Expand Down Expand Up @@ -697,6 +704,26 @@ def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
cls._betterproto_meta = ProtoClassMetadata(cls)
return cls._betterproto_meta

def _is_pydantic(self) -> bool:
"""
Check if the message is a pydantic dataclass.
"""
return pydantic is not None and pydantic.dataclasses.is_pydantic_dataclass(type(self))

def _validate(self) -> None:
"""
Manually validate the message using pydantic.

This is useful since pydantic does not revalidate the message when fields are changed. This function doesn't
validate the fields recursively.
"""
if not self._is_pydantic():
raise TypeError("Validation is only available for pydantic dataclasses.")

dict = self.__dict__.copy()
del dict["_unknown_fields"]
pydantic_core.SchemaValidator(self.__pydantic_core_schema__).validate_python(dict) # type: ignore

def dump(self, stream: SupportsWrite[bytes], delimit: bool = False) -> None:
"""
Dumps the binary encoded Protobuf message to the stream.
Expand All @@ -720,6 +747,9 @@ def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this message instance.
"""
if self._is_pydantic():
self._validate()

with BytesIO() as stream:
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)
Expand Down Expand Up @@ -822,13 +852,17 @@ def _postprocess_single(self, wire_type: int, meta: FieldMetadata, field_name: s
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
if meta.proto_type in (TYPE_INT32, TYPE_INT64):
bits = int(meta.proto_type[3:])
bits = 32 if meta.proto_type == TYPE_INT32 else 64
value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit)
elif meta.proto_type in (TYPE_UINT32, TYPE_UINT64):
bits = 32 if meta.proto_type == TYPE_UINT32 else 64
value = value & ((1 << bits) - 1)
elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64):
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
bits = 32 if meta.proto_type == TYPE_SINT32 else 64
value = value & ((1 << bits) - 1)
value = (value >> 1) ^ (-(value & 1)) # Undo zig-zag encoding
elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false.
value = value > 0
Expand Down Expand Up @@ -947,6 +981,9 @@ def load(
" or the expected size may have been incorrect."
)

if self._is_pydantic():
self._validate()

return self

@classmethod
Expand Down Expand Up @@ -1017,6 +1054,9 @@ def to_dict(
Dict[:class:`str`, Any]
The JSON serializable dict representation of this object.
"""
if self._is_pydantic():
self._validate()

kwargs = { # For recursive calls
"output_format": output_format,
"casing": casing,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_encoding_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
def test_int_overflow():
"""Make sure that overflows in encoded values are handled correctly."""
from tests.output_betterproto_pydantic.encoding_decoding import Overflow32, Overflow64

b = bytes(Overflow64(uint=2**50 + 42))
msg = Overflow32.parse(b)
assert msg.uint == 42

b = bytes(Overflow64(int=2**50 + 42))
msg = Overflow32.parse(b)
assert msg.int == 42

b = bytes(Overflow64(int=2**50 - 42))
msg = Overflow32.parse(b)
assert msg.int == -42

b = bytes(Overflow64(sint=2**50 + 42))
msg = Overflow32.parse(b)
assert msg.sint == 42

b = bytes(Overflow64(sint=-(2**50) - 42))
msg = Overflow32.parse(b)
assert msg.sint == -42
23 changes: 23 additions & 0 deletions tests/test_manual_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pydantic
import pytest


def test_manual_validation():
from tests.output_betterproto_pydantic.manual_validation import Msg

msg = Msg()

msg.x = 12
msg._validate()

msg.x = 2**50 # This is an invalid int32 value
with pytest.raises(pydantic.ValidationError):
msg._validate()


def test_manual_validation_non_pydantic():
from tests.output_betterproto.manual_validation import Msg

# Validation is not available for non-pydantic messages
with pytest.raises(TypeError):
Msg()._validate()
Loading