Skip to content

Commit

Permalink
fix: add oneof fields to generated protoplus init (#485)
Browse files Browse the repository at this point in the history
Fixes: #484
  • Loading branch information
crwilcox committed Jul 7, 2020
1 parent 9076362 commit be5a847
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 17 deletions.
Expand Up @@ -43,6 +43,7 @@ class {{ message.name }}({{ p }}.Message):
{% else -%}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
{%- if field.enum or field.message %},
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
{% endif %})
Expand Down
51 changes: 47 additions & 4 deletions gapic/schema/api.py
Expand Up @@ -34,6 +34,7 @@
from gapic.schema import wrappers
from gapic.schema import naming as api_naming
from gapic.utils import cached_property
from gapic.utils import nth
from gapic.utils import to_snake_case
from gapic.utils import RESERVED_NAMES

Expand Down Expand Up @@ -556,14 +557,42 @@ def _load_children(self,
answer[wrapped.name] = wrapped
return answer

def _get_oneofs(self,
oneof_pbs: Sequence[descriptor_pb2.OneofDescriptorProto],
address: metadata.Address, path: Tuple[int, ...],
) -> Dict[str, wrappers.Oneof]:
"""Return a dictionary of wrapped oneofs for the given message.
Args:
oneof_fields (Sequence[~.descriptor_pb2.OneofDescriptorProto]): A
sequence of protobuf field objects.
address (~.metadata.Address): An address object denoting the
location of these oneofs.
path (Tuple[int]): The source location path thus far, as
understood by ``SourceCodeInfo.Location``.
Returns:
Mapping[str, ~.wrappers.Oneof]: A ordered mapping of
:class:`~.wrappers.Oneof` objects.
"""
# Iterate over the oneofs and collect them into a dictionary.
answer = collections.OrderedDict(
(oneof_pb.name, wrappers.Oneof(oneof_pb=oneof_pb))
for i, oneof_pb in enumerate(oneof_pbs)
)

# Done; return the answer.
return answer

def _get_fields(self,
field_pbs: Sequence[descriptor_pb2.FieldDescriptorProto],
address: metadata.Address, path: Tuple[int, ...],
oneofs: Optional[Dict[str, wrappers.Oneof]] = None
) -> Dict[str, wrappers.Field]:
"""Return a dictionary of wrapped fields for the given message.
Args:
fields (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
field_pbs (Sequence[~.descriptor_pb2.FieldDescriptorProto]): A
sequence of protobuf field objects.
address (~.metadata.Address): An address object denoting the
location of these fields.
Expand All @@ -585,7 +614,13 @@ def _get_fields(self,
# first) and this will be None. This case is addressed in the
# `_load_message` method.
answer: Dict[str, wrappers.Field] = collections.OrderedDict()
for field_pb, i in zip(field_pbs, range(0, sys.maxsize)):
for i, field_pb in enumerate(field_pbs):
is_oneof = oneofs and field_pb.oneof_index > 0
oneof_name = nth(
(oneofs or {}).keys(),
field_pb.oneof_index
) if is_oneof else None

answer[field_pb.name] = wrappers.Field(
field_pb=field_pb,
enum=self.api_enums.get(field_pb.type_name.lstrip('.')),
Expand All @@ -594,6 +629,7 @@ def _get_fields(self,
address=address.child(field_pb.name, path + (i,)),
documentation=self.docs.get(path + (i,), self.EMPTY),
),
oneof=oneof_name,
)

# Done; return the answer.
Expand Down Expand Up @@ -779,19 +815,25 @@ def _load_message(self,
loader=self._load_message,
path=path + (3,),
)
# self._load_children(message.oneof_decl, loader=self._load_field,
# address=nested_addr, info=info.get(8, {}))

oneofs = self._get_oneofs(
message_pb.oneof_decl,
address=address,
path=path + (7,),
)

# Create a dictionary of all the fields for this message.
fields = self._get_fields(
message_pb.field,
address=address,
path=path + (2,),
oneofs=oneofs,
)
fields.update(self._get_fields(
message_pb.extension,
address=address,
path=path + (6,),
oneofs=oneofs,
))

# Create a message correspoding to this descriptor.
Expand All @@ -804,6 +846,7 @@ def _load_message(self,
address=address,
documentation=self.docs.get(path, self.EMPTY),
),
oneofs=oneofs,
)
return self.proto_messages[address.proto]

Expand Down
11 changes: 11 additions & 0 deletions gapic/schema/wrappers.py
Expand Up @@ -54,6 +54,7 @@ class Field:
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
oneof: Optional[str] = None

def __getattr__(self, name):
return getattr(self.field_pb, name)
Expand Down Expand Up @@ -206,6 +207,15 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Field':
)


@dataclasses.dataclass(frozen=True)
class Oneof:
"""Description of a field."""
oneof_pb: descriptor_pb2.OneofDescriptorProto

def __getattr__(self, name):
return getattr(self.oneof_pb, name)


@dataclasses.dataclass(frozen=True)
class MessageType:
"""Description of a message (defined with the ``message`` keyword)."""
Expand All @@ -220,6 +230,7 @@ class MessageType:
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
oneofs: Optional[Mapping[str, 'Oneof']] = None

def __getattr__(self, name):
return getattr(self.message_pb, name)
Expand Down
Expand Up @@ -38,14 +38,15 @@ class {{ message.name }}({{ p }}.Message):
{{- p }}.{{ key_field.proto_type }}, {{ p }}.{{ value_field.proto_type }}, number={{ field.number }}
{%- if value_field.enum or value_field.message %},
{{ value_field.proto_type.lower() }}={{ value_field.type.ident.rel(message.ident) }},
{% endif %})
{% endif %}) {# enum or message#}
{% endwith -%}
{% else -%}
{% else -%} {# field.map #}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{- p }}.{{ field.proto_type }}, number={{ field.number }}
{% if field.oneof %}, oneof='{{ field.oneof }}'{% endif %}
{%- if field.enum or field.message %},
{{ field.proto_type.lower() }}={{ field.type.ident.rel(message.ident) }},
{% endif %})
{% endif -%}
{% endfor -%}
{% endif %}) {# enum or message #}
{% endif -%} {# field.map #}
{% endfor -%} {# for field in message.fields.values#}
{{ '\n\n' }}
2 changes: 1 addition & 1 deletion gapic/templates/noxfile.py.j2
Expand Up @@ -20,7 +20,7 @@ def unit(session):
'--cov-config=.coveragerc',
'--cov-report=term',
'--cov-report=html',
os.path.join('tests', 'unit', '{{ api.naming.versioned_module_name }}'),
os.path.join('tests', 'unit',)
)


Expand Down
Expand Up @@ -288,9 +288,9 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
call.return_value = iter([{{ method.output.ident }}()])
{% else -%}
call.return_value = {{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message') %}
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %}
{{ field.name }}={{ field.mock_value }},
{%- endfor %}
{% endif %}{%- endfor %}
)
{% endif -%}
{% if method.client_streaming %}
Expand Down Expand Up @@ -318,14 +318,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
assert isinstance(message, {{ method.output.ident }})
{% else -%}
assert isinstance(response, {{ method.client_output.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else -%}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif -%}
{% endif -%} {# end oneof/optional #}
{% endfor %}
{% endif %}

Expand Down Expand Up @@ -368,8 +369,9 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
{%- else -%}
grpc_helpers_async.FakeStreamUnaryCall
{%- endif -%}({{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message') %}
{%- for field in method.output.fields.values() | rejectattr('message') %}{% if not (field.oneof and not field.proto3_optional) %}
{{ field.name }}={{ field.mock_value }},
{%- endif %}
{%- endfor %}
))
{% endif -%}
Expand Down Expand Up @@ -400,14 +402,15 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
assert isinstance(message, {{ method.output.ident }})
{% else -%}
assert isinstance(response, {{ method.client_output_async.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not (field.oneof and not field.proto3_optional) %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6)
{% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #}
assert response.{{ field.name }} is {{ field.mock_value }}
{% else -%}
assert response.{{ field.name }} == {{ field.mock_value }}
{% endif -%}
{% endif -%} {# oneof/optional #}
{% endfor %}
{% endif %}

Expand Down
2 changes: 2 additions & 0 deletions gapic/utils/__init__.py
Expand Up @@ -15,6 +15,7 @@
from gapic.utils.cache import cached_property
from gapic.utils.case import to_snake_case
from gapic.utils.code import empty
from gapic.utils.code import nth
from gapic.utils.code import partition
from gapic.utils.doc import doc
from gapic.utils.filename import to_valid_filename
Expand All @@ -29,6 +30,7 @@
'cached_property',
'doc',
'empty',
'nth',
'partition',
'RESERVED_NAMES',
'rst',
Expand Down
15 changes: 14 additions & 1 deletion gapic/utils/code.py
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import (Callable, Iterable, List, Tuple, TypeVar)
from typing import (Callable, Iterable, List, Optional, Tuple, TypeVar)
import itertools


def empty(content: str) -> bool:
Expand Down Expand Up @@ -50,3 +51,15 @@ def partition(predicate: Callable[[T], bool],

# Returns trueList, falseList
return results[1], results[0]


def nth(iterable: Iterable[T], n: int, default: Optional[T] = None) -> Optional[T]:
"""Return the nth element of an iterable or a default value.
Args
iterable (Iterable(T)): An iterable on any type.
n (int): The 'index' of the lement to retrieve.
default (Optional(T)): An optional default elemnt if the iterable has
fewer than n elements.
"""
return next(itertools.islice(iterable, n, None), default)
13 changes: 12 additions & 1 deletion test_utils/test_utils.py
Expand Up @@ -200,6 +200,7 @@ def make_field(
message: wrappers.MessageType = None,
enum: wrappers.EnumType = None,
meta: metadata.Metadata = None,
oneof: str = None,
**kwargs
) -> wrappers.Field:
T = desc.FieldDescriptorProto.Type
Expand All @@ -223,11 +224,13 @@ def make_field(
number=number,
**kwargs
)

return wrappers.Field(
field_pb=field_pb,
enum=enum,
message=message,
meta=meta or metadata.Metadata(),
oneof=oneof,
)


Expand Down Expand Up @@ -322,20 +325,28 @@ def make_enum_pb2(
def make_message_pb2(
name: str,
fields: tuple = (),
oneof_decl: tuple = (),
**kwargs
) -> desc.DescriptorProto:
return desc.DescriptorProto(name=name, field=fields, **kwargs)
return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs)


def make_field_pb2(name: str, number: int,
type: int = 11, # 11 == message
type_name: str = None,
oneof_index: int = None
) -> desc.FieldDescriptorProto:
return desc.FieldDescriptorProto(
name=name,
number=number,
type=type,
type_name=type_name,
oneof_index=oneof_index,
)

def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto:
return desc.OneofDescriptorProto(
name=name,
)


Expand Down
40 changes: 40 additions & 0 deletions tests/unit/schema/test_api.py
Expand Up @@ -34,6 +34,7 @@
make_file_pb2,
make_message_pb2,
make_naming,
make_oneof_pb2,
)


Expand Down Expand Up @@ -239,6 +240,45 @@ def test_proto_keyword_fname():
}


def test_proto_oneof():
# Put together a couple of minimal protos.
fd = (
make_file_pb2(
name='dep.proto',
package='google.dep',
messages=(make_message_pb2(name='ImportedMessage', fields=()),),
),
make_file_pb2(
name='foo.proto',
package='google.example.v1',
messages=(
make_message_pb2(name='Foo', fields=()),
make_message_pb2(
name='Bar',
fields=(
make_field_pb2(name='imported_message', number=1,
type_name='.google.dep.ImportedMessage',
oneof_index=0),
make_field_pb2(
name='primitive', number=2, type=1, oneof_index=0),
),
oneof_decl=(
make_oneof_pb2(name="value_type"),
)
)
)
)
)

# Create an API with those protos.
api_schema = api.API.build(fd, package='google.example.v1')
proto = api_schema.protos['foo.proto']
assert proto.names == {'imported_message', 'Bar', 'primitive', 'Foo'}
oneofs = proto.messages["google.example.v1.Bar"].oneofs
assert len(oneofs) == 1
assert "value_type" in oneofs.keys()


def test_proto_names_import_collision():
# Put together a couple of minimal protos.
fd = (
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/schema/wrappers/test_field.py
Expand Up @@ -153,6 +153,13 @@ def test_mock_value_int():
assert field.mock_value == '728'


def test_oneof():
REP = descriptor_pb2.FieldDescriptorProto.Label.Value('LABEL_REPEATED')

field = make_field(oneof="oneof_name")
assert field.oneof == "oneof_name"


def test_mock_value_float():
field = make_field(name='foo_bar', type='TYPE_DOUBLE')
assert field.mock_value == '0.728'
Expand Down

0 comments on commit be5a847

Please sign in to comment.