From 9ac608005bf49cf8366f0da0e17c341fd034c2d8 Mon Sep 17 00:00:00 2001 From: William OLLIVIER Date: Wed, 20 Jul 2022 10:40:29 +0100 Subject: [PATCH] Added type annotations for public API + flake8 fixes --- firebase_admin/__init__.py | 16 ++++++++++------ firebase_admin/_utils.py | 8 ++++---- firebase_admin/credentials.py | 9 ++++++--- firebase_admin/firestore.py | 14 ++++++++------ firebase_admin/messaging.py | 5 +++++ firebase_admin/ml.py | 20 ++++++++++---------- firebase_admin/storage.py | 8 +++++--- firebase_admin/tenant_mgt.py | 5 +++-- 8 files changed, 51 insertions(+), 34 deletions(-) diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 7e3b2eab0..89650fca4 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -17,6 +17,7 @@ import json import os import threading +from typing import Any, Callable, Dict, Optional from firebase_admin import credentials from firebase_admin.__about__ import __version__ @@ -31,7 +32,8 @@ _CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId', 'storageBucket'] -def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): + +def initialize_app(credential: Optional[credentials.Base] = None, options: Optional[Dict[str, Any]] = None, name: str = _DEFAULT_APP_NAME) -> "App": """Initializes and returns a new App instance. Creates a new App instance using the specified options @@ -83,7 +85,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'you call initialize_app().').format(name)) -def delete_app(app): +def delete_app(app: "App"): """Gracefully deletes an App instance. Args: @@ -98,7 +100,7 @@ def delete_app(app): with _apps_lock: if _apps.get(app.name) is app: del _apps[app.name] - app._cleanup() # pylint: disable=protected-access + app._cleanup() # pylint: disable=protected-access return if app.name == _DEFAULT_APP_NAME: raise ValueError( @@ -111,7 +113,7 @@ def delete_app(app): 'second argument.').format(app.name)) -def get_app(name=_DEFAULT_APP_NAME): +def get_app(name: str = _DEFAULT_APP_NAME) -> "App": """Retrieves an App instance by name. Args: @@ -190,7 +192,7 @@ class App: common to all Firebase APIs. """ - def __init__(self, name, credential, options): + def __init__(self, name: str, credential: credentials.Base, options: Optional[Dict[str, Any]]): """Constructs a new App using the provided name and options. Args: @@ -265,7 +267,7 @@ def _lookup_project_id(self): App._validate_project_id(self._options.get('projectId')) return project_id - def _get_service(self, name, initializer): + def _get_service(self, name: str, initializer: Callable): """Returns the service instance identified by the given name. Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each @@ -307,3 +309,5 @@ def _cleanup(self): if hasattr(service, 'close') and hasattr(service.close, '__call__'): service.close() self._services = None + +App.update diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index dcfb520d2..f56d9661d 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -15,6 +15,7 @@ """Internal utilities common to all modules.""" import json +from typing import Callable, Optional import google.auth import requests @@ -76,7 +77,7 @@ } -def _get_initialized_app(app): +def _get_initialized_app(app: Optional[firebase_admin.App]): """Returns a reference to an initialized App instance.""" if app is None: return firebase_admin.get_app() @@ -92,10 +93,9 @@ def _get_initialized_app(app): ' firebase_admin.App, but given "{0}".'.format(type(app))) - -def get_app_service(app, name, initializer): +def get_app_service(app: Optional[firebase_admin.App], name: str, initializer: Callable): app = _get_initialized_app(app) - return app._get_service(name, initializer) # pylint: disable=protected-access + return app._get_service(name, initializer) # pylint: disable=protected-access def handle_platform_error_from_requests(error, handle_func=None): diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 5477e1cf7..2d31369bd 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -16,11 +16,13 @@ import collections import json import pathlib +from typing import Any, Dict, Union import google.auth from google.auth.transport import requests from google.oauth2 import credentials from google.oauth2 import service_account +import google.auth.credentials _request = requests.Request() @@ -44,7 +46,7 @@ class Base: """Provides OAuth2 access tokens for accessing Firebase services.""" - def get_access_token(self): + def get_access_token(self) -> AccessTokenInfo: """Fetches a Google OAuth2 access token using this credential instance. Returns: @@ -54,7 +56,7 @@ def get_access_token(self): google_cred.refresh(_request) return AccessTokenInfo(google_cred.token, google_cred.expiry) - def get_credential(self): + def get_credential(self) -> google.auth.credentials.Credentials: """Returns the Google credential instance used for authentication.""" raise NotImplementedError @@ -64,7 +66,7 @@ class Certificate(Base): _CREDENTIAL_TYPE = 'service_account' - def __init__(self, cert): + def __init__(self, cert: Union[str, Dict[str, Any]]): """Initializes a credential from a Google service account certificate. Service account certificates can be downloaded as JSON files from the Firebase console. @@ -158,6 +160,7 @@ def _load_credential(self): if not self._g_credential: self._g_credential, self._project_id = google.auth.default(scopes=_scopes) + class RefreshToken(Base): """A credential initialized from an existing refresh token.""" diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 32c9897d5..a4179f2ab 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -19,7 +19,7 @@ """ try: - from google.cloud import firestore # pylint: disable=import-error,no-name-in-module + from google.cloud import firestore # pylint: disable=import-error,no-name-in-module existing = globals().keys() for key, value in firestore.__dict__.items(): if not key.startswith('_') and key not in existing: @@ -28,13 +28,15 @@ raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') -from firebase_admin import _utils +from firebase_admin import _utils, App +import google.auth.credentials +from typing import Optional _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app=None): +def client(app: Optional[App] = None) -> firestore.Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: @@ -57,14 +59,14 @@ def client(app=None): class _FirestoreClient: """Holds a Google Cloud Firestore client instance.""" - def __init__(self, credentials, project): + def __init__(self, credentials: google.auth.credentials.Credentials, project: str): self._client = firestore.Client(credentials=credentials, project=project) - def get(self): + def get(self) -> firestore.Client: return self._client @classmethod - def from_app(cls, app): + def from_app(cls, app: App): """Creates a new _FirestoreClient for the specified app.""" credentials = app.credential.get_credential() project = app.project_id diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 46dd7d410..4cfea32c5 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -95,6 +95,7 @@ def _get_messaging_service(app): return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) + def send(message, dry_run=False, app=None): """Sends the given message via Firebase Cloud Messaging (FCM). @@ -115,6 +116,7 @@ def send(message, dry_run=False, app=None): """ return _get_messaging_service(app).send(message, dry_run) + def send_all(messages, dry_run=False, app=None): """Sends the given list of messages via Firebase Cloud Messaging as a single batch. @@ -135,6 +137,7 @@ def send_all(messages, dry_run=False, app=None): """ return _get_messaging_service(app).send_all(messages, dry_run) + def send_multicast(multicast_message, dry_run=False, app=None): """Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM). @@ -166,6 +169,7 @@ def send_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_all(messages, dry_run) + def subscribe_to_topic(tokens, topic, app=None): """Subscribes a list of registration tokens to an FCM topic. @@ -185,6 +189,7 @@ def subscribe_to_topic(tokens, topic, app=None): return _get_messaging_service(app).make_topic_management_request( tokens, topic, 'iid/v1:batchAdd') + def unsubscribe_from_topic(tokens, topic, app=None): """Unsubscribes a list of registration tokens from an FCM topic. diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index bcc4b9390..2c46c7f0e 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -210,13 +210,13 @@ def from_dict(cls, data, app=None): tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) model._data = data_copy # pylint: disable=protected-access - model._app = app # pylint: disable=protected-access + model._app = app # pylint: disable=protected-access return model def _update_from_dict(self, data): copy = Model.from_dict(data) self.model_format = copy.model_format - self._data = copy._data # pylint: disable=protected-access + self._data = copy._data # pylint: disable=protected-access def __eq__(self, other): if isinstance(other, self.__class__): @@ -333,7 +333,7 @@ def model_format(self): def model_format(self, model_format): if model_format is not None: _validate_model_format(model_format) - self._model_format = model_format #Can be None + self._model_format = model_format # Can be None return self def as_dict(self, for_upload=False): @@ -369,7 +369,7 @@ def from_dict(cls, data): """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) - tflite_format._data = data_copy # pylint: disable=protected-access + tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format def __eq__(self, other): @@ -401,7 +401,7 @@ def model_source(self, model_source): if model_source is not None: if not isinstance(model_source, TFLiteModelSource): raise TypeError('Model source must be a TFLiteModelSource object.') - self._model_source = model_source # Can be None + self._model_source = model_source # Can be None @property def size_bytes(self): @@ -481,7 +481,7 @@ def __init__(self, gcs_tflite_uri, app=None): def __eq__(self, other): if isinstance(other, self.__class__): - return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access + return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access return False def __ne__(self, other): @@ -766,7 +766,7 @@ def _validate_display_name(display_name): def _validate_tags(tags): if not isinstance(tags, list) or not \ - all(isinstance(tag, str) for tag in tags): + all(isinstance(tag, str) for tag in tags): raise TypeError('Tags must be a list of strings.') if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') @@ -780,6 +780,7 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri + def _validate_auto_ml_model(model): if not _AUTO_ML_MODEL_PATTERN.match(model): raise ValueError('Model resource name format is invalid.') @@ -800,7 +801,7 @@ def _validate_list_filter(list_filter): def _validate_page_size(page_size): if page_size is not None: - if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck + if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck # Specifically type() to disallow boolean which is a subtype of int raise TypeError('Page size must be a number or None.') if page_size < 1 or page_size > _MAX_PAGE_SIZE: @@ -855,7 +856,7 @@ def _exponential_backoff(self, current_attempt, stop_time): if stop_time is not None: max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() - if max_seconds_left < 1: # allow a bit of time for rpc + if max_seconds_left < 1: # allow a bit of time for rpc raise exceptions.DeadlineExceededError('Polling max time exceeded.') wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) @@ -916,7 +917,6 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds # If the operation is not complete or timed out, return a (locked) model instead return get_model(model_id).as_dict() - def create_model(self, model): _validate_model(model) try: diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index f3948371c..39da1e632 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -25,12 +25,14 @@ raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') -from firebase_admin import _utils +from firebase_admin import _utils, App +from typing import Optional _STORAGE_ATTRIBUTE = '_storage' -def bucket(name=None, app=None) -> storage.Bucket: + +def bucket(name: Optional[str] = None, app: Optional[App] = None) -> storage.Bucket: """Returns a handle to a Google Cloud Storage bucket. If the name argument is not provided, uses the 'storageBucket' option specified when @@ -67,7 +69,7 @@ def from_app(cls, app): # significantly speeds up the initialization of the storage client. return _StorageClient(credentials, app.project_id, default_bucket) - def bucket(self, name=None): + def bucket(self, name: Optional[str] = None): """Returns a handle to the specified Cloud Storage Bucket.""" bucket_name = name if name is not None else self._default_bucket if bucket_name is None: diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 396a819fb..e30e65dd2 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -183,6 +183,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non FirebaseError: If an error occurs while retrieving the user accounts. """ tenant_mgt_service = _get_tenant_mgt_service(app) + def download(page_token, max_results): return tenant_mgt_service.list_tenants(page_token, max_results) return ListTenantsPage(download, page_token, max_results) @@ -206,7 +207,7 @@ class Tenant: def __init__(self, data): if not isinstance(data, dict): raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) - if not 'name' in data: + if 'name' not in data: raise ValueError('Tenant response missing required keys.') self._data = data @@ -256,7 +257,7 @@ def auth_for_tenant(self, tenant_id): client = auth.Client(self.app, tenant_id=tenant_id) self.tenant_clients[tenant_id] = client - return client + return client def get_tenant(self, tenant_id): """Gets the tenant corresponding to the given ``tenant_id``."""