Skip to content

Commit

Permalink
fix: non-string required fields provide correct values (#1108)
Browse files Browse the repository at this point in the history
In generated unit tests checking behavior of required fields in REST transports, fields are given default values in accordance with the type of the field.
  • Loading branch information
software-dov committed Dec 13, 2021
1 parent 6a593f9 commit bc5f729
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 105 deletions.
78 changes: 70 additions & 8 deletions gapic/schema/wrappers.py
Expand Up @@ -32,7 +32,7 @@
import json
import re
from itertools import chain
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
from typing import (Any, cast, Dict, FrozenSet, Iterator, Iterable, List, Mapping,
ClassVar, Optional, Sequence, Set, Tuple, Union)
from google.api import annotations_pb2 # type: ignore
from google.api import client_pb2
Expand Down Expand Up @@ -757,17 +757,79 @@ class HttpRule:
uri: str
body: Optional[str]

@property
def path_fields(self) -> List[Tuple[str, str]]:
def path_fields(self, method: "~.Method") -> List[Tuple[Field, str, str]]:
"""return list of (name, template) tuples extracted from uri."""
return [(match.group("name"), match.group("template"))
input = method.input
return [(input.get_field(*match.group("name").split(".")), match.group("name"), match.group("template"))
for match in path_template._VARIABLE_RE.finditer(self.uri)]

@property
def sample_request(self) -> str:
def sample_request(self, method: "~.Method") -> str:
"""return json dict for sample request matching the uri template."""
sample = utils.sample_from_path_fields(self.path_fields)
return json.dumps(sample)

def sample_from_path_fields(paths: List[Tuple["wrappers.Field", str, str]]) -> Dict[Any, Any]:
"""Construct a dict for a sample request object from a list of fields
and template patterns.
Args:
paths: a list of tuples, each with a (segmented) name and a pattern.
Returns:
A new nested dict with the templates instantiated.
"""

request: Dict[str, Any] = {}

def _sample_names() -> Iterator[str]:
sample_num: int = 0
while True:
sample_num += 1
yield "sample{}".format(sample_num)

def add_field(obj, path, value):
"""Insert a field into a nested dict and return the (outer) dict.
Keys and sub-dicts are inserted if necessary to create the path.
e.g. if obj, as passed in, is {}, path is "a.b.c", and value is
"hello", obj will be updated to:
{'a':
{'b':
{
'c': 'hello'
}
}
}
Args:
obj: a (possibly) nested dict (parsed json)
path: a segmented field name, e.g. "a.b.c"
where each part is a dict key.
value: the value of the new key.
Returns:
obj, possibly modified
Raises:
AttributeError if the path references a key that is
not a dict.: e.g. path='a.b', obj = {'a':'abc'}
"""

segments = path.split('.')
leaf = segments.pop()
subfield = obj
for segment in segments:
subfield = subfield.setdefault(segment, {})
subfield[leaf] = value
return obj

sample_names = _sample_names()
for field, path, template in paths:
sample_value = re.sub(
r"(\*\*|\*)",
lambda n: next(sample_names),
template or '*'
) if field.type == PrimitiveType.build(str) else field.mock_value_original_type
add_field(request, path, sample_value)

return request

sample = sample_from_path_fields(self.path_fields(method))
return sample

@classmethod
def try_parse_http_rule(cls, http_rule) -> Optional['HttpRule']:
Expand Down
Expand Up @@ -170,7 +170,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if method.input.required_fields %}
__{{ method.name | snake_case }}_required_fields_default_values = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %},{# default is str #}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
{% endfor %}
}

Expand Down
Expand Up @@ -1134,11 +1134,11 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
)

# send a request that will satisfy transcoding
request_init = {{ method.http_options[0].sample_request}}
request_init = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
request_init["{{ field.name }}"] = {{ field.mock_value_original_type }}
{% endif %}
{% endfor %}
request = request_type(request_init)
Expand Down Expand Up @@ -1221,10 +1221,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide

request_init = {}
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% if req_field.field_pb.default_value is string %}
{% if req_field.field_pb.type == 9 %}
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
{% else %}
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
request_init["{{ req_field.name }}"] = {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
Expand Down Expand Up @@ -1324,10 +1324,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% for req_field in method.input.required_fields if req_field.is_primitive %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.default_value is string %}
{% if req_field.field_pb.type == 9 %}
"{{ req_field.field_pb.default_value }}",
{% else %}
{{ req_field.field_pb.default_value }},
{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }},
{% endif %}{# default is str #}
),
{% endfor %}
Expand All @@ -1346,11 +1346,11 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ
)

# send a request that will satisfy transcoding
request_init = {{ method.http_options[0].sample_request}}
request_init = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
request_init["{{ field.name }}"] = {{ field.mock_value_original_type }}
{% endif %}
{% endfor %}
request = request_type(request_init)
Expand Down Expand Up @@ -1411,7 +1411,7 @@ def test_{{ method_name }}_rest_flattened(transport: str = 'rest'):
req.return_value = response_value

# get arguments that satisfy an http rule for this method
sample_request = {{ method.http_options[0].sample_request }}
sample_request = {{ method.http_options[0].sample_request(method) }}

# get truthy value for each flattened field
mock_args = dict(
Expand Down Expand Up @@ -1531,7 +1531,7 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
return_val.status_code = 200
req.side_effect = return_values

sample_request = {{ method.http_options[0].sample_request }}
sample_request = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
Expand Down
2 changes: 0 additions & 2 deletions gapic/utils/__init__.py
Expand Up @@ -29,7 +29,6 @@
from gapic.utils.reserved_names import RESERVED_NAMES
from gapic.utils.rst import rst
from gapic.utils.uri_conv import convert_uri_fieldnames
from gapic.utils.uri_sample import sample_from_path_fields


__all__ = (
Expand All @@ -44,7 +43,6 @@
'partition',
'RESERVED_NAMES',
'rst',
'sample_from_path_fields',
'sort_lines',
'to_snake_case',
'to_camel_case',
Expand Down
78 changes: 0 additions & 78 deletions gapic/utils/uri_sample.py

This file was deleted.

41 changes: 41 additions & 0 deletions tests/fragments/test_required_non_string.proto
@@ -0,0 +1,41 @@
// Copyright (C) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package google.fragment;

import "google/api/client.proto";
import "google/api/field_behavior.proto";
import "google/api/annotations.proto";

service RestService {
option (google.api.default_host) = "my.example.com";

rpc MyMethod(MethodRequest) returns (MethodResponse) {
option (google.api.http) = {
get: "/restservice/v1/mass_kg/{mass_kg}/length_cm/{length_cm}"
};
}
}


message MethodRequest {
int32 mass_kg = 1 [(google.api.field_behavior) = REQUIRED];
float length_cm = 2 [(google.api.field_behavior) = REQUIRED];
}

message MethodResponse {
string name = 1;
}
51 changes: 45 additions & 6 deletions tests/unit/schema/wrappers/test_method.py
Expand Up @@ -470,19 +470,58 @@ def test_method_http_options_generate_sample():
http_rule = http_pb2.HttpRule(
get='/v1/{resource.id=projects/*/regions/*/id/**}/stuff',
)
method = make_method('DoSomething', http_rule=http_rule)
sample = method.http_options[0].sample_request
assert json.loads(sample) == {'resource': {

method = make_method(
'DoSomething',
make_message(
name="Input",
fields=[
make_field(
name="resource",
number=1,
type=11,
message=make_message(
"Resource",
fields=[
make_field(name="id", type=9),
],
),
),
],
),
http_rule=http_rule,
)
sample = method.http_options[0].sample_request(method)
assert sample == {'resource': {
'id': 'projects/sample1/regions/sample2/id/sample3'}}


def test_method_http_options_generate_sample_implicit_template():
http_rule = http_pb2.HttpRule(
get='/v1/{resource.id}/stuff',
)
method = make_method('DoSomething', http_rule=http_rule)
sample = method.http_options[0].sample_request
assert json.loads(sample) == {'resource': {
method = make_method(
'DoSomething',
make_message(
name="Input",
fields=[
make_field(
name="resource",
number=1,
message=make_message(
"Resource",
fields=[
make_field(name="id", type=9),
],
),
),
],
),
http_rule=http_rule,
)

sample = method.http_options[0].sample_request(method)
assert sample == {'resource': {
'id': 'sample1'}}


Expand Down

0 comments on commit bc5f729

Please sign in to comment.