diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/types/_message.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/types/_message.py.j2 index a8119827b8..15bf3ea4ab 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/types/_message.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/types/_message.py.j2 @@ -43,6 +43,7 @@ class {{ message.name }}({{ p }}.Message): {% else -%} {{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field( {{- p }}.{{ field.proto_type }}, number={{ field.number }} + {% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %} {%- if field.enum or field.message %}, {{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }}, {% endif %}) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index cc9b9cf1f0..df3e1daa8e 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -34,6 +34,7 @@ from gapic.schema import wrappers from gapic.schema import naming as api_naming from gapic.utils import cached_property +from gapic.utils import nth from gapic.utils import to_snake_case from gapic.utils import RESERVED_NAMES @@ -556,14 +557,42 @@ def _load_children(self, answer[wrapped.name] = wrapped return answer + def _get_oneofs(self, + oneof_pbs: Sequence[descriptor_pb2.OneofDescriptorProto], + address: metadata.Address, path: Tuple[int, ...], + ) -> Dict[str, wrappers.Oneof]: + """Return a dictionary of wrapped oneofs for the given message. + + Args: + oneof_fields (Sequence[~.descriptor_pb2.OneofDescriptorProto]): A + sequence of protobuf field objects. + address (~.metadata.Address): An address object denoting the + location of these oneofs. + path (Tuple[int]): The source location path thus far, as + understood by ``SourceCodeInfo.Location``. + + Returns: + Mapping[str, ~.wrappers.Oneof]: A ordered mapping of + :class:`~.wrappers.Oneof` objects. + """ + # Iterate over the oneofs and collect them into a dictionary. + answer = collections.OrderedDict( + (oneof_pb.name, wrappers.Oneof(oneof_pb=oneof_pb)) + for i, oneof_pb in enumerate(oneof_pbs) + ) + + # Done; return the answer. + return answer + def _get_fields(self, field_pbs: Sequence[descriptor_pb2.FieldDescriptorProto], address: metadata.Address, path: Tuple[int, ...], + oneofs: Optional[Dict[str, wrappers.Oneof]] = None ) -> Dict[str, wrappers.Field]: """Return a dictionary of wrapped fields for the given message. Args: - fields (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A + field_pbs (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A sequence of protobuf field objects. address (~.metadata.Address): An address object denoting the location of these fields. @@ -585,7 +614,13 @@ def _get_fields(self, # first) and this will be None. This case is addressed in the # `_load_message` method. answer: Dict[str, wrappers.Field] = collections.OrderedDict() - for field_pb, i in zip(field_pbs, range(0, sys.maxsize)): + for i, field_pb in enumerate(field_pbs): + is_oneof = oneofs and field_pb.oneof_index > 0 + oneof_name = nth( + (oneofs or {}).keys(), + field_pb.oneof_index + ) if is_oneof else None + answer[field_pb.name] = wrappers.Field( field_pb=field_pb, enum=self.api_enums.get(field_pb.type_name.lstrip('.')), @@ -594,6 +629,7 @@ def _get_fields(self, address=address.child(field_pb.name, path + (i,)), documentation=self.docs.get(path + (i,), self.EMPTY), ), + oneof=oneof_name, ) # Done; return the answer. @@ -779,19 +815,25 @@ def _load_message(self, loader=self._load_message, path=path + (3,), ) - # self._load_children(message.oneof_decl, loader=self._load_field, - # address=nested_addr, info=info.get(8, {})) + + oneofs = self._get_oneofs( + message_pb.oneof_decl, + address=address, + path=path + (7,), + ) # Create a dictionary of all the fields for this message. fields = self._get_fields( message_pb.field, address=address, path=path + (2,), + oneofs=oneofs, ) fields.update(self._get_fields( message_pb.extension, address=address, path=path + (6,), + oneofs=oneofs, )) # Create a message correspoding to this descriptor. @@ -804,6 +846,7 @@ def _load_message(self, address=address, documentation=self.docs.get(path, self.EMPTY), ), + oneofs=oneofs, ) return self.proto_messages[address.proto] diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 6f7e041f16..1061620378 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -54,6 +54,7 @@ class Field: meta: metadata.Metadata = dataclasses.field( default_factory=metadata.Metadata, ) + oneof: Optional[str] = None def __getattr__(self, name): return getattr(self.field_pb, name) @@ -206,6 +207,15 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Field': ) +@dataclasses.dataclass(frozen=True) +class Oneof: + """Description of a field.""" + oneof_pb: descriptor_pb2.OneofDescriptorProto + + def __getattr__(self, name): + return getattr(self.oneof_pb, name) + + @dataclasses.dataclass(frozen=True) class MessageType: """Description of a message (defined with the ``message`` keyword).""" @@ -220,6 +230,7 @@ class MessageType: meta: metadata.Metadata = dataclasses.field( default_factory=metadata.Metadata, ) + oneofs: Optional[Mapping[str, 'Oneof']] = None def __getattr__(self, name): return getattr(self.message_pb, name) diff --git a/gapic/templates/%namespace/%name_%version/%sub/types/_message.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/types/_message.py.j2 index e9586108c0..5a1eb5fcc6 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/types/_message.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/types/_message.py.j2 @@ -38,14 +38,15 @@ class {{ message.name }}({{ p }}.Message): {{- p }}.{{ key_field.proto_type }}, {{ p }}.{{ value_field.proto_type }}, number={{ field.number }} {%- if value_field.enum or value_field.message %}, {{ value_field.proto_type.lower() }}={{ value_field.type.ident.rel(message.ident) }}, - {% endif %}) + {% endif %}) {# enum or message#} {% endwith -%} - {% else -%} + {% else -%} {# field.map #} {{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field( {{- p }}.{{ field.proto_type }}, number={{ field.number }} + {% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %} {%- if field.enum or field.message %}, {{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }}, - {% endif %}) - {% endif -%} - {% endfor -%} + {% endif %}) {# enum or message #} + {% endif -%} {# field.map #} + {% endfor -%} {# for field in message.fields.values#} {{ '\n\n' }} diff --git a/gapic/templates/noxfile.py.j2 b/gapic/templates/noxfile.py.j2 index d31a325e2f..5fde488f00 100644 --- a/gapic/templates/noxfile.py.j2 +++ b/gapic/templates/noxfile.py.j2 @@ -20,7 +20,7 @@ def unit(session): '--cov-config=.coveragerc', '--cov-report=term', '--cov-report=html', - os.path.join('tests', 'unit', '{{ api.naming.versioned_module_name }}'), + os.path.join('tests', 'unit',) ) diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index c21846a4ac..f561d927e1 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -288,9 +288,9 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'): call.return_value = iter([{{ method.output.ident }}()]) {% else -%} call.return_value = {{ method.output.ident }}( - {%- for field in method.output.fields.values() | rejectattr('message') %} + {%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %} {{ field.name }}={{ field.mock_value }}, - {%- endfor %} + {% endif %}{%- endfor %} ) {% endif -%} {% if method.client_streaming %} @@ -318,7 +318,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'): assert isinstance(message, {{ method.output.ident }}) {% else -%} assert isinstance(response, {{ method.client_output.ident }}) - {% for field in method.output.fields.values() | rejectattr('message') -%} + {% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %} {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6) {% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #} @@ -326,6 +326,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'): {% else -%} assert response.{{ field.name }} == {{ field.mock_value }} {% endif -%} + {% endif -%} {# end oneof/optional #} {% endfor %} {% endif %} @@ -368,8 +369,9 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio {%- else -%} grpc_helpers_async.FakeStreamUnaryCall {%- endif -%}({{ method.output.ident }}( - {%- for field in method.output.fields.values() | rejectattr('message') %} + {%- for field in method.output.fields.values() | rejectattr('message') %}{% if not (field.oneof and not field.proto3_optional) %} {{ field.name }}={{ field.mock_value }}, + {%- endif %} {%- endfor %} )) {% endif -%} @@ -400,7 +402,7 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio assert isinstance(message, {{ method.output.ident }}) {% else -%} assert isinstance(response, {{ method.client_output_async.ident }}) - {% for field in method.output.fields.values() | rejectattr('message') -%} + {% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %} {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6) {% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #} @@ -408,6 +410,7 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio {% else -%} assert response.{{ field.name }} == {{ field.mock_value }} {% endif -%} + {% endif -%} {# oneof/optional #} {% endfor %} {% endif %} diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 315c575e2b..905fcbdec2 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -15,6 +15,7 @@ from gapic.utils.cache import cached_property from gapic.utils.case import to_snake_case from gapic.utils.code import empty +from gapic.utils.code import nth from gapic.utils.code import partition from gapic.utils.doc import doc from gapic.utils.filename import to_valid_filename @@ -29,6 +30,7 @@ 'cached_property', 'doc', 'empty', + 'nth', 'partition', 'RESERVED_NAMES', 'rst', diff --git a/gapic/utils/code.py b/gapic/utils/code.py index 27458a999c..15f327983c 100644 --- a/gapic/utils/code.py +++ b/gapic/utils/code.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import (Callable, Iterable, List, Tuple, TypeVar) +from typing import (Callable, Iterable, List, Optional, Tuple, TypeVar) +import itertools def empty(content: str) -> bool: @@ -50,3 +51,15 @@ def partition(predicate: Callable[[T], bool], # Returns trueList, falseList return results[1], results[0] + + +def nth(iterable: Iterable[T], n: int, default: Optional[T] = None) -> Optional[T]: + """Return the nth element of an iterable or a default value. + + Args + iterable (Iterable(T)): An iterable on any type. + n (int): The 'index' of the lement to retrieve. + default (Optional(T)): An optional default elemnt if the iterable has + fewer than n elements. + """ + return next(itertools.islice(iterable, n, None), default) diff --git a/test_utils/test_utils.py b/test_utils/test_utils.py index 697d08e8df..89fd735142 100644 --- a/test_utils/test_utils.py +++ b/test_utils/test_utils.py @@ -200,6 +200,7 @@ def make_field( message: wrappers.MessageType = None, enum: wrappers.EnumType = None, meta: metadata.Metadata = None, + oneof: str = None, **kwargs ) -> wrappers.Field: T = desc.FieldDescriptorProto.Type @@ -223,11 +224,13 @@ def make_field( number=number, **kwargs ) + return wrappers.Field( field_pb=field_pb, enum=enum, message=message, meta=meta or metadata.Metadata(), + oneof=oneof, ) @@ -322,20 +325,28 @@ def make_enum_pb2( def make_message_pb2( name: str, fields: tuple = (), + oneof_decl: tuple = (), **kwargs ) -> desc.DescriptorProto: - return desc.DescriptorProto(name=name, field=fields, **kwargs) + return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs) def make_field_pb2(name: str, number: int, type: int = 11, # 11 == message type_name: str = None, + oneof_index: int = None ) -> desc.FieldDescriptorProto: return desc.FieldDescriptorProto( name=name, number=number, type=type, type_name=type_name, + oneof_index=oneof_index, + ) + +def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto: + return desc.OneofDescriptorProto( + name=name, ) diff --git a/tests/unit/schema/test_api.py b/tests/unit/schema/test_api.py index 8dc1760cd8..b3f023054c 100644 --- a/tests/unit/schema/test_api.py +++ b/tests/unit/schema/test_api.py @@ -34,6 +34,7 @@ make_file_pb2, make_message_pb2, make_naming, + make_oneof_pb2, ) @@ -239,6 +240,45 @@ def test_proto_keyword_fname(): } +def test_proto_oneof(): + # Put together a couple of minimal protos. + fd = ( + make_file_pb2( + name='dep.proto', + package='google.dep', + messages=(make_message_pb2(name='ImportedMessage', fields=()),), + ), + make_file_pb2( + name='foo.proto', + package='google.example.v1', + messages=( + make_message_pb2(name='Foo', fields=()), + make_message_pb2( + name='Bar', + fields=( + make_field_pb2(name='imported_message', number=1, + type_name='.google.dep.ImportedMessage', + oneof_index=0), + make_field_pb2( + name='primitive', number=2, type=1, oneof_index=0), + ), + oneof_decl=( + make_oneof_pb2(name="value_type"), + ) + ) + ) + ) + ) + + # Create an API with those protos. + api_schema = api.API.build(fd, package='google.example.v1') + proto = api_schema.protos['foo.proto'] + assert proto.names == {'imported_message', 'Bar', 'primitive', 'Foo'} + oneofs = proto.messages["google.example.v1.Bar"].oneofs + assert len(oneofs) == 1 + assert "value_type" in oneofs.keys() + + def test_proto_names_import_collision(): # Put together a couple of minimal protos. fd = ( diff --git a/tests/unit/schema/wrappers/test_field.py b/tests/unit/schema/wrappers/test_field.py index 7104c735be..3cdcaf9bcf 100644 --- a/tests/unit/schema/wrappers/test_field.py +++ b/tests/unit/schema/wrappers/test_field.py @@ -153,6 +153,13 @@ def test_mock_value_int(): assert field.mock_value == '728' +def test_oneof(): + REP = descriptor_pb2.FieldDescriptorProto.Label.Value('LABEL_REPEATED') + + field = make_field(oneof="oneof_name") + assert field.oneof == "oneof_name" + + def test_mock_value_float(): field = make_field(name='foo_bar', type='TYPE_DOUBLE') assert field.mock_value == '0.728' diff --git a/tests/unit/schema/wrappers/test_oneof.py b/tests/unit/schema/wrappers/test_oneof.py new file mode 100644 index 0000000000..90fe2546ce --- /dev/null +++ b/tests/unit/schema/wrappers/test_oneof.py @@ -0,0 +1,35 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +import pytest + +from google.api import field_behavior_pb2 +from google.protobuf import descriptor_pb2 + +from gapic.schema import metadata +from gapic.schema import wrappers + +from test_utils.test_utils import ( + make_oneof_pb2, +) + + +def test_wrapped_oneof(): + oneof_pb = make_oneof_pb2("oneof_name") + wrapped = wrappers.Oneof(oneof_pb=oneof_pb) + + assert wrapped.oneof_pb == oneof_pb + assert wrapped.name == oneof_pb.name diff --git a/tests/unit/utils/test_code.py b/tests/unit/utils/test_code.py index 1069443f7b..5f18679d6f 100644 --- a/tests/unit/utils/test_code.py +++ b/tests/unit/utils/test_code.py @@ -34,3 +34,12 @@ def test_empty_whitespace_comments(): def test_empty_code(): assert not code.empty('import this') + + +def test_nth(): + # list + assert code.nth([i * i for i in range(20)], 4) == 16 + # generator + assert code.nth((i * i for i in range(20)), 4) == 16 + # default + assert code.nth((i * i for i in range(20)), 30, 2112) == 2112