Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 67 additions & 23 deletions proto/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,34 +527,72 @@ def __init__(
# coerced.
marshal = self._meta.marshal
for key, value in mapping.items():
(key, pb_type) = self._get_pb_type_from_key(key)
if pb_type is None:
if ignore_unknown_fields:
continue

raise ValueError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)

try:
pb_type = self._meta.fields[key].pb_type
except KeyError:
pb_value = marshal.to_proto(pb_type, value)
except ValueError:
# Underscores may be appended to field names
# that collide with python or proto-plus keywords.
# In case a key only exists with a `_` suffix, coerce the key
# to include the `_` suffix. Is not possible to
# to include the `_` suffix. It's not possible to
# natively define the same field with a trailing underscore in protobuf.
# See related issue
# https://github.com/googleapis/python-api-core/issues/227
if f"{key}_" in self._meta.fields:
key = f"{key}_"
pb_type = self._meta.fields[key].pb_type
else:
if ignore_unknown_fields:
continue

raise ValueError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)

pb_value = marshal.to_proto(pb_type, value)
if isinstance(value, dict):
keys_to_update = [
item
for item in value
if not hasattr(pb_type, item) and hasattr(pb_type, f"{item}_")
]
for item in keys_to_update:
value[f"{item}_"] = value.pop(item)

pb_value = marshal.to_proto(pb_type, value)

if pb_value is not None:
params[key] = pb_value

# Create the internal protocol buffer.
super().__setattr__("_pb", self._meta.pb(**params))

def _get_pb_type_from_key(self, key):
"""Given a key, return the corresponding pb_type.

Args:
key(str): The name of the field.

Returns:
A tuple containing a key and pb_type. The pb_type will be
the composite type of the field, or the primitive type if a primitive.
If no corresponding field exists, return None.
"""

pb_type = None

try:
pb_type = self._meta.fields[key].pb_type
except KeyError:
# Underscores may be appended to field names
# that collide with python or proto-plus keywords.
# In case a key only exists with a `_` suffix, coerce the key
# to include the `_` suffix. It's not possible to
# natively define the same field with a trailing underscore in protobuf.
# See related issue
# https://github.com/googleapis/python-api-core/issues/227
if f"{key}_" in self._meta.fields:
key = f"{key}_"
pb_type = self._meta.fields[key].pb_type

return (key, pb_type)

def __dir__(self):
desc = type(self).pb().DESCRIPTOR
names = {f_name for f_name in self._meta.fields.keys()}
Expand Down Expand Up @@ -664,13 +702,14 @@ def __getattr__(self, key):
their Python equivalents. See the ``marshal`` module for
more details.
"""
try:
pb_type = self._meta.fields[key].pb_type
pb_value = getattr(self._pb, key)
marshal = self._meta.marshal
return marshal.to_python(pb_type, pb_value, absent=key not in self)
except KeyError as ex:
raise AttributeError(str(ex))
(key, pb_type) = self._get_pb_type_from_key(key)
if pb_type is None:
raise AttributeError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)
pb_value = getattr(self._pb, key)
marshal = self._meta.marshal
return marshal.to_python(pb_type, pb_value, absent=key not in self)

def __ne__(self, other):
"""Return True if the messages are unequal, False otherwise."""
Expand All @@ -688,7 +727,12 @@ def __setattr__(self, key, value):
if key[0] == "_":
return super().__setattr__(key, value)
marshal = self._meta.marshal
pb_type = self._meta.fields[key].pb_type
(key, pb_type) = self._get_pb_type_from_key(key)
if pb_type is None:
raise AttributeError(
"Unknown field for {}: {}".format(self.__class__.__name__, key)
)

pb_value = marshal.to_proto(pb_type, value)

# Clear the existing field.
Expand Down
49 changes: 47 additions & 2 deletions tests/test_fields_mitigate_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import proto

import pytest

# Underscores may be appended to field names
# that collide with python or proto-plus keywords.
# In case a key only exists with a `_` suffix, coerce the key
# to include the `_` suffix. Is not possible to
# to include the `_` suffix. It's not possible to
# natively define the same field with a trailing underscore in protobuf.
# See related issue
# https://github.com/googleapis/python-api-core/issues/227
Expand All @@ -27,10 +27,55 @@ class TestMessage(proto.Message):
spam_ = proto.Field(proto.STRING, number=1)
eggs = proto.Field(proto.STRING, number=2)

class TextStream(proto.Message):
text_stream = proto.Field(TestMessage, number=1)

obj = TestMessage(spam_="has_spam")
obj.eggs = "has_eggs"
assert obj.spam_ == "has_spam"

# Test that `spam` is coerced to `spam_`
modified_obj = TestMessage({"spam": "has_spam", "eggs": "has_eggs"})
assert modified_obj.spam_ == "has_spam"

# Test get and set
modified_obj.spam = "no_spam"
assert modified_obj.spam == "no_spam"

modified_obj.spam_ = "yes_spam"
assert modified_obj.spam_ == "yes_spam"

modified_obj.spam = "maybe_spam"
assert modified_obj.spam_ == "maybe_spam"

modified_obj.spam_ = "maybe_not_spam"
assert modified_obj.spam == "maybe_not_spam"

# Try nested values
modified_obj = TextStream(
text_stream=TestMessage({"spam": "has_spam", "eggs": "has_eggs"})
)
assert modified_obj.text_stream.spam_ == "has_spam"

# Test get and set for nested values
modified_obj.text_stream.spam = "no_spam"
assert modified_obj.text_stream.spam == "no_spam"

modified_obj.text_stream.spam_ = "yes_spam"
assert modified_obj.text_stream.spam_ == "yes_spam"

modified_obj.text_stream.spam = "maybe_spam"
assert modified_obj.text_stream.spam_ == "maybe_spam"

modified_obj.text_stream.spam_ = "maybe_not_spam"
assert modified_obj.text_stream.spam == "maybe_not_spam"

with pytest.raises(AttributeError):
assert modified_obj.text_stream.attribute_does_not_exist == "n/a"

with pytest.raises(AttributeError):
modified_obj.text_stream.attribute_does_not_exist = "n/a"

# Try using dict
modified_obj = TextStream(text_stream={"spam": "has_spam", "eggs": "has_eggs"})
assert modified_obj.text_stream.spam_ == "has_spam"