Skip to content

Commit

Permalink
fix: tweak oneof detection (#505)
Browse files Browse the repository at this point in the history
Oneof detection and assignment to fields is tricky.
This patch fixes detection of oneof fields,
fixes uses in generated clients
and tweaks generated tests to use them correctly.
  • Loading branch information
software-dov committed Jul 10, 2020
1 parent ffccce7 commit 1632e25
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gapic/schema/api.py
Expand Up @@ -615,7 +615,7 @@ def _get_fields(self,
# `_load_message` method.
answer: Dict[str, wrappers.Field] = collections.OrderedDict()
for i, field_pb in enumerate(field_pbs):
is_oneof = oneofs and field_pb.oneof_index > 0
is_oneof = oneofs and field_pb.HasField('oneof_index')
oneof_name = nth(
(oneofs or {}).keys(),
field_pb.oneof_index
Expand Down
22 changes: 22 additions & 0 deletions gapic/schema/wrappers.py
Expand Up @@ -239,6 +239,15 @@ def __hash__(self):
# Identity is sufficiently unambiguous.
return hash(self.ident)

def oneof_fields(self, include_optional=False):
oneof_fields = collections.defaultdict(list)
for field in self.fields.values():
# Only include proto3 optional oneofs if explicitly looked for.
if field.oneof and not field.proto3_optional or include_optional:
oneof_fields[field.oneof].append(field)

return oneof_fields

@utils.cached_property
def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
answer = tuple(
Expand Down Expand Up @@ -583,6 +592,15 @@ def client_output(self):
def client_output_async(self):
return self._client_output(enable_asyncio=True)

def flattened_oneof_fields(self, include_optional=False):
oneof_fields = collections.defaultdict(list)
for field in self.flattened_fields.values():
# Only include proto3 optional oneofs if explicitly looked for.
if field.oneof and not field.proto3_optional or include_optional:
oneof_fields[field.oneof].append(field)

return oneof_fields

def _client_output(self, enable_asyncio: bool):
"""Return the output from the client layer.
Expand Down Expand Up @@ -685,6 +703,10 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]:

return answer

@utils.cached_property
def flattened_field_to_key(self):
return {field.name: key for key, field in self.flattened_fields.items()}

@utils.cached_property
def legacy_flattened_fields(self) -> Mapping[str, Field]:
"""Return the legacy flattening interface: top level fields only,
Expand Down
Expand Up @@ -288,9 +288,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
call.return_value = iter([{{ method.output.ident }}()])
{% else -%}
call.return_value = {{ method.output.ident }}(
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %}
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not field.oneof or field.proto3_optional %}
{{ field.name }}={{ field.mock_value }},
{% endif %}{%- endfor %}
{#- This is a hack to only pick one field #}
{%- for oneof_fields in method.output.oneof_fields().values() %}
{% with field = oneof_fields[0] %}
{{ field.name }}={{ field.mock_value }},
{%- endwith %}
{%- endfor %}
)
{% endif -%}
{% if method.client_streaming %}
Expand Down Expand Up @@ -567,9 +573,15 @@ def test_{{ method.name|snake_case }}_flattened():
# request object values.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
{% for key, field in method.flattened_fields.items() -%}
{% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %}
assert args[0].{{ key }} == {{ field.mock_value }}
{% endfor %}
{% endif %}{% endfor %}
{%- for oneofs in method.flattened_oneof_fields().values() %}
{%- with field = oneofs[-1] %}
assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }}
{%- endwith %}
{%- endfor %}



def test_{{ method.name|snake_case }}_flattened_error():
Expand Down Expand Up @@ -640,9 +652,14 @@ async def test_{{ method.name|snake_case }}_flattened_async():
# request object values.
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
{% for key, field in method.flattened_fields.items() -%}
{% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %}
assert args[0].{{ key }} == {{ field.mock_value }}
{% endfor %}
{% endif %}{% endfor %}
{%- for oneofs in method.flattened_oneof_fields().values() %}
{%- with field = oneofs[-1] %}
assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }}
{%- endwith %}
{%- endfor %}


@pytest.mark.asyncio
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/schema/wrappers/test_message.py
Expand Up @@ -235,3 +235,26 @@ def test_field_map():
entry_field = make_field('foos', message=entry_msg, repeated=True)
assert entry_msg.map
assert entry_field.map


def test_oneof_fields():
mass_kg = make_field(name="mass_kg", oneof="mass", type=5)
mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5)
length_m = make_field(name="length_m", oneof="length", type=5)
length_f = make_field(name="length_f", oneof="length", type=5)
color = make_field(name="color", type=5)
request = make_message(
name="CreateMolluscReuqest",
fields=(
mass_kg,
mass_lbs,
length_m,
length_f,
color,
),
)
actual_oneofs = request.oneof_fields()
expected_oneofs = {
"mass": [mass_kg, mass_lbs],
"length": [length_m, length_f],
}
56 changes: 56 additions & 0 deletions tests/unit/schema/wrappers/test_method.py
Expand Up @@ -364,3 +364,59 @@ def test_method_legacy_flattened_fields():
])

assert method.legacy_flattened_fields == expected


def test_flattened_oneof_fields():
mass_kg = make_field(name="mass_kg", oneof="mass", type=5)
mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5)

length_m = make_field(name="length_m", oneof="length", type=5)
length_f = make_field(name="length_f", oneof="length", type=5)

color = make_field(name="color", type=5)
mantle = make_field(
name="mantle",
message=make_message(
name="Mantle",
fields=(
make_field(name="color", type=5),
mass_kg,
mass_lbs,
),
),
)
request = make_message(
name="CreateMolluscReuqest",
fields=(
length_m,
length_f,
color,
mantle,
),
)
method = make_method(
name="CreateMollusc",
input_message=request,
signatures=[
"length_m,",
"length_f,",
"mantle.mass_kg,",
"mantle.mass_lbs,",
"color",
]
)

expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]}
actual = method.flattened_oneof_fields()
assert expected == actual

# Check this method too becasue the setup is a lot of work.
expected = {
"color": "color",
"length_m": "length_m",
"length_f": "length_f",
"mass_kg": "mantle.mass_kg",
"mass_lbs": "mantle.mass_lbs",
}
actual = method.flattened_field_to_key
assert expected == actual

0 comments on commit 1632e25

Please sign in to comment.