diff --git a/proto/message.py b/proto/message.py index de9280da..13a7635c 100644 --- a/proto/message.py +++ b/proto/message.py @@ -527,34 +527,72 @@ def __init__( # coerced. marshal = self._meta.marshal for key, value in mapping.items(): + (key, pb_type) = self._get_pb_type_from_key(key) + if pb_type is None: + if ignore_unknown_fields: + continue + + raise ValueError( + "Unknown field for {}: {}".format(self.__class__.__name__, key) + ) + try: - pb_type = self._meta.fields[key].pb_type - except KeyError: + pb_value = marshal.to_proto(pb_type, value) + except ValueError: # Underscores may be appended to field names # that collide with python or proto-plus keywords. # In case a key only exists with a `_` suffix, coerce the key - # to include the `_` suffix. Is not possible to + # to include the `_` suffix. It's not possible to # natively define the same field with a trailing underscore in protobuf. # See related issue # https://github.com/googleapis/python-api-core/issues/227 - if f"{key}_" in self._meta.fields: - key = f"{key}_" - pb_type = self._meta.fields[key].pb_type - else: - if ignore_unknown_fields: - continue - - raise ValueError( - "Unknown field for {}: {}".format(self.__class__.__name__, key) - ) - - pb_value = marshal.to_proto(pb_type, value) + if isinstance(value, dict): + keys_to_update = [ + item + for item in value + if not hasattr(pb_type, item) and hasattr(pb_type, f"{item}_") + ] + for item in keys_to_update: + value[f"{item}_"] = value.pop(item) + + pb_value = marshal.to_proto(pb_type, value) + if pb_value is not None: params[key] = pb_value # Create the internal protocol buffer. super().__setattr__("_pb", self._meta.pb(**params)) + def _get_pb_type_from_key(self, key): + """Given a key, return the corresponding pb_type. + + Args: + key(str): The name of the field. + + Returns: + A tuple containing a key and pb_type. The pb_type will be + the composite type of the field, or the primitive type if a primitive. + If no corresponding field exists, return None. + """ + + pb_type = None + + try: + pb_type = self._meta.fields[key].pb_type + except KeyError: + # Underscores may be appended to field names + # that collide with python or proto-plus keywords. + # In case a key only exists with a `_` suffix, coerce the key + # to include the `_` suffix. It's not possible to + # natively define the same field with a trailing underscore in protobuf. + # See related issue + # https://github.com/googleapis/python-api-core/issues/227 + if f"{key}_" in self._meta.fields: + key = f"{key}_" + pb_type = self._meta.fields[key].pb_type + + return (key, pb_type) + def __dir__(self): desc = type(self).pb().DESCRIPTOR names = {f_name for f_name in self._meta.fields.keys()} @@ -664,13 +702,14 @@ def __getattr__(self, key): their Python equivalents. See the ``marshal`` module for more details. """ - try: - pb_type = self._meta.fields[key].pb_type - pb_value = getattr(self._pb, key) - marshal = self._meta.marshal - return marshal.to_python(pb_type, pb_value, absent=key not in self) - except KeyError as ex: - raise AttributeError(str(ex)) + (key, pb_type) = self._get_pb_type_from_key(key) + if pb_type is None: + raise AttributeError( + "Unknown field for {}: {}".format(self.__class__.__name__, key) + ) + pb_value = getattr(self._pb, key) + marshal = self._meta.marshal + return marshal.to_python(pb_type, pb_value, absent=key not in self) def __ne__(self, other): """Return True if the messages are unequal, False otherwise.""" @@ -688,7 +727,12 @@ def __setattr__(self, key, value): if key[0] == "_": return super().__setattr__(key, value) marshal = self._meta.marshal - pb_type = self._meta.fields[key].pb_type + (key, pb_type) = self._get_pb_type_from_key(key) + if pb_type is None: + raise AttributeError( + "Unknown field for {}: {}".format(self.__class__.__name__, key) + ) + pb_value = marshal.to_proto(pb_type, value) # Clear the existing field. diff --git a/tests/test_fields_mitigate_collision.py b/tests/test_fields_mitigate_collision.py index 0dca71df..117af48a 100644 --- a/tests/test_fields_mitigate_collision.py +++ b/tests/test_fields_mitigate_collision.py @@ -13,12 +13,12 @@ # limitations under the License. import proto - +import pytest # Underscores may be appended to field names # that collide with python or proto-plus keywords. # In case a key only exists with a `_` suffix, coerce the key -# to include the `_` suffix. Is not possible to +# to include the `_` suffix. It's not possible to # natively define the same field with a trailing underscore in protobuf. # See related issue # https://github.com/googleapis/python-api-core/issues/227 @@ -27,6 +27,9 @@ class TestMessage(proto.Message): spam_ = proto.Field(proto.STRING, number=1) eggs = proto.Field(proto.STRING, number=2) + class TextStream(proto.Message): + text_stream = proto.Field(TestMessage, number=1) + obj = TestMessage(spam_="has_spam") obj.eggs = "has_eggs" assert obj.spam_ == "has_spam" @@ -34,3 +37,45 @@ class TestMessage(proto.Message): # Test that `spam` is coerced to `spam_` modified_obj = TestMessage({"spam": "has_spam", "eggs": "has_eggs"}) assert modified_obj.spam_ == "has_spam" + + # Test get and set + modified_obj.spam = "no_spam" + assert modified_obj.spam == "no_spam" + + modified_obj.spam_ = "yes_spam" + assert modified_obj.spam_ == "yes_spam" + + modified_obj.spam = "maybe_spam" + assert modified_obj.spam_ == "maybe_spam" + + modified_obj.spam_ = "maybe_not_spam" + assert modified_obj.spam == "maybe_not_spam" + + # Try nested values + modified_obj = TextStream( + text_stream=TestMessage({"spam": "has_spam", "eggs": "has_eggs"}) + ) + assert modified_obj.text_stream.spam_ == "has_spam" + + # Test get and set for nested values + modified_obj.text_stream.spam = "no_spam" + assert modified_obj.text_stream.spam == "no_spam" + + modified_obj.text_stream.spam_ = "yes_spam" + assert modified_obj.text_stream.spam_ == "yes_spam" + + modified_obj.text_stream.spam = "maybe_spam" + assert modified_obj.text_stream.spam_ == "maybe_spam" + + modified_obj.text_stream.spam_ = "maybe_not_spam" + assert modified_obj.text_stream.spam == "maybe_not_spam" + + with pytest.raises(AttributeError): + assert modified_obj.text_stream.attribute_does_not_exist == "n/a" + + with pytest.raises(AttributeError): + modified_obj.text_stream.attribute_does_not_exist = "n/a" + + # Try using dict + modified_obj = TextStream(text_stream={"spam": "has_spam", "eggs": "has_eggs"}) + assert modified_obj.text_stream.spam_ == "has_spam"