Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize default values in oneofs when calling to_dict() or to_json() #110

Merged
merged 12 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,18 +583,20 @@ def __bytes__(self) -> bytes:
# Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
selected_in_group = False
if meta.group and self._group_current[meta.group] == field_name:
selected_in_group = True
selected_in_group = (
meta.group and self._group_current[meta.group] == field_name
)

serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire:
# Empty messages can still be sent on the wire if they were
# set (or received empty).
serialize_empty = True
# Empty messages can still be sent on the wire if they were
# set (or received empty).
serialize_empty = isinstance(value, Message) and value._serialized_on_wire

include_default_value_for_oneof = self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)

if value == self._get_field_default(field_name) and not (
selected_in_group or serialize_empty
selected_in_group or serialize_empty or include_default_value_for_oneof
):
# Default (zero) values are not serialized. Two exceptions are
# if this is the selected oneof item or if we know we have to
Expand Down Expand Up @@ -623,6 +625,17 @@ def __bytes__(self) -> bytes:
sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else:
# If we have an empty string and we're including the default value for
# a oneof, make sure we serialize it. This ensures that the byte string
# output isn't simply an empty string. This also ensures that round trip
# serialization will keep `which_one_of` calls consistent.
if (
isinstance(value, str)
and value == ""
and include_default_value_for_oneof
):
serialize_empty = True

output += _serialize_single(
meta.number,
meta.proto_type,
Expand Down Expand Up @@ -726,6 +739,13 @@ def _postprocess_single(

return value

def _include_default_value_for_oneof(
self, field_name: str, meta: FieldMetadata
) -> bool:
return (
meta.group is not None and self._group_current.get(meta.group) == field_name
)

def parse(self: T, data: bytes) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
Expand Down Expand Up @@ -804,10 +824,22 @@ def to_dict(
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
if value != DATETIME_ZERO or include_default_values:
if (
value != DATETIME_ZERO
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
):
output[cased_name] = _Timestamp.timestamp_to_json(value)
elif isinstance(value, timedelta):
if value != timedelta(0) or include_default_values:
if (
value != timedelta(0)
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
):
output[cased_name] = _Duration.delta_to_json(value)
elif meta.wraps:
if value is not None or include_default_values:
Expand All @@ -817,19 +849,28 @@ def to_dict(
value = [i.to_dict(casing, include_default_values) for i in value]
if value or include_default_values:
output[cased_name] = value
else:
if value._serialized_on_wire or include_default_values:
output[cased_name] = value.to_dict(
casing, include_default_values
)
elif meta.proto_type == TYPE_MAP:
elif (
value._serialized_on_wire
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
):
output[cased_name] = value.to_dict(casing, include_default_values,)
elif meta.proto_type == "map":
for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)

if value or include_default_values:
output[cased_name] = value
elif value != self._get_field_default(field_name) or include_default_values:
elif (
value != self._get_field_default(field_name)
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
):
if meta.proto_type in INT_64_TYPES:
if field_is_repeated:
output[cased_name] = [str(n) for n in value]
Expand Down Expand Up @@ -888,6 +929,8 @@ def from_dict(self: T, value: dict) -> T:
elif meta.wraps:
setattr(self, field_name, value[key])
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
Expand All @@ -913,8 +956,8 @@ def from_dict(self: T, value: dict) -> T:
elif isinstance(v, str):
v = enum_cls.from_string(v)

if v is not None:
setattr(self, field_name, v)
if v is not None:
setattr(self, field_name, v)
return self

def to_json(self, indent: Union[None, int, str] = None) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
syntax = "proto3";

message Foo{
int64 bar = 1;
}

message Test{
oneof group{
string string = 1;
int64 integer = 2;
Foo foo = 3;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest

from google.protobuf import json_format
import betterproto
from tests.output_betterproto.google_impl_behavior_equivalence import (
Test,
Foo,
)
from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
Test as ReferenceTest,
Foo as ReferenceFoo,
)


def test_oneof_serializes_similar_to_google_oneof():

tests = [
(Test(string="abc"), ReferenceTest(string="abc")),
(Test(integer=2), ReferenceTest(integer=2)),
(Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))),
# Default values should also behave the same within oneofs
(Test(string=""), ReferenceTest(string="")),
(Test(integer=0), ReferenceTest(integer=0)),
(Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))),
]
for message, message_reference in tests:
# NOTE: As of July 2020, MessageToJson inserts newlines in the output string so,
# just compare dicts
assert message.to_dict() == json_format.MessageToDict(message_reference)


def test_bytes_are_the_same_for_oneof():

message = Test(string="")
message_reference = ReferenceTest(string="")

message_bytes = bytes(message)
message_reference_bytes = message_reference.SerializeToString()

assert message_bytes == message_reference_bytes

message2 = Test().parse(message_reference_bytes)
message_reference2 = ReferenceTest()
message_reference2.ParseFromString(message_reference_bytes)

assert message == message2
assert message_reference == message_reference2

# None of these fields were explicitly set BUT they should not actually be null
# themselves
assert isinstance(message.foo, Foo)
assert isinstance(message2.foo, Foo)

assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
syntax = "proto3";

import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";

message Message{
int64 value = 1;
}

message NestedMessage{
int64 id = 1;
oneof value_type{
Message wrapped_message_value = 2;
}
}

message Test{
oneof value_type {
bool bool_value = 1;
int64 int64_value = 2;
google.protobuf.Timestamp timestamp_value = 3;
google.protobuf.Duration duration_value = 4;
Message wrapped_message_value = 5;
NestedMessage wrapped_nested_message_value = 6;
google.protobuf.BoolValue wrapped_bool_value = 7;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import datetime

import betterproto
from tests.output_betterproto.oneof_default_value_serialization import (
Test,
Message,
NestedMessage,
)


def assert_round_trip_serialization_works(message: Test) -> None:
assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of(
Test().from_json(message.to_json()), "value_type"
)


def test_oneof_default_value_serialization_works_for_all_values():
"""
Serialization from message with oneof set to default -> JSON -> message should keep
default value field intact.
"""

test_cases = [
Test(bool_value=False),
Test(int64_value=0),
Test(
timestamp_value=datetime.datetime(
year=1970,
month=1,
day=1,
hour=0,
minute=0,
tzinfo=datetime.timezone.utc,
)
),
Test(duration_value=datetime.timedelta(0)),
Test(wrapped_message_value=Message(value=0)),
# NOTE: Do NOT use betterproto.BoolValue here, it will cause JSON serialization
# errors.
# TODO: Do we want to allow use of BoolValue directly within a wrapped field or
# should we simply hard fail here?
Test(wrapped_bool_value=False),
]
for message in test_cases:
assert_round_trip_serialization_works(message)


def test_oneof_no_default_values_passed():
message = Test()
assert (
betterproto.which_one_of(message, "value_type")
== betterproto.which_one_of(Test().from_json(message.to_json()), "value_type")
== ("", None)
)


def test_oneof_nested_oneof_messages_are_serialized_with_defaults():
"""
Nested messages with oneofs should also be handled
"""
message = Test(
wrapped_nested_message_value=NestedMessage(
id=0, wrapped_message_value=Message(value=0)
)
)
assert (
betterproto.which_one_of(message, "value_type")
== betterproto.which_one_of(Test().from_json(message.to_json()), "value_type")
== (
"wrapped_nested_message_value",
NestedMessage(id=0, wrapped_message_value=Message(value=0)),
)
)
16 changes: 9 additions & 7 deletions tests/inputs/oneof_enum/test_oneof_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,36 @@
from tests.util import get_test_case_json_data


@pytest.mark.xfail
def test_which_one_of_returns_enum_with_default_value():
"""
returns first field when it is enum and set with default value
"""
message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json"))
assert message.move is None

assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)


@pytest.mark.xfail
def test_which_one_of_returns_enum_with_non_default_value():
"""
returns first field when it is enum and set with non default value
"""
message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json"))
assert message.move is None
assert message.signal == Signal.PASS
assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert message.signal == Signal.RESIGN
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)


@pytest.mark.xfail
def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum"))
assert message.move == Move(x=2, y=3)
assert message.signal == 0
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
Loading