Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions betterproto2/src/betterproto2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def __bytes__(self) -> bytes:
# Default (zero) values are not serialized.
continue

if isinstance(value, list):
if meta.repeated:
if meta.proto_type in PACKED_TYPES:
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
Expand All @@ -802,9 +802,8 @@ def __bytes__(self) -> bytes:
or b"\n\x00"
)

elif isinstance(value, dict):
elif meta.map_meta:
for k, v in value.items():
assert meta.map_meta
sk = _serialize_single(1, meta.map_meta[0].proto_type, k)
sv = _serialize_single(2, meta.map_meta[1].proto_type, v, unwrap=meta.map_meta[1].unwrap)
stream.write(_serialize_single(meta.number, meta.proto_type, sk + sv))
Expand Down Expand Up @@ -944,8 +943,10 @@ def load(

meta = proto_meta.meta_by_field_name[field_name]

is_packed_repeated = parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES

value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
if is_packed_repeated:
# This is a packed repeated field.
pos = 0
value = []
Expand All @@ -969,8 +970,8 @@ def load(
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list):
if isinstance(value, list):
elif meta.repeated:
if is_packed_repeated:
current.extend(value)
else:
current.append(value)
Expand Down Expand Up @@ -1142,7 +1143,12 @@ def _from_dict_init(cls, mapping: Mapping[str, Any] | Any, *, ignore_unknown_fie
raise KeyError(f"Unknown field '{field_name}' in message {cls.__name__}.") from None

if value is None:
continue
name, module = field_cls.__name__, field_cls.__module__

# Edge case: None shouldn't be ignored for google.protobuf.Value
# See https://protobuf.dev/programming-guides/json/
if not (module.endswith("google.protobuf") and name == "Value"):
continue

if meta.proto_type == TYPE_MESSAGE:
if meta.repeated:
Expand Down
21 changes: 0 additions & 21 deletions betterproto2/tests/inputs/config.py

This file was deleted.

19 changes: 10 additions & 9 deletions betterproto2/tests/test_any.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
def test_any() -> None:
# TODO using a custom message pool will no longer be necessary when the well-known types will be compiled as well
from tests.outputs.any.any import Person
from tests.outputs.any.google.protobuf import Any

person = Person(first_name="John", last_name="Smith")

any = Any()
any.pack(person)
any = Any.pack(person)

new_any = Any.parse(bytes(any))

Expand All @@ -19,25 +17,28 @@ def test_any_to_dict() -> None:

person = Person(first_name="John", last_name="Smith")

any = Any()

# TODO test with include defautl value
assert any.to_dict() == {"@type": ""}
assert Any().to_dict() == {"@type": ""}

# Pack an object inside
any.pack(person)
any = Any.pack(person)

assert any.to_dict() == {
"@type": "type.googleapis.com/any.Person",
"firstName": "John",
"lastName": "Smith",
}

assert Any.from_dict(any.to_dict()) == any
assert Any.parse(bytes(any)) == any

# Pack again in another Any
any2 = Any()
any2.pack(any)
any2 = Any.pack(any)

assert any2.to_dict() == {
"@type": "type.googleapis.com/google.protobuf.Any",
"value": {"@type": "type.googleapis.com/any.Person", "firstName": "John", "lastName": "Smith"},
}

assert Any.from_dict(any2.to_dict()) == any2
assert Any.parse(bytes(any2)) == any2
2 changes: 0 additions & 2 deletions betterproto2/tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,11 @@ def reset_sys_path():
["googletypes_struct/googletypes_struct.json"],
"googletypes_struct.googletypes_struct",
"googletypes_struct_reference.googletypes_struct_pb2",
xfail=True,
),
TestCase(
["googletypes_value/googletypes_value.json"],
"googletypes_value.googletypes_value",
"googletypes_value_reference.googletypes_value_pb2",
xfail=True,
),
TestCase(["int32/int32.json"], "int32.int32", "int32_reference.int32_pb2"),
TestCase(["map/map.json"], "map.map", "map_reference.map_pb2"),
Expand Down
28 changes: 9 additions & 19 deletions betterproto2/tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@ def complex_msg():
fe=Fe(abc="1"),
nested_data=NestedData(
struct_foo={
"foo": google.Struct(
fields={
"hello": google.Value(list_value=google.ListValue(values=[google.Value(string_value="world")]))
"foo": google.Struct.from_dict(
{
"hello": [["world"]],
}
),
},
map_str_any_bar={
"key": google.Any(value=b"value"),
},
),
mapping={
"message": google.Any(value=bytes(Fi(abc="hi"))),
"string": google.Any(value=b"howdy"),
"message": google.Any.pack(Fi(abc="hi")),
},
)

Expand All @@ -40,9 +36,8 @@ def test_pickling_complex_message():
assert msg == deser
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert msg.nested_data.struct_foo["foo"].fields["hello"].list_value.values[0].string_value == "world"
assert msg.mapping["message"] == google.Any.pack(Fi(abc="hi"))
assert msg.nested_data.struct_foo["foo"].to_dict()["hello"][0][0] == "world"


def test_recursive_message_defaults():
Expand All @@ -51,11 +46,7 @@ def test_recursive_message_defaults():
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
msg = unpickled(msg)

# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))

# lazy initialized works modifies the message
assert msg != RecursiveMessage(name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude"))
msg.child = RecursiveMessage(child=RecursiveMessage(name="jude"))
assert msg == RecursiveMessage(
name="bob",
Expand Down Expand Up @@ -104,7 +95,6 @@ def use_cache():
msg = use_cache()
assert use_cache.calls == 1 # The message is only ever built once
assert msg.fe.abc == "1"
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
assert msg.nested_data.struct_foo["foo"].fields["hello"].list_value.values[0].string_value == "world"
assert not msg.is_set("fi")
assert msg.mapping["message"] == google.Any.pack(Fi(abc="hi"))
assert msg.nested_data.struct_foo["foo"].to_dict()["hello"][0][0] == "world"
50 changes: 50 additions & 0 deletions betterproto2/tests/test_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
def test_struct_to_dict():
from tests.outputs.google.google.protobuf import Struct

struct = Struct.from_dict(
{
"null_field": None,
"number_field": 12,
"string_field": "test",
"bool_field": True,
"struct_field": {"x": "abc"},
"list_field": [42, False, None],
}
)

assert struct.to_dict() == {
"null_field": None,
"number_field": 12,
"string_field": "test",
"bool_field": True,
"struct_field": {"x": "abc"},
"list_field": [42, False, None],
}

assert Struct.from_dict(struct.to_dict()) == struct


def test_listvalue_to_dict():
from tests.outputs.google.google.protobuf import ListValue

list_value = ListValue.from_dict([42, False, {}])

assert list_value.to_dict() == [42, False, {}]
assert ListValue.from_dict(list_value.to_dict()) == list_value


def test_nullvalue():
from tests.outputs.google.google.protobuf import NullValue, Value

null_value = NullValue.NULL_VALUE

assert bytes(Value(null_value=null_value)) == b"\x08\x00"


def test_value_to_dict():
from tests.outputs.google.google.protobuf import Value

value = Value.from_dict([1, 2, False])

assert value.to_dict() == [1, 2, False]
assert Value.from_dict(value.to_dict()) == value
6 changes: 5 additions & 1 deletion betterproto2_compiler/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ keywords = [
requires-python = ">=3.10,<4.0"
dependencies = [
# TODO use the version from the current repo?
"betterproto2[grpclib]>=0.7.0,<0.8",
# "betterproto2[grpclib]>=0.7.0,<0.8",
"betterproto2[grpclib]",
"ruff~=0.9.3",
"jinja2>=3.0.3",
"typing-extensions>=4.7.1,<5",
"strenum>=0.4.15,<0.5 ; python_version == '3.10'",
]

[tool.uv.sources]
"betterproto2" = { path = "../betterproto2" }

[project.urls]
Documentation = "https://betterproto.github.io/python-betterproto2/"
Repository = "https://github.com/betterproto/python-betterproto2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
UInt32Value,
UInt64Value,
)
from .struct import ListValue, Struct, Value
from .timestamp import Timestamp

# For each (package, message name), lists the methods that should be added to the message definition.
# The source code of the method is read from the `known_types` folder. If imports are needed, they can be directly added
# to the template file: they will automatically be removed if not necessary.
KNOWN_METHODS: dict[tuple[str, str], list[Callable]] = {
("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict],
("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict, Any.from_dict],
("google.protobuf", "Timestamp"): [
Timestamp.from_datetime,
Timestamp.to_datetime,
Expand Down Expand Up @@ -92,6 +93,18 @@
BytesValue.from_wrapped,
BytesValue.to_wrapped,
],
("google.protobuf", "Struct"): [
Struct.from_dict,
Struct.to_dict,
],
("google.protobuf", "ListValue"): [
ListValue.from_dict,
ListValue.to_dict,
],
("google.protobuf", "Value"): [
Value.from_dict,
Value.to_dict,
],
}

# A wrapped type is the type of a message that is automatically replaced by a known Python type.
Expand Down
27 changes: 24 additions & 3 deletions betterproto2_compiler/src/betterproto2_compiler/known_types/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


class Any(VanillaAny):
def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None:
@classmethod
def pack(cls, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> "Any":
"""
Pack the given message in the `Any` object.

Expand All @@ -17,8 +18,10 @@ def pack(self, message: betterproto2.Message, message_pool: "betterproto2.Messag
"""
message_pool = message_pool or default_message_pool

self.type_url = message_pool.type_to_url[type(message)]
self.value = bytes(message)
type_url = message_pool.type_to_url[type(message)]
value = bytes(message)

return cls(type_url=type_url, value=value)

def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message | None:
"""
Expand Down Expand Up @@ -54,3 +57,21 @@ def to_dict(self, **kwargs) -> dict[str, typing.Any]:
output["value"] = value.to_dict(**kwargs)

return output

# TODO typing
@classmethod
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
value = dict(value) # Make a copy

type_url = value.pop("@type", None)
msg_cls = default_message_pool.url_to_type.get(type_url, None)

if not msg_cls:
raise TypeError(f"Can't unpack unregistered type: {type_url}")

if not msg_cls.to_dict == betterproto2.Message.to_dict:
value = value["value"]

return cls(
type_url=type_url, value=bytes(msg_cls.from_dict(value, ignore_unknown_fields=ignore_unknown_fields))
)
Loading
Loading