From 50d61afd30b021835fe898e41b783f4d04acff09 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi <25311427+vam-google@users.noreply.github.com> Date: Mon, 25 Oct 2021 16:05:12 -0700 Subject: [PATCH] fix: Fix rest transport logic (#1039) * fix: Fix rest transport logic This includes 1) Do not include asyncio tests in the generated tests, because rest transport does not have asynio client. 2) Generate body field mock values for generated tests (otherwise grpc transcodding logic would fail). 3) Make `always_use_jwt_access=True` default for rest clients (grpc already does that) to match expected calls in generated tests. 4) Fix mypy errors for `AuthorizedSession` by ignoring it 5) Include operations_v1 conditionally, only if the client has lro There are few more fixes left, which are expected to be fixed in separate PRs. 1) `message->to_dict->message` roundrtip problem for int64 types is expected to be fixed by https://github.com/googleapis/proto-plus-python/pull/267 2) builtins conflicts (`license_` vs `license` as field name) is expected to be fixed by a TBD PR * fix integration tests --- gapic/schema/wrappers.py | 13 +++++- .../%sub/services/%service/client.py.j2 | 2 - .../services/%service/transports/rest.py.j2 | 16 ++++--- .../%name_%version/%sub/test_%service.py.j2 | 11 ++++- tests/unit/schema/wrappers/test_method.py | 44 ++++++++++++++++++- 5 files changed, 74 insertions(+), 12 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index aecda19f44..24422244a9 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -925,9 +925,21 @@ def query_params(self) -> Set[str]: return set(self.input.fields) - params + @property + def body_fields(self) -> Mapping[str, Field]: + bindings = self.http_options + if bindings and bindings[0].body and bindings[0].body != "*": + return self._fields_mapping([bindings[0].body]) + return {} + # TODO(yon-mg): refactor as there may be more than one method signature @utils.cached_property def flattened_fields(self) -> Mapping[str, Field]: + signatures = self.options.Extensions[client_pb2.method_signature] + return self._fields_mapping(signatures) + + # TODO(yon-mg): refactor as there may be more than one method signature + def _fields_mapping(self, signatures) -> Mapping[str, Field]: """Return the signature defined for this method.""" cross_pkg_request = self.input.ident.package != self.ident.package @@ -946,7 +958,6 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: yield name, field - signatures = self.options.Extensions[client_pb2.method_signature] answer: Dict[str, Field] = collections.OrderedDict( name_and_field for sig in signatures diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 02bfe76135..809f728dd1 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -306,9 +306,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, - {% if "grpc" in opts.transport %} always_use_jwt_access=True, - {% endif %} ) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 994e30d9e4..d85695b76f 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -1,14 +1,16 @@ -from google.auth.transport.requests import AuthorizedSession +from google.auth.transport.requests import AuthorizedSession # type: ignore import json # type: ignore import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.api_core import exceptions as core_exceptions # type: ignore -from google.api_core import retry as retries # type: ignore -from google.api_core import rest_helpers # type: ignore -from google.api_core import path_template # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import rest_helpers # type: ignore +from google.api_core import path_template # type: ignore +from google.api_core import gapic_v1 # type: ignore +{% if service.has_lro %} from google.api_core import operations_v1 +{% endif %} from requests import __version__ as requests_version from typing import Callable, Dict, Optional, Sequence, Tuple, Union import warnings diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index d4ec2c3142..53ef176529 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1106,7 +1106,14 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type ) # send a request that will satisfy transcoding - request = request_type({{ method.http_options[0].sample_request}}) + request_init = {{ method.http_options[0].sample_request}} + {% for field in method.body_fields.values() %} + {% if not field.oneof or field.proto3_optional %} + {# ignore oneof fields that might conflict with sample_request #} + request_init["{{ field.name }}"] = {{ field.mock_value }} + {% endif %} + {% endfor %} + request = request_type(request_init) {% if method.client_streaming %} requests = [request] {% endif %} @@ -2419,6 +2426,7 @@ async def test_test_iam_permissions_from_dict_async(): {% endif %} +{% if 'grpc' in opts.transport %} @pytest.mark.asyncio async def test_transport_close_async(): client = {{ service.async_client_name }}( @@ -2429,6 +2437,7 @@ async def test_transport_close_async(): async with client: close.assert_not_called() close.assert_called_once() +{% endif %} def test_transport_close(): transports = { diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index c6a81d9128..d377375036 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -330,6 +330,35 @@ def test_method_path_params_no_http_rule(): assert method.path_params == [] +def test_body_fields(): + http_rule = http_pb2.HttpRule( + post='/v1/{arms_shape=arms/*}/squids', + body='mantle' + ) + + mantle_stuff = make_field(name='mantle_stuff', type=9) + message = make_message('Mantle', fields=(mantle_stuff,)) + mantle = make_field('mantle', type=11, type_name='Mantle', message=message) + arms_shape = make_field('arms_shape', type=9) + input_message = make_message('Squid', fields=(mantle, arms_shape)) + method = make_method( + 'PutSquid', input_message=input_message, http_rule=http_rule) + assert set(method.body_fields) == {'mantle'} + mock_value = method.body_fields['mantle'].mock_value + assert mock_value == "baz.Mantle(mantle_stuff='mantle_stuff_value')" + + +def test_body_fields_no_body(): + http_rule = http_pb2.HttpRule( + post='/v1/{arms_shape=arms/*}/squids', + ) + + method = make_method( + 'PutSquid', http_rule=http_rule) + + assert not method.body_fields + + def test_method_http_options(): verbs = [ 'get', @@ -363,7 +392,7 @@ def test_method_http_options_no_http_rule(): assert method.path_params == [] -def test_method_http_options_body(): +def test_method_http_options_body_star(): http_rule = http_pb2.HttpRule( post='/v1/{parent=projects/*}/topics', body='*' @@ -376,6 +405,19 @@ def test_method_http_options_body(): }] +def test_method_http_options_body_field(): + http_rule = http_pb2.HttpRule( + post='/v1/{parent=projects/*}/topics', + body='body_field' + ) + method = make_method('DoSomething', http_rule=http_rule) + assert [dataclasses.asdict(http) for http in method.http_options] == [{ + 'method': 'post', + 'uri': '/v1/{parent=projects/*}/topics', + 'body': 'body_field' + }] + + def test_method_http_options_additional_bindings(): http_rule = http_pb2.HttpRule( post='/v1/{parent=projects/*}/topics',