Skip to content

Commit

Permalink
Added type annotations for public API + flake8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
guneemwelloeux committed Mar 25, 2023
1 parent 5c21b81 commit 2119ca3
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 34 deletions.
16 changes: 10 additions & 6 deletions firebase_admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
import threading
from typing import Any, Callable, Dict, Optional

from firebase_admin import credentials
from firebase_admin.__about__ import __version__
Expand All @@ -31,7 +32,8 @@
_CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId',
'storageBucket']

def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):

def initialize_app(credential: Optional[credentials.Base] = None, options: Optional[Dict[str, Any]] = None, name: str = _DEFAULT_APP_NAME) -> "App":
"""Initializes and returns a new App instance.
Creates a new App instance using the specified options
Expand Down Expand Up @@ -83,7 +85,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):
'you call initialize_app().').format(name))


def delete_app(app):
def delete_app(app: "App"):
"""Gracefully deletes an App instance.
Args:
Expand All @@ -98,7 +100,7 @@ def delete_app(app):
with _apps_lock:
if _apps.get(app.name) is app:
del _apps[app.name]
app._cleanup() # pylint: disable=protected-access
app._cleanup() # pylint: disable=protected-access
return
if app.name == _DEFAULT_APP_NAME:
raise ValueError(
Expand All @@ -111,7 +113,7 @@ def delete_app(app):
'second argument.').format(app.name))


def get_app(name=_DEFAULT_APP_NAME):
def get_app(name: str = _DEFAULT_APP_NAME) -> "App":
"""Retrieves an App instance by name.
Args:
Expand Down Expand Up @@ -190,7 +192,7 @@ class App:
common to all Firebase APIs.
"""

def __init__(self, name, credential, options):
def __init__(self, name: str, credential: credentials.Base, options: Optional[Dict[str, Any]]):
"""Constructs a new App using the provided name and options.
Args:
Expand Down Expand Up @@ -265,7 +267,7 @@ def _lookup_project_id(self):
App._validate_project_id(self._options.get('projectId'))
return project_id

def _get_service(self, name, initializer):
def _get_service(self, name: str, initializer: Callable):
"""Returns the service instance identified by the given name.
Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each
Expand Down Expand Up @@ -307,3 +309,5 @@ def _cleanup(self):
if hasattr(service, 'close') and hasattr(service.close, '__call__'):
service.close()
self._services = None

App.update
8 changes: 4 additions & 4 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Internal utilities common to all modules."""

import json
from typing import Callable, Optional

import google.auth
import requests
Expand Down Expand Up @@ -76,7 +77,7 @@
}


def _get_initialized_app(app):
def _get_initialized_app(app: Optional[firebase_admin.App]):
"""Returns a reference to an initialized App instance."""
if app is None:
return firebase_admin.get_app()
Expand All @@ -92,10 +93,9 @@ def _get_initialized_app(app):
' firebase_admin.App, but given "{0}".'.format(type(app)))



def get_app_service(app, name, initializer):
def get_app_service(app: Optional[firebase_admin.App], name: str, initializer: Callable):
app = _get_initialized_app(app)
return app._get_service(name, initializer) # pylint: disable=protected-access
return app._get_service(name, initializer) # pylint: disable=protected-access


def handle_platform_error_from_requests(error, handle_func=None):
Expand Down
9 changes: 6 additions & 3 deletions firebase_admin/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import collections
import json
import pathlib
from typing import Any, Dict, Union

import google.auth
from google.auth.transport import requests
from google.oauth2 import credentials
from google.oauth2 import service_account
import google.auth.credentials


_request = requests.Request()
Expand All @@ -44,7 +46,7 @@
class Base:
"""Provides OAuth2 access tokens for accessing Firebase services."""

def get_access_token(self):
def get_access_token(self) -> AccessTokenInfo:
"""Fetches a Google OAuth2 access token using this credential instance.
Returns:
Expand All @@ -54,7 +56,7 @@ def get_access_token(self):
google_cred.refresh(_request)
return AccessTokenInfo(google_cred.token, google_cred.expiry)

def get_credential(self):
def get_credential(self) -> google.auth.credentials.Credentials:
"""Returns the Google credential instance used for authentication."""
raise NotImplementedError

Expand All @@ -64,7 +66,7 @@ class Certificate(Base):

_CREDENTIAL_TYPE = 'service_account'

def __init__(self, cert):
def __init__(self, cert: Union[str, Dict[str, Any]]):
"""Initializes a credential from a Google service account certificate.
Service account certificates can be downloaded as JSON files from the Firebase console.
Expand Down Expand Up @@ -158,6 +160,7 @@ def _load_credential(self):
if not self._g_credential:
self._g_credential, self._project_id = google.auth.default(scopes=_scopes)


class RefreshToken(Base):
"""A credential initialized from an existing refresh token."""

Expand Down
14 changes: 8 additions & 6 deletions firebase_admin/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

try:
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
existing = globals().keys()
for key, value in firestore.__dict__.items():
if not key.startswith('_') and key not in existing:
Expand All @@ -28,13 +28,15 @@
raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure '
'to install the "google-cloud-firestore" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
import google.auth.credentials
from typing import Optional


_FIRESTORE_ATTRIBUTE = '_firestore'


def client(app=None):
def client(app: Optional[App] = None) -> firestore.Client:
"""Returns a client that can be used to interact with Google Cloud Firestore.
Args:
Expand All @@ -57,14 +59,14 @@ def client(app=None):
class _FirestoreClient:
"""Holds a Google Cloud Firestore client instance."""

def __init__(self, credentials, project):
def __init__(self, credentials: google.auth.credentials.Credentials, project: str):
self._client = firestore.Client(credentials=credentials, project=project)

def get(self):
def get(self) -> firestore.Client:
return self._client

@classmethod
def from_app(cls, app):
def from_app(cls, app: App):
"""Creates a new _FirestoreClient for the specified app."""
credentials = app.credential.get_credential()
project = app.project_id
Expand Down
5 changes: 5 additions & 0 deletions firebase_admin/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
def _get_messaging_service(app):
return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService)


def send(message, dry_run=False, app=None):
"""Sends the given message via Firebase Cloud Messaging (FCM).
Expand All @@ -115,6 +116,7 @@ def send(message, dry_run=False, app=None):
"""
return _get_messaging_service(app).send(message, dry_run)


def send_all(messages, dry_run=False, app=None):
"""Sends the given list of messages via Firebase Cloud Messaging as a single batch.
Expand All @@ -135,6 +137,7 @@ def send_all(messages, dry_run=False, app=None):
"""
return _get_messaging_service(app).send_all(messages, dry_run)


def send_multicast(multicast_message, dry_run=False, app=None):
"""Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM).
Expand Down Expand Up @@ -166,6 +169,7 @@ def send_multicast(multicast_message, dry_run=False, app=None):
) for token in multicast_message.tokens]
return _get_messaging_service(app).send_all(messages, dry_run)


def subscribe_to_topic(tokens, topic, app=None):
"""Subscribes a list of registration tokens to an FCM topic.
Expand All @@ -185,6 +189,7 @@ def subscribe_to_topic(tokens, topic, app=None):
return _get_messaging_service(app).make_topic_management_request(
tokens, topic, 'iid/v1:batchAdd')


def unsubscribe_from_topic(tokens, topic, app=None):
"""Unsubscribes a list of registration tokens from an FCM topic.
Expand Down
20 changes: 10 additions & 10 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ def from_dict(cls, data, app=None):
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model

def _update_from_dict(self, data):
copy = Model.from_dict(data)
self.model_format = copy.model_format
self._data = copy._data # pylint: disable=protected-access
self._data = copy._data # pylint: disable=protected-access

def __eq__(self, other):
if isinstance(other, self.__class__):
Expand Down Expand Up @@ -334,7 +334,7 @@ def model_format(self):
def model_format(self, model_format):
if model_format is not None:
_validate_model_format(model_format)
self._model_format = model_format #Can be None
self._model_format = model_format # Can be None
return self

def as_dict(self, for_upload=False):
Expand Down Expand Up @@ -370,7 +370,7 @@ def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format

def __eq__(self, other):
Expand Down Expand Up @@ -405,7 +405,7 @@ def model_source(self, model_source):
if model_source is not None:
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a TFLiteModelSource object.')
self._model_source = model_source # Can be None
self._model_source = model_source # Can be None

@property
def size_bytes(self):
Expand Down Expand Up @@ -485,7 +485,7 @@ def __init__(self, gcs_tflite_uri, app=None):

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return False

def __ne__(self, other):
Expand Down Expand Up @@ -775,7 +775,7 @@ def _validate_display_name(display_name):

def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, str) for tag in tags):
all(isinstance(tag, str) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
Expand All @@ -789,6 +789,7 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri


def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
Expand All @@ -809,7 +810,7 @@ def _validate_list_filter(list_filter):

def _validate_page_size(page_size):
if page_size is not None:
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
# Specifically type() to disallow boolean which is a subtype of int
raise TypeError('Page size must be a number or None.')
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
Expand Down Expand Up @@ -864,7 +865,7 @@ def _exponential_backoff(self, current_attempt, stop_time):

if stop_time is not None:
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)
Expand Down Expand Up @@ -925,7 +926,6 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
_validate_model(model)
try:
Expand Down
8 changes: 5 additions & 3 deletions firebase_admin/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
from typing import Optional


_STORAGE_ATTRIBUTE = '_storage'

def bucket(name=None, app=None) -> storage.Bucket:

def bucket(name: Optional[str] = None, app: Optional[App] = None) -> storage.Bucket:
"""Returns a handle to a Google Cloud Storage bucket.
If the name argument is not provided, uses the 'storageBucket' option specified when
Expand Down Expand Up @@ -67,7 +69,7 @@ def from_app(cls, app):
# significantly speeds up the initialization of the storage client.
return _StorageClient(credentials, app.project_id, default_bucket)

def bucket(self, name=None):
def bucket(self, name: Optional[str] = None):
"""Returns a handle to the specified Cloud Storage Bucket."""
bucket_name = name if name is not None else self._default_bucket
if bucket_name is None:
Expand Down
5 changes: 3 additions & 2 deletions firebase_admin/tenant_mgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non
FirebaseError: If an error occurs while retrieving the user accounts.
"""
tenant_mgt_service = _get_tenant_mgt_service(app)

def download(page_token, max_results):
return tenant_mgt_service.list_tenants(page_token, max_results)
return ListTenantsPage(download, page_token, max_results)
Expand All @@ -206,7 +207,7 @@ class Tenant:
def __init__(self, data):
if not isinstance(data, dict):
raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data))
if not 'name' in data:
if 'name' not in data:
raise ValueError('Tenant response missing required keys.')

self._data = data
Expand Down Expand Up @@ -256,7 +257,7 @@ def auth_for_tenant(self, tenant_id):

client = auth.Client(self.app, tenant_id=tenant_id)
self.tenant_clients[tenant_id] = client
return client
return client

def get_tenant(self, tenant_id):
"""Gets the tenant corresponding to the given ``tenant_id``."""
Expand Down

0 comments on commit 2119ca3

Please sign in to comment.