Skip to content

Commit

Permalink
fix: Fix rest transport logic (#1039)
Browse files Browse the repository at this point in the history
* fix: Fix rest transport logic

This includes
1) Do not include asyncio tests in the generated tests, because rest transport does not have asynio client.
2) Generate body field mock values for generated tests (otherwise grpc transcodding logic would fail).
3) Make `always_use_jwt_access=True` default for rest clients (grpc already does that) to match expected calls in generated tests.
4) Fix mypy errors for `AuthorizedSession` by ignoring it
5) Include operations_v1 conditionally, only if the client has lro

There are few more fixes left, which are expected to be fixed in separate PRs.

1) `message->to_dict->message` roundrtip problem for int64 types is expected to be fixed by googleapis/proto-plus-python#267
2) builtins conflicts (`license_` vs `license` as field name) is expected to be fixed by a TBD PR

* fix integration tests
  • Loading branch information
vam-google committed Oct 25, 2021
1 parent a0e25c8 commit 50d61af
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 12 deletions.
13 changes: 12 additions & 1 deletion gapic/schema/wrappers.py
Expand Up @@ -925,9 +925,21 @@ def query_params(self) -> Set[str]:

return set(self.input.fields) - params

@property
def body_fields(self) -> Mapping[str, Field]:
bindings = self.http_options
if bindings and bindings[0].body and bindings[0].body != "*":
return self._fields_mapping([bindings[0].body])
return {}

# TODO(yon-mg): refactor as there may be more than one method signature
@utils.cached_property
def flattened_fields(self) -> Mapping[str, Field]:
signatures = self.options.Extensions[client_pb2.method_signature]
return self._fields_mapping(signatures)

# TODO(yon-mg): refactor as there may be more than one method signature
def _fields_mapping(self, signatures) -> Mapping[str, Field]:
"""Return the signature defined for this method."""
cross_pkg_request = self.input.ident.package != self.ident.package

Expand All @@ -946,7 +958,6 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]:

yield name, field

signatures = self.options.Extensions[client_pb2.method_signature]
answer: Dict[str, Field] = collections.OrderedDict(
name_and_field
for sig in signatures
Expand Down
Expand Up @@ -306,9 +306,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
{% if "grpc" in opts.transport %}
always_use_jwt_access=True,
{% endif %}
)


Expand Down
@@ -1,14 +1,16 @@
from google.auth.transport.requests import AuthorizedSession
from google.auth.transport.requests import AuthorizedSession # type: ignore
import json # type: ignore
import grpc # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import rest_helpers # type: ignore
from google.api_core import path_template # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import rest_helpers # type: ignore
from google.api_core import path_template # type: ignore
from google.api_core import gapic_v1 # type: ignore
{% if service.has_lro %}
from google.api_core import operations_v1
{% endif %}
from requests import __version__ as requests_version
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import warnings
Expand Down
Expand Up @@ -1106,7 +1106,14 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type
)

# send a request that will satisfy transcoding
request = request_type({{ method.http_options[0].sample_request}})
request_init = {{ method.http_options[0].sample_request}}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
{% endif %}
{% endfor %}
request = request_type(request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -2419,6 +2426,7 @@ async def test_test_iam_permissions_from_dict_async():

{% endif %}

{% if 'grpc' in opts.transport %}
@pytest.mark.asyncio
async def test_transport_close_async():
client = {{ service.async_client_name }}(
Expand All @@ -2429,6 +2437,7 @@ async def test_transport_close_async():
async with client:
close.assert_not_called()
close.assert_called_once()
{% endif %}

def test_transport_close():
transports = {
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/schema/wrappers/test_method.py
Expand Up @@ -330,6 +330,35 @@ def test_method_path_params_no_http_rule():
assert method.path_params == []


def test_body_fields():
http_rule = http_pb2.HttpRule(
post='/v1/{arms_shape=arms/*}/squids',
body='mantle'
)

mantle_stuff = make_field(name='mantle_stuff', type=9)
message = make_message('Mantle', fields=(mantle_stuff,))
mantle = make_field('mantle', type=11, type_name='Mantle', message=message)
arms_shape = make_field('arms_shape', type=9)
input_message = make_message('Squid', fields=(mantle, arms_shape))
method = make_method(
'PutSquid', input_message=input_message, http_rule=http_rule)
assert set(method.body_fields) == {'mantle'}
mock_value = method.body_fields['mantle'].mock_value
assert mock_value == "baz.Mantle(mantle_stuff='mantle_stuff_value')"


def test_body_fields_no_body():
http_rule = http_pb2.HttpRule(
post='/v1/{arms_shape=arms/*}/squids',
)

method = make_method(
'PutSquid', http_rule=http_rule)

assert not method.body_fields


def test_method_http_options():
verbs = [
'get',
Expand Down Expand Up @@ -363,7 +392,7 @@ def test_method_http_options_no_http_rule():
assert method.path_params == []


def test_method_http_options_body():
def test_method_http_options_body_star():
http_rule = http_pb2.HttpRule(
post='/v1/{parent=projects/*}/topics',
body='*'
Expand All @@ -376,6 +405,19 @@ def test_method_http_options_body():
}]


def test_method_http_options_body_field():
http_rule = http_pb2.HttpRule(
post='/v1/{parent=projects/*}/topics',
body='body_field'
)
method = make_method('DoSomething', http_rule=http_rule)
assert [dataclasses.asdict(http) for http in method.http_options] == [{
'method': 'post',
'uri': '/v1/{parent=projects/*}/topics',
'body': 'body_field'
}]


def test_method_http_options_additional_bindings():
http_rule = http_pb2.HttpRule(
post='/v1/{parent=projects/*}/topics',
Expand Down

0 comments on commit 50d61af

Please sign in to comment.