Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add typing to proto.Message based class attributes #1474

Merged
merged 10 commits into from
Nov 9, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
import os
import re
from typing import Callable, Dict, Mapping, Optional, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union, cast
from typing import Callable, Dict, Mapping, MutableMapping, MutableSequence, Optional, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union, cast
import pkg_resources
{% if service.any_deprecated %}
import warnings
Expand Down Expand Up @@ -68,7 +68,7 @@ class {{ service.client_name }}Meta(type):
{% endif %}

def get_transport_class(cls,
label: str = None,
label: Optional[str] = None,
) -> Type[{{ service.name }}Transport]:
"""Returns an appropriate transport class.

Expand Down Expand Up @@ -340,17 +340,17 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
def {{ method.name|snake_case }}(self,
{% endif %}{# Extended Operations LRO #}
{% if not method.client_streaming %}
request: Union[{{ method.input.ident }}, dict] = None,
request: Union[{{ method.input.ident }}, dict, None] = None,
*,
{% for field in method.flattened_fields.values() %}
{{ field.name }}: {{ field.ident }} = None,
{{ field.name }}: Optional[{{ field.ident }}] = None,
{% endfor %}
{% else %}
requests: Iterator[{{ method.input.ident }}] = None,
requests: Optional[Iterator[{{ method.input.ident }}]] = None,
*,
{% endif %}
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
{% if not method.server_streaming %}
) -> {{ method.client_output.ident }}:
Expand All @@ -361,7 +361,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

Args:
{% if not method.client_streaming %}
request (Union[{{ method.input.ident.sphinx }}, dict]):
request (Union[{{ method.input.ident.sphinx }}, dict, None]):
The request object.{{ " " }}
{{- method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{% for key, field in method.flattened_fields.items() %}
Expand Down Expand Up @@ -516,10 +516,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% if opts.add_iam_methods %}
def set_iam_policy(
self,
request: iam_policy_pb2.SetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Sets the IAM access control policy on the specified function.
Expand Down Expand Up @@ -633,10 +633,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

def get_iam_policy(
self,
request: iam_policy_pb2.GetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Gets the IAM access control policy for a function.
Expand Down Expand Up @@ -750,10 +750,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):

def test_iam_permissions(
self,
request: iam_policy_pb2.TestIamPermissionsRequest = None,
request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> iam_policy_pb2.TestIamPermissionsResponse:
r"""Tests the specified IAM permissions against the IAM access control
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class {{ method.name }}Pager:
def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]:
return self._response.{{ method.paged_result_field.name }}.get(key)
{% else %}
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterator') }}:
def __iter__(self) -> {{ method.paged_result_field.ident | replace('MutableSequence', 'Iterator') }}:
for page in self.pages:
yield from page.{{ method.paged_result_field.name }}
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):

def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials = None,
credentials_file: str = None,
scopes: Sequence[str] = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
credentials: Optional[ga_credentials.Credentials] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
channel: Optional[grpc.Channel] = None,
api_mtls_endpoint: Optional[str] = None,
client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None,
ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None,
client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
Expand Down Expand Up @@ -186,8 +186,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
@classmethod
def create_channel(cls,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials = None,
credentials_file: str = None,
credentials: Optional[ga_credentials.Credentials] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
**kwargs) -> grpc.Channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials=None,
credentials_file: str=None,
scopes: Sequence[str]=None,
client_cert_source_for_mtls: Callable[[
], Tuple[bytes, bytes]]=None,
credentials_file: Optional[str]=None,
scopes: Optional[Sequence[str]]=None,
client_cert_source_for_mtls: Optional[Callable[[
], Tuple[bytes, bytes]]]=None,
quota_project_id: Optional[str]=None,
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
Expand Down Expand Up @@ -283,7 +283,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def __call__(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: float=None,
timeout: Optional[float]=None,
metadata: Sequence[Tuple[str, str]]=(),
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming.ResponseIterator{% endif %}{% endif %}:
{% if method.http_options and not method.client_streaming %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

{% with p = proto.disambiguate('proto') %}
{% if proto.messages|length or proto.all_enums|length %}
from typing import MutableMapping, MutableSequence

import proto{% if p != 'proto' %} as {{ p }}{% endif %} # type: ignore
{% endif %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class {{ message.name }}({{ p }}.Message):
{% for field in message.fields.values() %}
{% if field.map %}
{% with key_field = field.message.fields['key'], value_field = field.message.fields['value'] %}
{{ field.name }} = {{ p }}.MapField(
{{ field.name }}: MutableMapping[{{ key_field.type.ident.rel(message.ident) }}, {{ value_field.type.ident.rel(message.ident) }}] = {{ p }}.MapField(
{{ p }}.{{ key_field.proto_type }},
{{ p }}.{{ value_field.proto_type }},
number={{ field.number }},
Expand All @@ -61,7 +61,7 @@ class {{ message.name }}({{ p }}.Message):
)
{% endwith %}
{% else %}
{{ field.name }} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{ field.name }}: {% if field.is_primitive %}{{ field.ident }}{% else %}{% if field.repeated %}MutableSequence[{% endif %}{{ field.type.ident.rel(message.ident) }}{% if field.repeated %}]{% endif %}{% endif %} = {{ p }}.{% if field.repeated %}Repeated{% endif %}Field(
{{ p }}.{{ field.proto_type }},
number={{ field.number }},
{% if field.proto3_optional %}
Expand Down
4 changes: 2 additions & 2 deletions gapic/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import pathlib
import typing
from typing import Any, DefaultDict, Dict, Mapping, Tuple
from typing import Any, DefaultDict, Dict, Mapping, Optional, Tuple
from hashlib import sha256
from collections import OrderedDict, defaultdict
from gapic.samplegen_utils.utils import coerce_response_name, is_valid_sample_cfg, render_format_string
Expand Down Expand Up @@ -362,7 +362,7 @@ def _get_file(
return {fn: cgr_file}

def _get_filename(
self, template_name: str, *, api_schema: api.API, context: dict = None,
self, template_name: str, *, api_schema: api.API, context: Optional[dict] = None,
) -> str:
"""Return the appropriate output filename for this template.

Expand Down
4 changes: 2 additions & 2 deletions gapic/samplegen/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
import time
from typing import Tuple
from typing import Optional, Tuple

from gapic.samplegen_utils import (types, yaml)
from gapic.utils import case
Expand Down Expand Up @@ -45,7 +45,7 @@ def generate(
api_schema,
*,
environment: yaml.Map = PYTHON3_ENVIRONMENT,
manifest_time: int = None
manifest_time: Optional[int] = None
) -> Tuple[str, yaml.Doc]:
"""Generate a samplegen manifest for use by sampletest

Expand Down
6 changes: 3 additions & 3 deletions gapic/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def build(
file_to_generate: bool,
naming: api_naming.Naming,
opts: Options = Options(),
prior_protos: Mapping[str, 'Proto'] = None,
prior_protos: Optional[Mapping[str, 'Proto']] = None,
load_services: bool = True,
all_resources: Optional[Mapping[str, wrappers.MessageType]] = None,
) -> 'Proto':
Expand Down Expand Up @@ -243,7 +243,7 @@ def build(
file_descriptors: Sequence[descriptor_pb2.FileDescriptorProto],
package: str = '',
opts: Options = Options(),
prior_protos: Mapping[str, 'Proto'] = None,
prior_protos: Optional[Mapping[str, 'Proto']] = None,
) -> 'API':
"""Build the internal API schema based on the request.

Expand Down Expand Up @@ -631,7 +631,7 @@ def __init__(
file_to_generate: bool,
naming: api_naming.Naming,
opts: Options = Options(),
prior_protos: Mapping[str, Proto] = None,
prior_protos: Optional[Mapping[str, Proto]] = None,
load_services: bool = True,
all_resources: Optional[Mapping[str, wrappers.MessageType]] = None,
):
Expand Down
8 changes: 4 additions & 4 deletions gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,15 @@ class FieldIdentifier:

def __str__(self) -> str:
if self.mapping:
return f'Mapping[{self.mapping[0].ident}, {self.mapping[1].ident}]'
return f'MutableMapping[{self.mapping[0].ident}, {self.mapping[1].ident}]'
if self.repeated:
return f'Sequence[{self.ident}]'
return f'MutableSequence[{self.ident}]'
return str(self.ident)

@property
def sphinx(self) -> str:
if self.mapping:
return f'Mapping[{self.mapping[0].ident.sphinx}, {self.mapping[1].ident.sphinx}]'
return f'MutableMapping[{self.mapping[0].ident.sphinx}, {self.mapping[1].ident.sphinx}]'
if self.repeated:
return f'Sequence[{self.ident.sphinx}]'
return f'MutableSequence[{self.ident.sphinx}]'
return self.ident.sphinx
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
{% macro client_method(method, name, snippet_index, api, service, full_extended_lro=False) %}
def {{ name }}(self,
{% if not method.client_streaming %}
request: Union[{{ method.input.ident }}, dict] = None,
request: Union[{{ method.input.ident }}, dict, None] = None,
*,
{% for field in method.flattened_fields.values() %}
{{ field.name }}: {{ field.ident }} = None,
{{ field.name }}: Optional[{{ field.ident }}] = None,
{% endfor %}
{% else %}
requests: Iterator[{{ method.input.ident }}] = None,
requests: Optional[Iterator[{{ method.input.ident }}]] = None,
*,
{% endif %}
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
{% if method.extended_lro and not full_extended_lro %}{# This is a hack to preserve backwards compatibility with the "unary" surfaces #}
) -> {{ method.extended_lro.operation_type.ident }}:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
import functools
import re
from typing import Dict, Mapping, Optional, {% if service.any_server_streaming %}AsyncIterable, Awaitable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, {% if service.any_server_streaming %}AsyncIterable, Awaitable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
import pkg_resources
{% if service.any_deprecated %}
import warnings
Expand Down Expand Up @@ -144,9 +144,9 @@ class {{ service.async_client_name }}:
get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }}))

def __init__(self, *,
credentials: ga_credentials.Credentials = None,
credentials: Optional[ga_credentials.Credentials] = None,
transport: Union[str, {{ service.name }}Transport] = "grpc_asyncio",
client_options: ClientOptions = None,
client_options: Optional[ClientOptions] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiates the {{ (service.client_name|snake_case).replace("_", " ") }}.
Expand Down Expand Up @@ -195,17 +195,17 @@ class {{ service.async_client_name }}:
{%+ if not method.server_streaming %}async {% endif %}def {{ method_name }}(self,
{% endwith %}
{% if not method.client_streaming %}
request: Union[{{ method.input.ident }}, dict] = None,
request: Union[{{ method.input.ident }}, dict, None] = None,
*,
{% for field in method.flattened_fields.values() %}
{{ field.name }}: {{ field.ident }} = None,
{{ field.name }}: Optional[{{ field.ident }}] = None,
{% endfor %}
{% else %}
requests: AsyncIterator[{{ method.input.ident }}] = None,
requests: Optional[AsyncIterator[{{ method.input.ident }}]] = None,
*,
{% endif %}
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
{% if not method.server_streaming %}
) -> {{ method.client_output_async.ident }}:
Expand All @@ -224,7 +224,7 @@ class {{ service.async_client_name }}:

Args:
{% if not method.client_streaming %}
request (Union[{{ method.input.ident.sphinx }}, dict]):
request (Union[{{ method.input.ident.sphinx }}, dict, None]):
The request object.{{ " " }}
{{- method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{% for key, field in method.flattened_fields.items() %}
Expand Down Expand Up @@ -387,10 +387,10 @@ class {{ service.async_client_name }}:
{% if opts.add_iam_methods %}
async def set_iam_policy(
self,
request: iam_policy_pb2.SetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Sets the IAM access control policy on the specified function.
Expand Down Expand Up @@ -501,10 +501,10 @@ class {{ service.async_client_name }}:

async def get_iam_policy(
self,
request: iam_policy_pb2.GetIamPolicyRequest = None,
request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> policy_pb2.Policy:
r"""Gets the IAM access control policy for a function.
Expand Down Expand Up @@ -617,10 +617,10 @@ class {{ service.async_client_name }}:

async def test_iam_permissions(
self,
request: iam_policy_pb2.TestIamPermissionsRequest = None,
request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None,
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> iam_policy_pb2.TestIamPermissionsResponse:
r"""Tests the specified permissions against the IAM access control
Expand Down