Skip to content

Commit

Permalink
fix: snippetgen handling of repeated enum field (#1443)
Browse files Browse the repository at this point in the history
* fix: snippetgen handling of repeated enum field
  • Loading branch information
dizcology committed Nov 2, 2022
1 parent 1eeffcf commit 70d7882
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
13 changes: 10 additions & 3 deletions gapic/samplegen/samplegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ def _normal_request_setup(self, base_param_to_attrs, val, request, field):
elif attr.enum:
# A little bit hacky, but 'values' is a list, and this is the easiest
# way to verify that the value is a valid enum variant.
witness = any(e.name == val for e in attr.enum.values)
# Here val could be a list of a single enum value name.
witness = any(e.name in val for e in attr.enum.values)
if not witness:
raise types.InvalidEnumVariant(
"Invalid variant for enum {}: '{}'".format(attr, val)
Expand Down Expand Up @@ -974,8 +975,14 @@ def generate_request_object(api_schema: api.API, service: wrappers.Service, mess
{"field": field_name, "value": field.mock_value_original_type})
elif field.enum:
# Choose the last enum value in the list since index 0 is often "unspecified"
enum_value = field.enum.values[-1].name
if field.repeated:
field_value = [enum_value]
else:
field_value = enum_value

request.append(
{"field": field_name, "value": field.enum.values[-1].name})
{"field": field_name, "value": field_value})
else:
# This is a message type, recurse
# TODO(busunkim): Some real world APIs have
Expand Down Expand Up @@ -1023,7 +1030,7 @@ def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, A
spec = {
"rpc": rpc_name,
"transport": transport,
# `request` and `response` is populated in `preprocess_sample`
# `request` and `response` are populated in `preprocess_sample`
"service": f"{api_schema.naming.proto_package}.{service_name}",
"region_tag": region_tag,
"description": f"Snippet for {utils.to_snake_case(rpc_name)}"
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/samplegen/test_samplegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def test_preprocess_sample():
]


def test_preprocess_sample_with_enum_field():
@pytest.mark.parametrize(
'repeated_enum,expected', [(False, "TYPE_2"), (True, ["TYPE_2"])])
def test_preprocess_sample_with_enum_field(repeated_enum, expected):
# Verify that the default response is added.
sample = {"service": "Mollusc", "rpc": "Classify"}

Expand All @@ -212,6 +214,7 @@ def test_preprocess_sample_with_enum_field():
"type": DummyField(
name="type",
required=True,
repeated=repeated_enum,
type=enum_factory("type", ["TYPE_1", "TYPE_2"]),
enum=enum_factory("type", ["TYPE_1", "TYPE_2"])
)
Expand Down Expand Up @@ -255,7 +258,7 @@ def test_preprocess_sample_with_enum_field():
assert sample["request"] == [
{
"field": "type",
"value": "TYPE_2"
"value": expected
}
]

Expand Down

0 comments on commit 70d7882

Please sign in to comment.