Skip to content

Commit

Permalink
fix: update paging implementation to handle unconventional pagination (
Browse files Browse the repository at this point in the history
…#750)

* fix: update paging implementation to handle unconventional pagination

* fix: typing errors, mypy cli update

* fix: mypy cli flag

* fix: delete __init__.py, remove -p mypy flag

* fix: clearing up statements, tests, minor bug in filter usage

* fix: wrong generated type hints
  • Loading branch information
yon-mg committed Feb 4, 2021
1 parent 4077b45 commit eaac3e6
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 19 deletions.
17 changes: 13 additions & 4 deletions gapic/schema/wrappers.py
Expand Up @@ -866,13 +866,22 @@ def paged_result_field(self) -> Optional[Field]:
"""Return the response pagination field if the method is paginated."""
# If the request field lacks any of the expected pagination fields,
# then the method is not paginated.
for page_field in ((self.input, int, 'page_size'),
(self.input, str, 'page_token'),

# The request must have page_token and next_page_token as they keep track of pages
for source, source_type, name in ((self.input, str, 'page_token'),
(self.output, str, 'next_page_token')):
field = page_field[0].fields.get(page_field[2], None)
if not field or field.type != page_field[1]:
field = source.fields.get(name, None)
if not field or field.type != source_type:
return None

# The request must have max_results or page_size
page_fields = (self.input.fields.get('max_results', None),
self.input.fields.get('page_size', None))
page_field_size = next(
(field for field in page_fields if field), None)
if not page_field_size or page_field_size.type != int:
return None

# Return the first repeated field.
for field in self.output.fields.values():
if field.repeated:
Expand Down
Expand Up @@ -6,7 +6,7 @@
{# This lives within the loop in order to ensure that this template
is empty if there are no paged methods.
-#}
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional

{% filter sort_lines -%}
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
Expand Down Expand Up @@ -68,14 +68,25 @@ class {{ method.name }}Pager:
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

{% if method.paged_result_field.map %}
def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.type.fields.get('value').ident }}]]:
for page in self.pages:
yield from page.{{ method.paged_result_field.name}}.items()

def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]:
return self._response.items.get(key)
{% else %}
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:
for page in self.pages:
yield from page.{{ method.paged_result_field.name }}
{% endif %}

def __repr__(self) -> str:
return '{0}<{1!r}>'.format(self.__class__.__name__, self._response)


{# TODO(yon-mg): remove on rest async transport impl #}
{% if 'grpc' in opts.transport %}
class {{ method.name }}AsyncPager:
"""A pager for iterating through ``{{ method.name|snake_case }}`` requests.

Expand Down Expand Up @@ -138,5 +149,6 @@ class {{ method.name }}AsyncPager:
def __repr__(self) -> str:
return '{0}<{1!r}>'.format(self.__class__.__name__, self._response)

{% endif %}
{% endfor %}
{% endblock %}
Expand Up @@ -184,11 +184,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
# TODO(yon-mg): handle nested fields corerctly rather than using only top level fields
# not required for GCE
query_params = {
{% filter sort_lines -%}
{%- for field in method.query_params %}
{%- for field in method.query_params | sort%}
'{{ field|camel_case }}': request.{{ field }},
{%- endfor %}
{% endfilter -%}
}
# TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here
# discards default values
Expand Down
Expand Up @@ -1020,7 +1020,7 @@ def test_{{ method.name|snake_case }}_raw_page_lro():
assert response.raw_page is response
{% endif %} {#- method.paged_result_field #}

{% endfor -%} {#- method in methods #}
{% endfor -%} {#- method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport -%}
def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}):
Expand Down Expand Up @@ -1162,7 +1162,126 @@ def test_{{ method.name|snake_case }}_rest_flattened_error():
)


{% endfor -%}
{% if method.paged_result_field %}
def test_{{ method.name|snake_case }}_pager():
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
)

# Mock the http request call within the method and fake a response.
with mock.patch.object(Session, 'request') as req:
# Set the response as a series of pages
{% if method.paged_result_field.map%}
response = (
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={
'a':{{ method.paged_result_field.type.fields.get('value').ident }}(),
'b':{{ method.paged_result_field.type.fields.get('value').ident }}(),
'c':{{ method.paged_result_field.type.fields.get('value').ident }}(),
},
next_page_token='abc',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={},
next_page_token='def',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={
'g':{{ method.paged_result_field.type.fields.get('value').ident }}(),
},
next_page_token='ghi',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={
'h':{{ method.paged_result_field.type.fields.get('value').ident }}(),
'i':{{ method.paged_result_field.type.fields.get('value').ident }}(),
},
),
)
{% else %}
response = (
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[
{{ method.paged_result_field.type.ident }}(),
{{ method.paged_result_field.type.ident }}(),
{{ method.paged_result_field.type.ident }}(),
],
next_page_token='abc',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[],
next_page_token='def',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[
{{ method.paged_result_field.type.ident }}(),
],
next_page_token='ghi',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[
{{ method.paged_result_field.type.ident }}(),
{{ method.paged_result_field.type.ident }}(),
],
),
)
{% endif %}
# Two responses for two calls
response = response + response

# Wrap the values into proper Response objs
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
req.side_effect = return_values

metadata = ()
{% if method.field_headers -%}
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata((
{%- for field_header in method.field_headers %}
{%- if not method.client_streaming %}
('{{ field_header }}', ''),
{%- endif %}
{%- endfor %}
)),
)
{% endif -%}
pager = client.{{ method.name|snake_case }}(request={})

assert pager._metadata == metadata

{% if method.paged_result_field.map %}
assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }})
assert pager.get('h') is None
{% endif %}

results = list(pager)
assert len(results) == 6
{% if method.paged_result_field.map %}
assert all(
isinstance(i, tuple)
for i in results)
for result in results:
assert isinstance(result, tuple)
assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }})

assert pager.get('a') is None
assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }})
{% else %}
assert all(isinstance(i, {{ method.paged_result_field.type.ident }})
for i in results)
{% endif %}

pages = list(client.{{ method.name|snake_case }}(request={}).pages)
for page_, token in zip(pages, ['abc','def','ghi', '']):
assert page_.raw_page.next_page_token == token


{% endif %} {# paged methods #}
{% endfor -%} {#- method in methods for rest #}
def test_credentials_transport_error():
# It is an error to provide credentials and a transport instance.
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(
Expand Down
50 changes: 41 additions & 9 deletions tests/unit/schema/wrappers/test_method.py
Expand Up @@ -66,19 +66,38 @@ def test_method_client_output_empty():

def test_method_client_output_paged():
paged = make_field(name='foos', message=make_message('Foo'), repeated=True)
parent = make_field(name='parent', type=9) # str
page_size = make_field(name='page_size', type=5) # int
page_token = make_field(name='page_token', type=9) # str

input_msg = make_message(name='ListFoosRequest', fields=(
make_field(name='parent', type=9), # str
make_field(name='page_size', type=5), # int
make_field(name='page_token', type=9), # str
parent,
page_size,
page_token,
))
output_msg = make_message(name='ListFoosResponse', fields=(
paged,
make_field(name='next_page_token', type=9), # str
))
method = make_method('ListFoos',
input_message=input_msg,
output_message=output_msg,
)
method = make_method(
'ListFoos',
input_message=input_msg,
output_message=output_msg,
)
assert method.paged_result_field == paged
assert method.client_output.ident.name == 'ListFoosPager'

max_results = make_field(name='max_results', type=5) # int
input_msg = make_message(name='ListFoosRequest', fields=(
parent,
max_results,
page_token,
))
method = make_method(
'ListFoos',
input_message=input_msg,
output_message=output_msg,
)
assert method.paged_result_field == paged
assert method.client_output.ident.name == 'ListFoosPager'

Expand Down Expand Up @@ -123,6 +142,19 @@ def test_method_paged_result_field_no_page_field():
)
assert method.paged_result_field is None

method = make_method(
name='Foo',
input_message=make_message(
name='FooRequest',
fields=(make_field(name='page_token', type=9),) # str
),
output_message=make_message(
name='FooResponse',
fields=(make_field(name='next_page_token', type=9),) # str
)
)
assert method.paged_result_field is None


def test_method_paged_result_ref_types():
input_msg = make_message(
Expand All @@ -139,7 +171,7 @@ def test_method_paged_result_ref_types():
name='ListMolluscsResponse',
fields=(
make_field(name='molluscs', message=mollusc_msg, repeated=True),
make_field(name='next_page_token', type=9)
make_field(name='next_page_token', type=9) # str
),
module='mollusc'
)
Expand Down Expand Up @@ -207,7 +239,7 @@ def test_flattened_ref_types():


def test_method_paged_result_primitive():
paged = make_field(name='squids', type=9, repeated=True)
paged = make_field(name='squids', type=9, repeated=True) # str
input_msg = make_message(
name='ListSquidsRequest',
fields=(
Expand Down

0 comments on commit eaac3e6

Please sign in to comment.