Skip to content

Commit

Permalink
fix: numerous small fixes to allow bigtable-admin (#660)
Browse files Browse the repository at this point in the history
Includes:
* tweaked logic around defining recursive message types
* more sophisticated logic for generating unit tests using recursive
message types
* flattened map-y fields are handled properly
* fixed a corner case where a method has a third-party request object
and flattened fields
  • Loading branch information
software-dov committed Oct 19, 2020
1 parent d2bc4ae commit 09692c4
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 34 deletions.
Expand Up @@ -333,7 +333,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}(**request)
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
elif not request:
request = {{ method.input.ident }}()
request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %})
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
# Minor optimization to avoid making a copy if the user passes
Expand All @@ -344,16 +344,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}
{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
{% if method.flattened_fields and method.input.ident.package == method.ident.package -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
{# Map-y fields can be _updated_, however #}
{%- for key, field in method.flattened_fields.items() if field.map and method.input.ident.package == method.ident.package %}

if {{ field.name }}:
request.{{ key }}.update({{ field.name }})
{%- endfor %}
{# And list-y fields can be _extended_ -#}
{%- for key, field in method.flattened_fields.items() if field.repeated and not field.map and method.input.ident.package == method.ident.package %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
Expand Down
Expand Up @@ -297,6 +297,10 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ m
for message in response:
assert isinstance(message, {{ method.output.ident }})
{% else -%}
{% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %}
{# Cheeser assertion to force code coverage for bad paginated methods #}
assert response.raw_page is response
{% endif %}
assert isinstance(response, {{ method.client_output.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not field.oneof or field.proto3_optional %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
Expand Down
15 changes: 9 additions & 6 deletions gapic/schema/metadata.py
Expand Up @@ -242,12 +242,15 @@ def rel(self, address: 'Address') -> str:
# It is possible that a field references a message that has
# not yet been declared. If so, send its name enclosed in quotes
# (a string) instead.
if self.module_path > address.module_path or self == address:
return f"'{'.'.join(self.parent + (self.name,))}'"

# This is a message in the same module, already declared.
# Send its name.
return '.'.join(self.parent + (self.name,))
#
# Note: this is a conservative construction; it generates a stringy
# identifier all the time when it may be possible to use a regular
# module lookup.
# On the other hand, there's no reason _not_ to use a stringy
# identifier. It is guaranteed to work all the time because
# it bumps name resolution until a time when all types in a module
# are guaranteed to be fully defined.
return f"'{'.'.join(self.parent + (self.name,))}'"

# Return the usual `module.Name`.
return str(self)
Expand Down
13 changes: 5 additions & 8 deletions gapic/schema/wrappers.py
Expand Up @@ -55,9 +55,6 @@ class Field:
)
oneof: Optional[str] = None

# Arbitrary cap set via heuristic rule of thumb.
MAX_MOCK_DEPTH: int = 20

def __getattr__(self, name):
return getattr(self.field_pb, name)

Expand Down Expand Up @@ -93,17 +90,16 @@ def map(self) -> bool:

@utils.cached_property
def mock_value(self) -> str:
depth = 0
visited_fields: Set["Field"] = set()
stack = [self]
answer = "{}"
while stack:
expr = stack.pop()
answer = answer.format(expr.inner_mock(stack, depth))
depth += 1
answer = answer.format(expr.inner_mock(stack, visited_fields))

return answer

def inner_mock(self, stack, depth):
def inner_mock(self, stack, visited_fields):
"""Return a repr of a valid, usually truthy mock value."""
# For primitives, send a truthy value computed from the
# field name.
Expand Down Expand Up @@ -137,10 +133,11 @@ def inner_mock(self, stack, depth):
and isinstance(self.type, MessageType)
and len(self.type.fields)
# Nested message types need to terminate eventually
and depth < self.MAX_MOCK_DEPTH
and self not in visited_fields
):
sub = next(iter(self.type.fields.values()))
stack.append(sub)
visited_fields.add(self)
# Don't do the recursive rendering here, just set up
# where the nested value should go with the double {}.
answer = f'{self.type.ident}({sub.name}={{}})'
Expand Down
Expand Up @@ -169,7 +169,8 @@ class {{ service.async_client_name }}:
{% if method.flattened_fields -%}
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
has_flattened_params = any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}])
if request is not None and has_flattened_params:
raise ValueError('If the `request` argument is set, then none of '
'the individual field arguments should be set.')

Expand All @@ -181,23 +182,29 @@ class {{ service.async_client_name }}:
request = {{ method.input.ident }}(**request)
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
elif not request:
request = {{ method.input.ident }}()
request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %})
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
{% if method.flattened_fields and method.input.ident.package == method.ident.package -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
{%- for key, field in method.flattened_fields.items() if not field.repeated and method.input.ident.package == method.ident.package %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
{# Map-y fields can be _updated_, however #}
{%- for key, field in method.flattened_fields.items() if field.map and method.input.ident.package == method.ident.package %}

if {{ field.name }}:
request.{{ key }}.update({{ field.name }})
{%- endfor %}
{# And list-y fields can be _extended_ -#}
{%- for key, field in method.flattened_fields.items() if field.repeated and not field.map and method.input.ident.package == method.ident.package %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
Expand Down
Expand Up @@ -345,7 +345,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}(**request)
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
elif not request:
request = {{ method.input.ident }}()
request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %})
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
# Minor optimization to avoid making a copy if the user passes
Expand All @@ -357,16 +357,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
{% if method.flattened_fields and method.input.ident.package == method.ident.package -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated or method.input.ident.package != method.ident.package) %}
{%- for key, field in method.flattened_fields.items() if not field.repeated and method.input.ident.package == method.ident.package %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if field.repeated %}
{# Map-y fields can be _updated_, however #}
{%- for key, field in method.flattened_fields.items() if field.map and method.input.ident.package == method.ident.package %}

if {{ field.name }}:
request.{{ key }}.update({{ field.name }})
{%- endfor %}
{# And list-y fields can be _extended_ -#}
{%- for key, field in method.flattened_fields.items() if field.repeated and not field.map and method.input.ident.package == method.ident.package %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
Expand Down
Expand Up @@ -398,6 +398,10 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ m
for message in response:
assert isinstance(message, {{ method.output.ident }})
{% else -%}
{% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %}
{# Cheeser assertion to force code coverage for bad paginated methods #}
assert response.raw_page is response
{% endif %}
assert isinstance(response, {{ method.client_output.ident }})
{% for field in method.output.fields.values() | rejectattr('message') -%}{% if not field.oneof or field.proto3_optional %}
{% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#}
Expand All @@ -417,15 +421,15 @@ def test_{{ method.name|snake_case }}_from_dict():


@pytest.mark.asyncio
async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio'):
async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio', request_type={{ method.input.ident }}):
client = {{ service.async_client_name }}(
credentials=credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = {{ method.input.ident }}()
request = request_type()
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -474,7 +478,7 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
{% if method.client_streaming %}
assert next(args[0]) == request
{% else %}
assert args[0] == request
assert args[0] == {{ method.input.ident }}()
{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -500,6 +504,11 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio
{% endif %}


@pytest.mark.asyncio
async def test_{{ method.name|snake_case }}_async_from_dict():
await test_{{ method.name|snake_case }}_async(request_type=dict)


{% if method.field_headers and not method.client_streaming %}
def test_{{ method.name|snake_case }}_field_headers():
client = {{ service.client_name }}(
Expand Down Expand Up @@ -592,7 +601,7 @@ async def test_{{ method.name|snake_case }}_field_headers_async():
{% endif %}

{% if method.ident.package != method.input.ident.package %}
def test_{{ method.name|snake_case }}_from_dict():
def test_{{ method.name|snake_case }}_from_dict_foreign():
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/schema/test_metadata.py
Expand Up @@ -70,7 +70,7 @@ def test_address_rel():
addr = metadata.Address(package=('foo', 'bar'), module='baz', name='Bacon')
assert addr.rel(
metadata.Address(package=('foo', 'bar'), module='baz'),
) == 'Bacon'
) == "'Bacon'"


def test_address_rel_other():
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/schema/wrappers/test_field.py
Expand Up @@ -19,6 +19,7 @@
from google.api import field_behavior_pb2
from google.protobuf import descriptor_pb2

from gapic.schema import api
from gapic.schema import metadata
from gapic.schema import wrappers

Expand Down Expand Up @@ -250,6 +251,35 @@ def test_mock_value_message():
assert field.mock_value == 'bogus.Message(foo=324)'


def test_mock_value_recursive():
# The elaborate setup is an unfortunate requirement.
file_pb = descriptor_pb2.FileDescriptorProto(
name="turtle.proto",
package="animalia.chordata.v2",
message_type=(
descriptor_pb2.DescriptorProto(
# It's turtles all the way down ;)
name="Turtle",
field=(
descriptor_pb2.FieldDescriptorProto(
name="turtle",
type="TYPE_MESSAGE",
type_name=".animalia.chordata.v2.Turtle",
number=1,
),
),
),
),
)
my_api = api.API.build([file_pb], package="animalia.chordata.v2")
turtle_field = my_api.messages["animalia.chordata.v2.Turtle"].fields["turtle"]

# If not handled properly, this will run forever and eventually OOM.
actual = turtle_field.mock_value
expected = "ac_turtle.Turtle(turtle=ac_turtle.Turtle(turtle=turtle.Turtle(turtle=None)))"
assert actual == expected


def test_field_name_kword_disambiguation():
from_field = make_field(
name="from",
Expand Down

0 comments on commit 09692c4

Please sign in to comment.