Skip to content

Commit

Permalink
feat: forward compatible diregapic LRO support (#1085)
Browse files Browse the repository at this point in the history
Detect whether a method fulfills the criteria for DIREGAPIC LRO.
If so, fudge the name of the generated method by adding the suffix
'_primitive'. This change is made for both the synchronous and async
client variants. Any generated unit tests are changed to use and
reference the fudged name.

The names of the corresponding transport method is NOT changed.
  • Loading branch information
software-dov committed Nov 17, 2021
1 parent a03bc22 commit aa7f4d5
Show file tree
Hide file tree
Showing 13 changed files with 1,159 additions and 56 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Expand Up @@ -51,7 +51,7 @@ toolchain(

py_binary(
name = "gapic_plugin",
srcs = glob(["gapic/**/*.py"]),
srcs = glob(["gapic/**/*.py", "google/**/*.py"]),
data = [":pandoc_binary"] + glob([
"gapic/**/*.j2",
"gapic/**/.*.j2",
Expand Down
46 changes: 46 additions & 0 deletions gapic/schema/wrappers.py
Expand Up @@ -41,6 +41,7 @@
from google.api import resource_pb2
from google.api_core import exceptions
from google.api_core import path_template
from google.cloud import extended_operations_pb2 as ex_ops_pb2
from google.protobuf import descriptor_pb2 # type: ignore
from google.protobuf.json_format import MessageToDict # type: ignore

Expand Down Expand Up @@ -344,6 +345,39 @@ def oneof_fields(self, include_optional=False):

return oneof_fields

@utils.cached_property
def is_diregapic_operation(self) -> bool:
if not self.name == "Operation":
return False

name, status, error_code, error_message = False, False, False, False
duplicate_msg = f"Message '{self.name}' has multiple fields with the same operation response mapping: {{}}"
for f in self.field:
maybe_op_mapping = f.options.Extensions[ex_ops_pb2.operation_field]
OperationResponseMapping = ex_ops_pb2.OperationResponseMapping

if maybe_op_mapping == OperationResponseMapping.NAME:
if name:
raise TypeError(duplicate_msg.format("name"))
name = True

if maybe_op_mapping == OperationResponseMapping.STATUS:
if status:
raise TypeError(duplicate_msg.format("status"))
status = True

if maybe_op_mapping == OperationResponseMapping.ERROR_CODE:
if error_code:
raise TypeError(duplicate_msg.format("error_code"))
error_code = True

if maybe_op_mapping == OperationResponseMapping.ERROR_MESSAGE:
if error_message:
raise TypeError(duplicate_msg.format("error_message"))
error_message = True

return name and status and error_code and error_message

@utils.cached_property
def required_fields(self) -> Sequence['Field']:
required_fields = [
Expand Down Expand Up @@ -765,6 +799,10 @@ class Method:
def __getattr__(self, name):
return getattr(self.method_pb, name)

@property
def is_operation_polling_method(self):
return self.output.is_diregapic_operation and self.options.Extensions[ex_ops_pb2.operation_polling_method]

@utils.cached_property
def client_output(self):
return self._client_output(enable_asyncio=False)
Expand Down Expand Up @@ -838,6 +876,10 @@ def _client_output(self, enable_asyncio: bool):
# Return the usual output.
return self.output

@property
def operation_service(self) -> Optional[str]:
return self.options.Extensions[ex_ops_pb2.operation_service]

@property
def is_deprecated(self) -> bool:
"""Returns true if the method is deprecated, false otherwise."""
Expand Down Expand Up @@ -1172,6 +1214,10 @@ class Service:
def __getattr__(self, name):
return getattr(self.service_pb, name)

@property
def custom_polling_method(self) -> Optional[Method]:
return next((m for m in self.methods.values() if m.is_operation_polling_method), None)

@property
def client_name(self) -> str:
"""Returns the name of the generated client class"""
Expand Down
Expand Up @@ -150,7 +150,9 @@ class {{ service.async_client_name }}:
)

{% for method in service.methods.values() %}
{%+ if not method.server_streaming %}async {% endif %}def {{ method.name|snake_case }}(self,
{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
{%+ if not method.server_streaming %}async {% endif %}def {{ method_name }}(self,
{% endwith %}
{% if not method.client_streaming %}
request: Union[{{ method.input.ident }}, dict] = None,
*,
Expand Down
Expand Up @@ -315,7 +315,11 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):


{% for method in service.methods.values() %}
{% if method.operation_service %}{# DIREGAPIC LRO #}
def {{ method.name|snake_case }}_unary(self,
{% else %}
def {{ method.name|snake_case }}(self,
{% endif %}{# DIREGAPIC LRO #}
{% if not method.client_streaming %}
request: Union[{{ method.input.ident }}, dict] = None,
*,
Expand Down

0 comments on commit aa7f4d5

Please sign in to comment.