From 2246418ecd98729c3ea3342472b3e93d3b7de1d5 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 1 Jun 2017 15:07:30 -0700 Subject: [PATCH] Implementing a Database API for Python Admin SDK (#31) * Implemented get_token() method for App * Adding db client * Fleshed out the DatabaseReference type * Implemented the full DB API * Added unit tests for DB queries * Support for service cleanup; More tests * Further API cleanup * Code cleanup and more test cases * Fixing test for Python 3 * More python3 fixes * Get/set priority * Implementing query filters * Implemented query filters * Implemented a Query abstraction * Adding integration tests for DB API * Adding license headers to new files; Using the same testutils from unit tests * Added integration tests for comlpex queries * More integration tests for complex queries * Some documentation and tests * Improved error handling; More integration tests * Updated API docs * Implement support for sorting query results * Support for sorting lists * Fixed the sorting implementation and updated test cases * Braking index ties by comparing their keys * Updated integration tests to check for result order * Updated docstrings * Added newlines at the end of test data files * Updated documentation; Fixed a bug in service cleanup logic; Other minor nits. --- firebase_admin/__init__.py | 77 +++- firebase_admin/auth.py | 22 +- firebase_admin/db.py | 637 ++++++++++++++++++++++++++++++++ firebase_admin/utils.py | 34 ++ integration/__init__.py | 16 + integration/conftest.py | 50 +++ integration/test_db.py | 250 +++++++++++++ setup.cfg | 3 + tests/data/dinosaurs.json | 78 ++++ tests/data/dinosaurs_index.json | 12 + tests/test_app.py | 22 ++ tests/test_db.py | 591 +++++++++++++++++++++++++++++ 12 files changed, 1770 insertions(+), 22 deletions(-) create mode 100644 firebase_admin/db.py create mode 100644 firebase_admin/utils.py create mode 100644 integration/__init__.py create mode 100644 integration/conftest.py create mode 100644 integration/test_db.py create mode 100644 tests/data/dinosaurs.json create mode 100644 tests/data/dinosaurs_index.json create mode 100644 tests/test_db.py diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 3f340a95..ed9b11d0 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """Firebase Admin SDK for Python.""" +import datetime import threading import six @@ -26,8 +27,10 @@ _apps = {} _apps_lock = threading.RLock() +_clock = datetime.datetime.utcnow _DEFAULT_APP_NAME = '[DEFAULT]' +_CLOCK_SKEW_SECONDS = 300 def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): @@ -91,6 +94,7 @@ def delete_app(app): with _apps_lock: if _apps.get(app.name) is app: del _apps[app.name] + app._cleanup() # pylint: disable=protected-access return if app.name == _DEFAULT_APP_NAME: raise ValueError( @@ -145,12 +149,16 @@ def __init__(self, options): 'must be a dictionary.'.format(type(options))) self._options = options + def get(self, key): + """Returns the option identified by the provided key.""" + return self._options.get(key) + class App(object): """The entry point for Firebase Python SDK. - Represents a Firebase app, while holding the configuration and state - common to all Firebase APIs. + Represents a Firebase app, while holding the configuration and state + common to all Firebase APIs. """ def __init__(self, name, credential, options): @@ -174,6 +182,9 @@ def __init__(self, name, credential, options): 'with a valid credential instance.') self._credential = credential self._options = _AppOptions(options) + self._token = None + self._lock = threading.RLock() + self._services = {} @property def name(self): @@ -186,3 +197,65 @@ def credential(self): @property def options(self): return self._options + + def get_token(self): + """Returns an OAuth2 bearer token. + + This method may return a cached token. But it handles cache invalidation, and therefore + is guaranteed to always return unexpired tokens. + + Returns: + string: An unexpired OAuth2 token. + """ + if not self._token_valid(): + self._token = self._credential.get_access_token() + return self._token.access_token + + def _token_valid(self): + if self._token is None: + return False + skewed_expiry = self._token.expiry - datetime.timedelta(seconds=_CLOCK_SKEW_SECONDS) + return _clock() < skewed_expiry + + def _get_service(self, name, initializer): + """Returns the service instance identified by the given name. + + Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each + service instance is associated with exactly one App. If the named service + instance does not exist yet, _get_service() calls the provided initializer function to + create the service instance. The created instance will be cached, so that subsequent + calls would always fetch it from the cache. + + Args: + name: Name of the service to retrieve. + initializer: A function that can be used to initialize a service for the first time. + + Returns: + object: The specified service instance. + + Raises: + ValueError: If the provided name is invalid, or if the App is already deleted. + """ + if not name or not isinstance(name, six.string_types): + raise ValueError( + 'Illegal name argument: "{0}". Name must be a non-empty string.'.format(name)) + with self._lock: + if self._services is None: + raise ValueError( + 'Service requested from deleted Firebase App: "{0}".'.format(self._name)) + if name not in self._services: + self._services[name] = initializer(self) + return self._services[name] + + def _cleanup(self): + """Cleans up any services associated with this App. + + Checks whether each service contains a close() method, and calls it if available. + This is to be called when an App is being deleted, thus ensuring graceful termination of + any services started by the App. + """ + with self._lock: + for service in self._services.values(): + if hasattr(service, 'close') and hasattr(service.close, '__call__'): + service.close() + self._services = None diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index cfba9c93..8f9208cc 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -27,8 +27,8 @@ import google.oauth2.id_token import six -import firebase_admin from firebase_admin import credentials +from firebase_admin import utils _auth_lock = threading.Lock() @@ -39,20 +39,6 @@ GCLOUD_PROJECT_ENV_VAR = 'GCLOUD_PROJECT' -def _get_initialized_app(app): - if app is None: - return firebase_admin.get_app() - elif isinstance(app, firebase_admin.App): - initialized_app = firebase_admin.get_app(app.name) - if app is not initialized_app: - raise ValueError('Illegal app argument. App instance not ' - 'initialized via the firebase module.') - return app - else: - raise ValueError('Illegal app argument. Argument must be of type ' - ' firebase_admin.App, but given "{0}".'.format(type(app))) - - def _get_token_generator(app): """Returns a _TokenGenerator instance for an App. @@ -69,11 +55,7 @@ def _get_token_generator(app): Raises: ValueError: If the app argument is invalid. """ - app = _get_initialized_app(app) - with _auth_lock: - if not hasattr(app, _AUTH_ATTRIBUTE): - setattr(app, _AUTH_ATTRIBUTE, _TokenGenerator(app)) - return getattr(app, _AUTH_ATTRIBUTE) + return utils.get_app_service(app, _AUTH_ATTRIBUTE, _TokenGenerator) def create_custom_token(uid, developer_claims=None, app=None): diff --git a/firebase_admin/db.py b/firebase_admin/db.py new file mode 100644 index 00000000..ac37f613 --- /dev/null +++ b/firebase_admin/db.py @@ -0,0 +1,637 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Realtime Database module. + +This module contains functions and classes that facilitate interacting with the Firebase Realtime +Database. It supports basic data manipulation operations, as well as complex queries such as +limit queries and range queries. However, it does not support realtime update notifications. This +module uses the Firebase REST API underneath. +""" + +import collections +import json +import numbers + +import requests +import six +from six.moves import urllib + +from firebase_admin import utils + +_DB_ATTRIBUTE = '_database' +_INVALID_PATH_CHARACTERS = '[].#$' +_RESERVED_FILTERS = ('$key', '$value', '$priority') + + +def get_reference(path='/', app=None): + """Returns a database Reference representing the node at the specified path. + + If no path is specified, this function returns a Reference that represents the database root. + + Args: + path: Path to a node in the Firebase realtime database (optional). + app: An App instance (optional). + + Returns: + Reference: A newly initialized Reference. + + Raises: + ValueError: If the specified path or app is invalid. + """ + client = utils.get_app_service(app, _DB_ATTRIBUTE, _Client.from_app) + return Reference(client=client, path=path) + +def _parse_path(path): + """Parses a path string into a set of segments.""" + if not isinstance(path, six.string_types): + raise ValueError('Invalid path: "{0}". Path must be a string.'.format(path)) + if any(ch in path for ch in _INVALID_PATH_CHARACTERS): + raise ValueError( + 'Invalid path: "{0}". Path contains illegal characters.'.format(path)) + return [seg for seg in path.split('/') if seg] + + +class Reference(object): + """Reference represents a node in the Firebase realtime database.""" + + def __init__(self, **kwargs): + """Creates a new Reference using the provided parameters. + + This method is for internal use only. Use db.get_reference() to obtain an instance of + Reference. + """ + self._client = kwargs.get('client') + if 'segments' in kwargs: + self._segments = kwargs.get('segments') + else: + self._segments = _parse_path(kwargs.get('path')) + self._pathurl = '/' + '/'.join(self._segments) + + @property + def key(self): + if self._segments: + return self._segments[-1] + return None + + @property + def path(self): + return self._pathurl + + @property + def parent(self): + if self._segments: + return Reference(client=self._client, segments=self._segments[:-1]) + return None + + def child(self, path): + """Returns a Reference to the specified child node. + + The path may point to an immediate child of the current Reference, or a deeply nested + child. Child paths must not begin with '/'. + + Args: + path: Path to the child node. + + Returns: + Reference: A database Reference representing the specified child node. + + Raises: + ValueError: If the child path is not a string, not well-formed or begins with '/'. + """ + if not path or not isinstance(path, six.string_types): + raise ValueError( + 'Invalid path argument: "{0}". Path must be a non-empty string.'.format(path)) + if path.startswith('/'): + raise ValueError( + 'Invalid path argument: "{0}". Child path must not start with "/"'.format(path)) + full_path = self._pathurl + '/' + path + return Reference(client=self._client, path=full_path) + + def get_value(self): + """Returns the value at the current location of the database. + + Returns: + object: Decoded JSON value of the current database Reference. + + Raises: + ApiCallError: If an error occurs while communicating with the remote database server. + """ + return self._client.request('get', self._add_suffix()) + + def get_priority(self): + """Returns the priority of this node, if specified. + + Returns: + object: A priority value or None. + + Raises: + ApiCallError: If an error occurs while communicating with the remote database server. + """ + return self._client.request('get', self._add_suffix('/.priority.json')) + + def set_value(self, value, priority=None): + """Sets the data at this location to the given value. + + The value must be JSON-serializable and not None. If a priority is specified, the node will + be assigned that priority along with the value. + + Args: + value: JSON-serialable value to be set at this location. + priority: A numeric or alphanumeric priority value (optional). + + Raises: + ValueError: If the value is None or priority is invalid. + TypeError: If the value is not JSON-serializable. + ApiCallError: If an error occurs while communicating with the remote database server. + """ + if value is None: + raise ValueError('Value must not be None.') + if priority is not None: + Reference._check_priority(priority) + if isinstance(value, dict): + value = dict(value) + value['.priority'] = priority + else: + value = {'.value' : value, '.priority' : priority} + params = {'print' : 'silent'} + self._client.request_oneway('put', self._add_suffix(), json=value, params=params) + + def push(self, value=''): + """Creates a new child node. + + The optional value argument can be used to provide an initial value for the child node. If + no value is provided, child node will have empty string as the default value. + + Args: + value: JSON-serializable initial value for the child node (optional). + + Returns: + Reference: A Reference representing the newly created child node. + + Raises: + ValueError: If the value is None. + TypeError: If the value is not JSON-serializable. + ApiCallError: If an error occurs while communicating with the remote database server. + """ + if value is None: + raise ValueError('Value must not be None.') + output = self._client.request('post', self._add_suffix(), json=value) + push_id = output.get('name') + return self.child(push_id) + + def update_children(self, value): + """Updates the specified child keys of this Reference to the provided values. + + Args: + value: A dictionary containing the child keys to update, and their new values. + + Raises: + ValueError: If value is empty or not a dictionary. + ApiCallError: If an error occurs while communicating with the remote database server. + """ + if not value or not isinstance(value, dict): + raise ValueError('Value argument must be a non-empty dictionary.') + if None in value.keys() or None in value.values(): + raise ValueError('Dictionary must not contain None keys or values.') + params = {'print':'silent'} + self._client.request_oneway('patch', self._add_suffix(), json=value, params=params) + + def delete(self): + """Deleted this node from the database. + + Raises: + ApiCallError: If an error occurs while communicating with the remote database server. + """ + self._client.request_oneway('delete', self._add_suffix()) + + def order_by_child(self, path): + """Returns a Query that orders data by child values. + + Returned Query can be used to set additional parameters, and execute complex database + queries (e.g. limit queries, range queries). + + Args: + path: Path to a valid child of the current Reference. + + Returns: + Query: A database Query instance. + + Raises: + ValueError: If the child path is not a string, not well-formed or None. + """ + if path in _RESERVED_FILTERS: + raise ValueError('Illegal child path: {0}'.format(path)) + return Query(order_by=path, client=self._client, pathurl=self._add_suffix()) + + def order_by_key(self): + """Creates a Query that orderes data by key. + + Returned Query can be used to set additional parameters, and execute complex database + queries (e.g. limit queries, range queries). + + Returns: + Query: A database Query instance. + """ + return Query(order_by='$key', client=self._client, pathurl=self._add_suffix()) + + def order_by_value(self): + """Creates a Query that orderes data by value. + + Returned Query can be used to set additional parameters, and execute complex database + queries (e.g. limit queries, range queries). + + Returns: + Query: A database Query instance. + """ + return Query(order_by='$value', client=self._client, pathurl=self._add_suffix()) + + def order_by_priority(self): + """Creates a Query that orderes data by priority. + + Returned Query can be used to set additional parameters, and execute complex database + queries (e.g. limit queries, range queries). Due to a limitation of the + underlying REST API, the order-by-priority constraint can only be enforced during + the execution time of the Query. When the Query returns results, the actual results + will be returned as an unordered collection. + + Returns: + Query: A database Query instance. + """ + return Query(order_by='$priority', client=self._client, pathurl=self._add_suffix()) + + def _add_suffix(self, suffix='.json'): + return self._pathurl + suffix + + @classmethod + def _check_priority(cls, priority): + if isinstance(priority, six.string_types) and priority.isalnum(): + return + if isinstance(priority, numbers.Number): + return + raise ValueError('Illegal priority value: "{0}". Priority values must be numeric or ' + 'alphanumeric.'.format(priority)) + + +class Query(object): + """Represents a complex query that can be executed on a Reference. + + Complex queries can consist of up to 2 components: a required ordering constraint, and an + optional filtering constraint. At the server, data is first sorted according to the given + ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) + is applied on the sorted data to produce the final result. Despite the ordering constraint, + the final result is returned by the server as an unordered collection. Therefore the Query + interface performs another round of sorting at the client-side before returning the results + to the caller. This client-side sorted results are returned to the user as a Python + OrderedDict. However, client-side sorting is not feasible for order-by-priority queries. + Therefore for such queries results are returned as a regular unordered dict. + """ + + def __init__(self, **kwargs): + order_by = kwargs.pop('order_by') + if not order_by or not isinstance(order_by, six.string_types): + raise ValueError('order_by field must be a non-empty string') + if order_by not in _RESERVED_FILTERS: + if order_by.startswith('/'): + raise ValueError('Invalid path argument: "{0}". Child path must not start ' + 'with "/"'.format(order_by)) + segments = _parse_path(order_by) + order_by = '/'.join(segments) + self._client = kwargs.pop('client') + self._pathurl = kwargs.pop('pathurl') + self._order_by = order_by + self._params = {'orderBy' : json.dumps(order_by)} + if kwargs: + raise ValueError('Unexpected keyword arguments: {0}'.format(kwargs)) + + def set_limit_first(self, limit): + """Creates a query with limit, and anchors it to the start of the window. + + Args: + limit: The maximum number of child nodes to return. + + Returns: + Query: The updated Query instance. + + Raises: + ValueError: If the value is not an integer, or set_limit_last() was called previously. + """ + if not isinstance(limit, int): + raise ValueError('Limit must be an integer.') + if 'limitToLast' in self._params: + raise ValueError('Cannot set both first and last limits.') + self._params['limitToFirst'] = limit + return self + + def set_limit_last(self, limit): + """Creates a query with limit, and anchors it to the end of the window. + + Args: + limit: The maximum number of child nodes to return. + + Returns: + Query: The updated Query instance. + + Raises: + ValueError: If the value is not an integer, or set_limit_first() was called previously. + """ + if not isinstance(limit, int): + raise ValueError('Limit must be an integer.') + if 'limitToFirst' in self._params: + raise ValueError('Cannot set both first and last limits.') + self._params['limitToLast'] = limit + return self + + def set_start_at(self, start): + """Sets the lower bound for a range query. + + The Query will only return child nodes with a value greater than or equal to the specified + value. + + Args: + start: JSON-serializable value to start at, inclusive. + + Returns: + Query: The updated Query instance. + + Raises: + ValueError: If the value is empty or None. + """ + if not start: + raise ValueError('Start value must not be empty or None.') + self._params['startAt'] = json.dumps(start) + return self + + def set_end_at(self, end): + """Sets the upper bound for a range query. + + The Query will only return child nodes with a value less than or equal to the specified + value. + + Args: + end: JSON-serializable value to end at, inclusive. + + Returns: + Query: The updated Query instance. + + Raises: + ValueError: If the value is empty or None. + """ + if not end: + raise ValueError('End value must not be empty or None.') + self._params['endAt'] = json.dumps(end) + return self + + def set_equal_to(self, value): + """Sets an equals constraint on the Query. + + The Query will only return child nodes whose value is equal to the specified value. + + Args: + value: JSON-serializable value to query for. + + Returns: + Query: The updated Query instance. + + Raises: + ValueError: If the value is empty or None. + """ + if not value: + raise ValueError('Equal to value must not be empty or None.') + self._params['equalTo'] = json.dumps(value) + return self + + @property + def querystr(self): + params = [] + for key in sorted(self._params): + params.append('{0}={1}'.format(key, self._params[key])) + return '&'.join(params) + + def run(self): + """Executes this Query and returns the results. + + The results will be returned as a sorted list or an OrderedDict, except in the case of + order-by-priority queries. + + Returns: + object: Decoded JSON result of the Query. + + Raises: + ApiCallError: If an error occurs while communicating with the remote database server. + """ + result = self._client.request('get', '{0}?{1}'.format(self._pathurl, self.querystr)) + if isinstance(result, (dict, list)) and self._order_by != '$priority': + return _Sorter(result, self._order_by).get() + return result + + +class ApiCallError(Exception): + """Represents an Exception encountered while invoking the Firebase database server API.""" + + def __init__(self, message, error): + Exception.__init__(self, message) + self.detail = error + + +class _Sorter(object): + """Helper class for sorting query results.""" + + def __init__(self, results, order_by): + if isinstance(results, dict): + self.dict_input = True + entries = [_SortEntry(k, v, order_by) for k, v in results.items()] + elif isinstance(results, list): + self.dict_input = False + entries = [_SortEntry(k, v, order_by) for k, v in enumerate(results)] + else: + raise ValueError('Sorting not supported for "{0}" object.'.format(type(results))) + self.sort_entries = sorted(entries) + + def get(self): + if self.dict_input: + return collections.OrderedDict([(e.key, e.value) for e in self.sort_entries]) + else: + return [e.value for e in self.sort_entries] + + +class _SortEntry(object): + """A wrapper that is capable of sorting items in a dictionary.""" + + _type_none = 0 + _type_bool_false = 1 + _type_bool_true = 2 + _type_numeric = 3 + _type_string = 4 + _type_object = 5 + + def __init__(self, key, value, order_by): + self._key = key + self._value = value + if order_by == '$key' or order_by == '$priority': + self._index = key + elif order_by == '$value': + self._index = value + else: + self._index = _SortEntry._extract_child(value, order_by) + self._index_type = _SortEntry._get_index_type(self._index) + + @property + def key(self): + return self._key + + @property + def index(self): + return self._index + + @property + def index_type(self): + return self._index_type + + @property + def value(self): + return self._value + + @classmethod + def _get_index_type(cls, index): + """Assigns an integer code to the type of the index. + + The index type determines how differently typed values are sorted. This ordering is based + on https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data + """ + if index is None: + return cls._type_none + elif isinstance(index, bool) and not index: + return cls._type_bool_false + elif isinstance(index, bool) and index: + return cls._type_bool_true + elif isinstance(index, (int, float)): + return cls._type_numeric + elif isinstance(index, six.string_types): + return cls._type_string + else: + return cls._type_object + + @classmethod + def _extract_child(cls, value, path): + segments = path.split('/') + current = value + for segment in segments: + if isinstance(current, dict): + current = current.get(segment) + else: + return None + return current + + def _compare(self, other): + """Compares two _SortEntry instances. + + If the indices have the same numeric or string type, compare them directly. Ties are + broken by comparing the keys. If the indices have the same type, but are neither numeric + nor string, compare the keys. In all other cases compare based on the ordering provided + by index types. + """ + self_key, other_key = self.index_type, other.index_type + if self_key == other_key: + if self_key in (self._type_numeric, self._type_string) and self.index != other.index: + self_key, other_key = self.index, other.index + else: + self_key, other_key = self.key, other.key + + if self_key < other_key: + return -1 + elif self_key > other_key: + return 1 + else: + return 0 + + def __lt__(self, other): + return self._compare(other) < 0 + + def __le__(self, other): + return self._compare(other) <= 0 + + def __gt__(self, other): + return self._compare(other) > 0 + + def __ge__(self, other): + return self._compare(other) >= 0 + + def __eq__(self, other): + return self._compare(other) is 0 + + +class _Client(object): + """HTTP client used to make REST calls. + + _Client maintains an HTTP session, and handles authenticating HTTP requests along with + marshalling and unmarshalling of JSON data. + """ + + def __init__(self, url=None, auth=None, session=None): + self._url = url + self._auth = auth + self._session = session + + @classmethod + def from_app(cls, app): + """Created a new _Client for a given App""" + url = app.options.get('dbURL') + if not url or not isinstance(url, six.string_types): + raise ValueError( + 'Invalid dbURL option: "{0}". dbURL must be a non-empty URL string.'.format(url)) + parsed = urllib.parse.urlparse(url) + if parsed.scheme != 'https': + raise ValueError( + 'Invalid dbURL option: "{0}". dbURL must be an HTTPS URL.'.format(url)) + elif not parsed.netloc.endswith('.firebaseio.com'): + raise ValueError( + 'Invalid dbURL option: "{0}". dbURL must be a valid URL to a Firebase Realtime ' + 'Database instance.'.format(url)) + return _Client('https://{0}'.format(parsed.netloc), _OAuth(app), requests.Session()) + + def request(self, method, urlpath, **kwargs): + return self._do_request(method, urlpath, **kwargs).json() + + def request_oneway(self, method, urlpath, **kwargs): + self._do_request(method, urlpath, **kwargs) + + def _do_request(self, method, urlpath, **kwargs): + try: + resp = self._session.request(method, self._url + urlpath, auth=self._auth, **kwargs) + resp.raise_for_status() + return resp + except requests.exceptions.RequestException as error: + raise ApiCallError(self._extract_error_message(error), error) + + def _extract_error_message(self, error): + if error.response is not None: + data = json.loads(error.response.content) + if isinstance(data, dict): + return '{0}\nReason: {1}'.format(error, data.get('error', 'unknown')) + return str(error) + + def close(self): + self._session.close() + self._auth = None + self._url = None + + +class _OAuth(requests.auth.AuthBase): + def __init__(self, app): + self._app = app + + def __call__(self, req): + req.headers['Authorization'] = 'Bearer {0}'.format(self._app.get_token()) + return req diff --git a/firebase_admin/utils.py b/firebase_admin/utils.py new file mode 100644 index 00000000..4243d38a --- /dev/null +++ b/firebase_admin/utils.py @@ -0,0 +1,34 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal utilities common to all modules.""" + +import firebase_admin + +def _get_initialized_app(app): + if app is None: + return firebase_admin.get_app() + elif isinstance(app, firebase_admin.App): + initialized_app = firebase_admin.get_app(app.name) + if app is not initialized_app: + raise ValueError('Illegal app argument. App instance not ' + 'initialized via the firebase module.') + return app + else: + raise ValueError('Illegal app argument. Argument must be of type ' + ' firebase_admin.App, but given "{0}".'.format(type(app))) + +def get_app_service(app, name, initializer): + app = _get_initialized_app(app) + return app._get_service(name, initializer) # pylint: disable=protected-access diff --git a/integration/__init__.py b/integration/__init__.py new file mode 100644 index 00000000..81707da0 --- /dev/null +++ b/integration/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Enables exclusion of the tests module from the distribution. diff --git a/integration/conftest.py b/integration/conftest.py new file mode 100644 index 00000000..ba4bb63c --- /dev/null +++ b/integration/conftest.py @@ -0,0 +1,50 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pytest configuration and global fixtures for integration tests.""" +import json + +import pytest + +import firebase_admin +from firebase_admin import credentials + + +def pytest_addoption(parser): + parser.addoption( + '--cert', action='store', help='Service account certificate file for integration tests.') + +def _get_cert_path(request): + cert = request.config.getoption('--cert') + if cert: + return cert + raise ValueError('Service account certificate not specified. Make sure to specify the ' + '"--cert" command-line option.') + +@pytest.fixture(autouse=True, scope='session') +def default_app(request): + """Initializes the default Firebase App instance used for all integration tests. + + This fixture is attached to the session scope, which ensures that it runs only once during + a test session. It is also marked as autouse, and therefore runs automatically without + test cases having to call it explicitly. + """ + cert_path = _get_cert_path(request) + with open(cert_path) as cert: + project_id = json.load(cert).get('project_id') + if not project_id: + raise ValueError('Failed to determine project ID from service account certificate.') + cred = credentials.Certificate(cert_path) + ops = {'dbURL' : 'https://{0}.firebaseio.com'.format(project_id)} + return firebase_admin.initialize_app(cred, ops) diff --git a/integration/test_db.py b/integration/test_db.py new file mode 100644 index 00000000..fd66b2b9 --- /dev/null +++ b/integration/test_db.py @@ -0,0 +1,250 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.db module.""" +import collections +import json + +import pytest + +from firebase_admin import db +from tests import testutils + +def _update_rules(): + with open(testutils.resource_filename('dinosaurs_index.json')) as index_file: + index = json.load(index_file) + client = db.get_reference()._client + rules = client.request('get', '/.settings/rules.json') + existing = rules.get('rules', dict()).get('_adminsdk') + if existing != index: + rules['rules']['_adminsdk'] = index + client.request('put', '/.settings/rules.json', json=rules) + +@pytest.fixture(scope='module') +def testdata(): + with open(testutils.resource_filename('dinosaurs.json')) as dino_file: + return json.load(dino_file) + +@pytest.fixture(scope='module') +def testref(): + """Adds the necessary DB indices, and sets the initial values. + + This fixture is attached to the module scope, and therefore is guaranteed to run only once + during the execution of this test module. + + Returns: + Reference: A reference to the test dinosaur database. + """ + _update_rules() + ref = db.get_reference('_adminsdk/python/dinodb') + ref.set_value(testdata()) + return ref + + +class TestReferenceAttributes(object): + """Test cases for attributes exposed by db.Reference class.""" + + def test_ref_attributes(self, testref): + assert testref.key == 'dinodb' + assert testref.path == '/_adminsdk/python/dinodb' + + def test_child(self, testref): + child = testref.child('dinosaurs') + assert child.key == 'dinosaurs' + assert child.path == '/_adminsdk/python/dinodb/dinosaurs' + + def test_parent(self, testref): + parent = testref.parent + assert parent.key == 'python' + assert parent.path == '/_adminsdk/python' + + +class TestReadOperations(object): + """Test cases for reading node values.""" + + def test_get_value(self, testref, testdata): + value = testref.get_value() + assert isinstance(value, dict) + assert testdata == value + + def test_get_child_value(self, testref, testdata): + value = testref.child('dinosaurs').get_value() + assert isinstance(value, dict) + assert testdata['dinosaurs'] == value + + def test_get_grandchild_value(self, testref, testdata): + value = testref.child('dinosaurs').child('lambeosaurus').get_value() + assert isinstance(value, dict) + assert testdata['dinosaurs']['lambeosaurus'] == value + + def test_get_nonexisting_child_value(self, testref): + assert testref.child('none_existing').get_value() is None + + +class TestWriteOperations(object): + """Test cases for creating and updating node values.""" + + def test_push(self, testref): + python = testref.parent + ref = python.child('users').push() + assert ref.path == '/_adminsdk/python/users/' + ref.key + assert ref.get_value() == '' + + def test_push_with_value(self, testref): + python = testref.parent + value = {'name' : 'Luis Alvarez', 'since' : 1911} + ref = python.child('users').push(value) + assert ref.path == '/_adminsdk/python/users/' + ref.key + assert ref.get_value() == value + + def test_set_primitive_value(self, testref): + python = testref.parent + ref = python.child('users').push() + ref.set_value('value') + assert ref.get_value() == 'value' + + def test_set_complex_value(self, testref): + python = testref.parent + value = {'name' : 'Mary Anning', 'since' : 1799} + ref = python.child('users').push() + ref.set_value(value) + assert ref.get_value() == value + + def test_set_primitive_value_with_priority(self, testref): + python = testref.parent + ref = python.child('users').push() + ref.set_value('value', 1) + assert ref.get_value() == 'value' + assert ref.get_priority() == 1 + + def test_set_complex_value_with_priority(self, testref): + python = testref.parent + value = {'name' : 'Barnum Brown', 'since' : 1873} + ref = python.child('users').push() + ref.set_value(value, 2) + assert ref.get_value() == value + assert ref.get_priority() == 2 + + def test_update_children(self, testref): + python = testref.parent + value = {'name' : 'Robert Bakker', 'since' : 1945} + ref = python.child('users').push() + ref.update_children(value) + assert ref.get_value() == value + + def test_update_children_with_existing_values(self, testref): + python = testref.parent + ref = python.child('users').push({'name' : 'Edwin Colbert', 'since' : 1900}) + ref.update_children({'since' : 1905}) + assert ref.get_value() == {'name' : 'Edwin Colbert', 'since' : 1905} + + def test_delete(self, testref): + python = testref.parent + ref = python.child('users').push('foo') + assert ref.get_value() == 'foo' + ref.delete() + assert ref.get_value() is None + + +class TestAdvancedQueries(object): + """Test cases for advanced interactions via the db.Query interface.""" + + height_sorted = [ + 'linhenykus', 'pterodactyl', 'lambeosaurus', + 'triceratops', 'stegosaurus', 'bruhathkayosaurus', + ] + + def test_order_by_key(self, testref): + value = testref.child('dinosaurs').order_by_key().run() + assert isinstance(value, collections.OrderedDict) + assert list(value.keys()) == [ + 'bruhathkayosaurus', 'lambeosaurus', 'linhenykus', + 'pterodactyl', 'stegosaurus', 'triceratops' + ] + + def test_order_by_value(self, testref): + value = testref.child('scores').order_by_value().run() + assert list(value.keys()) == [ + 'stegosaurus', 'lambeosaurus', 'triceratops', + 'bruhathkayosaurus', 'linhenykus', 'pterodactyl', + ] + + def test_order_by_child(self, testref): + value = testref.child('dinosaurs').order_by_child('height').run() + assert list(value.keys()) == self.height_sorted + + def test_limit_first(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_limit_first(2).run() + assert list(value.keys()) == self.height_sorted[:2] + + def test_limit_first_all(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_limit_first(10).run() + assert list(value.keys()) == self.height_sorted + + def test_limit_last(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_limit_last(2).run() + assert list(value.keys()) == self.height_sorted[-2:] + + def test_limit_last_all(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_limit_last(10).run() + assert list(value.keys()) == self.height_sorted + + def test_start_at(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_start_at(3.5).run() + assert list(value.keys()) == self.height_sorted[-2:] + + def test_end_at(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_end_at(3.5).run() + assert list(value.keys()) == self.height_sorted[:4] + + def test_start_and_end_at(self, testref): + value = testref.child('dinosaurs').order_by_child('height') \ + .set_start_at(2.5).set_end_at(5).run() + assert list(value.keys()) == self.height_sorted[-3:-1] + + def test_equal_to(self, testref): + value = testref.child('dinosaurs').order_by_child('height').set_equal_to(0.6).run() + assert list(value.keys()) == self.height_sorted[:2] + + def test_order_by_nested_child(self, testref): + value = testref.child('dinosaurs').order_by_child('ratings/pos').set_start_at(4).run() + assert len(value) == 3 + assert 'pterodactyl' in value + assert 'stegosaurus' in value + assert 'triceratops' in value + + def test_filter_by_key(self, testref): + value = testref.child('dinosaurs').order_by_key().set_limit_first(2).run() + assert len(value) == 2 + assert 'bruhathkayosaurus' in value + assert 'lambeosaurus' in value + + def test_filter_by_value(self, testref): + value = testref.child('scores').order_by_value().set_limit_last(2).run() + assert len(value) == 2 + assert 'pterodactyl' in value + assert 'linhenykus' in value + + def test_order_by_priority(self, testref): + python = testref.parent + museums = python.child('museums').push() + values = {'Berlin' : 1, 'Chicago' : 2, 'Brussels' : 3} + for name, priority in values.items(): + ref = museums.push() + ref.set_value(name, priority) + result = museums.order_by_priority().set_limit_last(2).run() + assert isinstance(result, dict) + assert len(result) == 2 + assert 'Brussels' in result.values() + assert 'Chicago' in result.values() diff --git a/setup.cfg b/setup.cfg index 2a9acf13..a038cfa0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,5 @@ [bdist_wheel] universal = 1 + +[tool:pytest] +testpaths = tests \ No newline at end of file diff --git a/tests/data/dinosaurs.json b/tests/data/dinosaurs.json new file mode 100644 index 00000000..9d7afaab --- /dev/null +++ b/tests/data/dinosaurs.json @@ -0,0 +1,78 @@ +{ + "dinosaurs": { + "bruhathkayosaurus": { + "appeared": -70000000, + "height": 25, + "length": 44, + "order": "saurischia", + "vanished": -70000000, + "weight": 135000, + "ratings": { + "pos": 1 + } + }, + "lambeosaurus": { + "appeared": -76000000, + "height": 2.1, + "length": 12.5, + "order": "ornithischia", + "vanished": -75000000, + "weight": 5000, + "ratings": { + "pos": 2 + } + }, + "linhenykus": { + "appeared": -85000000, + "height": 0.6, + "length": 1, + "order": "theropoda", + "vanished": -75000000, + "weight": 3, + "ratings": { + "pos": 3 + } + }, + "pterodactyl": { + "appeared": -150000000, + "height": 0.6, + "length": 0.8, + "order": "pterosauria", + "vanished": -148500000, + "weight": 2, + "ratings": { + "pos": 4 + } + }, + "stegosaurus": { + "appeared": -155000000, + "height": 4, + "length": 9, + "order": "ornithischia", + "vanished": -150000000, + "weight": 2500, + "ratings": { + "pos": 5 + } + }, + "triceratops": { + "appeared": -68000000, + "height": 3, + "length": 8, + "order": "ornithischia", + "vanished": -66000000, + "weight": 11000, + "ratings": { + "pos": 6 + } + } + }, + "scores": { + "bruhathkayosaurus": 55, + "lambeosaurus": 21, + "linhenykus": 80, + "pterodactyl": 93, + "stegosaurus": 5, + "triceratops": 22 + } +} diff --git a/tests/data/dinosaurs_index.json b/tests/data/dinosaurs_index.json new file mode 100644 index 00000000..01accf91 --- /dev/null +++ b/tests/data/dinosaurs_index.json @@ -0,0 +1,12 @@ +{ + "python": { + "dinodb": { + "dinosaurs": { + ".indexOn": ["height", "ratings/pos"] + }, + "scores": { + ".indexOn": ".value" + } + } + } +} diff --git a/tests/test_app.py b/tests/test_app.py index 6e977374..e219ef6d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests for firebase_admin.App.""" +import datetime +import json import os import pytest @@ -159,3 +161,23 @@ def test_app_delete(self, init_app): firebase_admin.get_app(init_app.name) with pytest.raises(ValueError): firebase_admin.delete_app(init_app) + + def test_get_token(self, init_app): + mock_response = {'access_token': 'mock_access_token_1', 'expires_in': 3600} + credentials._request = testutils.MockRequest(200, json.dumps(mock_response)) + + assert init_app.get_token() == 'mock_access_token_1' + + mock_response = {'access_token': 'mock_access_token_2', 'expires_in': 3600} + credentials._request = testutils.MockRequest(200, json.dumps(mock_response)) + + expiry = init_app._token.expiry + # should return same token from cache + firebase_admin._clock = lambda: expiry - datetime.timedelta( + seconds=firebase_admin._CLOCK_SKEW_SECONDS + 1) + assert init_app.get_token() == 'mock_access_token_1' + + # should return new token from RPC call + firebase_admin._clock = lambda: expiry - datetime.timedelta( + seconds=firebase_admin._CLOCK_SKEW_SECONDS) + assert init_app.get_token() == 'mock_access_token_2' diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 00000000..ff5c87c3 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,591 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for firebase_admin.db.""" +import collections +import datetime +import json + +import pytest +from requests import adapters +from requests import models +import six + +import firebase_admin +from firebase_admin import credentials +from firebase_admin import db +from tests import testutils + + +class MockAdapter(adapters.HTTPAdapter): + def __init__(self, data, status, recorder): + adapters.HTTPAdapter.__init__(self) + self._data = data + self._status = status + self._recorder = recorder + + def send(self, request, **kwargs): # pylint: disable=unused-argument + self._recorder.append(request) + resp = models.Response() + resp.status_code = self._status + resp.raw = six.BytesIO(self._data.encode()) + return resp + + +class MockCredential(credentials.Base): + def get_access_token(self): + expiry = datetime.datetime.utcnow() + datetime.timedelta(hours=24) + return credentials.AccessTokenInfo('mock-token', expiry) + + def get_credential(self): + return None + +class _Object(object): + pass + + +class TestReferencePath(object): + """Test cases for Reference paths.""" + + # path => (fullstr, key, parent) + valid_paths = { + '/' : ('/', None, None), + '' : ('/', None, None), + '/foo' : ('/foo', 'foo', '/'), + 'foo' : ('/foo', 'foo', '/'), + '/foo/bar' : ('/foo/bar', 'bar', '/foo'), + 'foo/bar' : ('/foo/bar', 'bar', '/foo'), + '/foo/bar/' : ('/foo/bar', 'bar', '/foo'), + } + + invalid_paths = [ + None, True, False, 0, 1, dict(), list(), tuple(), _Object(), + 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', + ] + + valid_children = { + 'foo': ('/test/foo', 'foo', '/test'), + 'foo/bar' : ('/test/foo/bar', 'bar', '/test/foo'), + 'foo/bar/' : ('/test/foo/bar', 'bar', '/test/foo'), + } + + invalid_children = [ + None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), + 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() + ] + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_valid_path(self, path, expected): + ref = db.Reference(path=path) + fullstr, key, parent = expected + assert ref.path == fullstr + assert ref.key == key + if parent is None: + assert ref.parent is None + else: + assert ref.parent.path == parent + + @pytest.mark.parametrize('path', invalid_paths) + def test_invalid_key(self, path): + with pytest.raises(ValueError): + db.Reference(path=path) + + @pytest.mark.parametrize('child, expected', valid_children.items()) + def test_valid_child(self, child, expected): + fullstr, key, parent = expected + childref = db.Reference(path='/test').child(child) + assert childref.path == fullstr + assert childref.key == key + assert childref.parent.path == parent + + @pytest.mark.parametrize('child', invalid_children) + def test_invalid_child(self, child): + parent = db.Reference(path='/test') + with pytest.raises(ValueError): + parent.child(child) + + +class TestReference(object): + """Test cases for database queries via References.""" + + test_url = 'https://test.firebaseio.com' + valid_values = [ + '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} + ] + + @classmethod + def setup_class(cls): + firebase_admin.initialize_app(MockCredential(), {'dbURL' : cls.test_url}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def instrument(self, ref, payload, status=200): + recorder = [] + adapter = MockAdapter(payload, status, recorder) + ref._client._session.mount(self.test_url, adapter) + return recorder + + @pytest.mark.parametrize('data', valid_values) + def test_get_value(self, data): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + assert ref.get_value() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_order_by_query(self, data): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query_str = 'orderBy=%22foo%22' + assert query.run() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_limit_query(self, data): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query.set_limit_first(100) + query_str = 'limitToFirst=100&orderBy=%22foo%22' + assert query.run() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_range_query(self, data): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query.set_start_at(100) + query.set_end_at(200) + query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' + assert query.run() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_get_priority(self): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps('10')) + assert ref.get_priority() == '10' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test/.priority.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_set_value(self, data): + ref = db.get_reference('/test') + recorder = self.instrument(ref, '') + data = {'foo' : 'bar'} + ref.set_value(data) + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_set_primitive_value_with_priority(self): + ref = db.get_reference('/test') + recorder = self.instrument(ref, '') + ref.set_value('foo', '10') + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == {'.value' : 'foo', '.priority' : '10'} + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('priority', [10, 10.0, True, False, 'foo', 'foo123']) + def test_set_value_with_priority(self, priority): + ref = db.get_reference('/test') + recorder = self.instrument(ref, '') + data = {'foo' : 'bar'} + ref.set_value(data, priority) + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + data['.priority'] = priority + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_set_none_value(self): + ref = db.get_reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.set_value(None) + + @pytest.mark.parametrize('value', [ + _Object(), {'foo': _Object()}, [_Object()] + ]) + def test_set_non_json_value(self, value): + ref = db.get_reference('/test') + self.instrument(ref, '') + with pytest.raises(TypeError): + ref.set_value(value) + + @pytest.mark.parametrize('priority', [ + '', list(), tuple(), dict(), _Object(), {'foo': _Object()} + ]) + def test_set_invalid_priority(self, priority): + ref = db.get_reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.set_value('', priority) + + def test_update_children(self): + ref = db.get_reference('/test') + data = {'foo' : 'bar'} + recorder = self.instrument(ref, json.dumps(data)) + ref.update_children(data) + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_update_children_default(self): + ref = db.get_reference('/test') + recorder = self.instrument(ref, '') + with pytest.raises(ValueError): + ref.update_children({}) + assert len(recorder) is 0 + + @pytest.mark.parametrize('update', [ + None, {}, {None:'foo'}, {'foo': None}, '', 'foo', 0, 1, list(), tuple() + ]) + def test_set_invalid_update(self, update): + ref = db.get_reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.update_children(update) + + @pytest.mark.parametrize('data', valid_values) + def test_push(self, data): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + child = ref.push(data) + assert isinstance(child, db.Reference) + assert child.key == 'testkey' + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_push_default(self): + ref = db.get_reference('/test') + recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + assert ref.push().key == 'testkey' + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == '' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_push_none_value(self): + ref = db.get_reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.push(None) + + def test_delete(self): + ref = db.get_reference('/test') + recorder = self.instrument(ref, '') + ref.delete() + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_get_root_reference(self): + ref = db.get_reference() + assert ref.key is None + assert ref.path == '/' + + @pytest.mark.parametrize('path, expected', TestReferencePath.valid_paths.items()) + def test_get_reference(self, path, expected): + ref = db.get_reference(path) + fullstr, key, parent = expected + assert ref.path == fullstr + assert ref.key == key + if parent is None: + assert ref.parent is None + else: + assert ref.parent.path == parent + + +class TestDatabseInitialization(object): + """Test cases for database initialization.""" + + def teardown_method(self): + testutils.cleanup_apps() + + def test_no_app(self): + with pytest.raises(ValueError): + db.get_reference() + + def test_no_db_url(self): + firebase_admin.initialize_app(credentials.Base()) + with pytest.raises(ValueError): + db.get_reference() + + @pytest.mark.parametrize('url', [ + 'https://test.firebaseio.com', 'https://test.firebaseio.com/' + ]) + def test_valid_db_url(self, url): + firebase_admin.initialize_app(credentials.Base(), {'dbURL' : url}) + ref = db.get_reference() + assert ref._client._url == 'https://test.firebaseio.com' + + @pytest.mark.parametrize('url', [ + None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', + True, False, 1, 0, dict(), list(), tuple(), + ]) + def test_invalid_db_url(self, url): + firebase_admin.initialize_app(credentials.Base(), {'dbURL' : url}) + with pytest.raises(ValueError): + db.get_reference() + + def test_app_delete(self): + app = firebase_admin.initialize_app( + credentials.Base(), {'dbURL' : 'https://test.firebaseio.com'}) + ref = db.get_reference() + assert ref is not None + assert ref._client._auth is not None + firebase_admin.delete_app(app) + assert ref._client._auth is None + with pytest.raises(ValueError): + db.get_reference() + + +@pytest.fixture(params=['foo', '$key', '$value', '$priority']) +def initquery(request): + ref = db.Reference(path='foo') + if request.param == '$key': + return ref.order_by_key(), request.param + elif request.param == '$value': + return ref.order_by_value(), request.param + elif request.param == '$priority': + return ref.order_by_priority(), request.param + else: + return ref.order_by_child(request.param), request.param + + +class TestQuery(object): + """Test cases for db.Query class.""" + + valid_paths = { + 'foo' : 'foo', + 'foo/bar' : 'foo/bar', + 'foo/bar/' : 'foo/bar' + } + + ref = db.Reference(path='foo') + + @pytest.mark.parametrize('path', [ + '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), + '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' + ]) + def test_invalid_path(self, path): + with pytest.raises(ValueError): + self.ref.order_by_child(path) + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_order_by_valid_path(self, path, expected): + query = self.ref.order_by_child(path) + assert query.querystr == 'orderBy="{0}"'.format(expected) + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_filter_by_valid_path(self, path, expected): + query = self.ref.order_by_child(path) + query.set_equal_to(10) + assert query.querystr == 'equalTo=10&orderBy="{0}"'.format(expected) + + def test_order_by_key(self): + query = self.ref.order_by_key() + assert query.querystr == 'orderBy="$key"' + + def test_key_filter(self): + query = self.ref.order_by_key() + query.set_equal_to(10) + assert query.querystr == 'equalTo=10&orderBy="$key"' + + def test_order_by_value(self): + query = self.ref.order_by_value() + assert query.querystr == 'orderBy="$value"' + + def test_value_filter(self): + query = self.ref.order_by_value() + query.set_equal_to(10) + assert query.querystr == 'equalTo=10&orderBy="$value"' + + def test_order_by_priority(self): + query = self.ref.order_by_priority() + assert query.querystr == 'orderBy="$priority"' + + def test_priority_filter(self): + query = self.ref.order_by_priority() + query.set_equal_to(10) + assert query.querystr == 'equalTo=10&orderBy="$priority"' + + def test_multiple_limits(self): + query = self.ref.order_by_child('foo') + query.set_limit_first(1) + with pytest.raises(ValueError): + query.set_limit_last(2) + + query = self.ref.order_by_child('foo') + query.set_limit_last(2) + with pytest.raises(ValueError): + query.set_limit_first(1) + + def test_start_at_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.set_start_at(None) + + def test_end_at_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.set_end_at(None) + + def test_equal_to_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.set_equal_to(None) + + def test_range_query(self, initquery): + query, order_by = initquery + query.set_start_at(1) + query.set_equal_to(2) + query.set_end_at(3) + assert query.querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) + + def test_limit_first_query(self, initquery): + query, order_by = initquery + query.set_limit_first(1) + assert query.querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) + + def test_limit_last_query(self, initquery): + query, order_by = initquery + query.set_limit_last(1) + assert query.querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) + + def test_all_in(self, initquery): + query, order_by = initquery + query.set_start_at(1) + query.set_equal_to(2) + query.set_end_at(3) + query.set_limit_first(10) + expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) + assert query.querystr == expected + + +class TestSorter(object): + """Test cases for db._Sorter class.""" + + value_test_cases = [ + ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), + ({'k1' : 3, 'k2' : 2, 'k3' : 1}, ['k3', 'k2', 'k1']), + ({'k1' : 3, 'k2' : 1, 'k3' : 2}, ['k2', 'k3', 'k1']), + ({'k1' : 3, 'k2' : 1, 'k3' : 1}, ['k2', 'k3', 'k1']), + ({'k1' : 1, 'k2' : 2, 'k3' : 1}, ['k1', 'k3', 'k2']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 'baz'}, ['k2', 'k3', 'k1']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 10}, ['k3', 'k2', 'k1']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : None}, ['k3', 'k2', 'k1']), + ({'k1' : 5, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), + ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), + ] + + list_test_cases = [ + ([], []), + ([1, 2, 3], [1, 2, 3]), + ([3, 2, 1], [1, 2, 3]), + ([1, 3, 2], [1, 2, 3]), + (['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']), + (['foo', 1, False, None, 0, True], [None, False, True, 0, 1, 'foo']), + ] + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_value(self, result, expected): + ordered = db._Sorter(result, '$value').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', list_test_cases) + def test_order_by_value_with_list(self, result, expected): + ordered = db._Sorter(result, '$value').get() + assert isinstance(ordered, list) + assert ordered == expected + + @pytest.mark.parametrize('value', [None, False, True, 0, 1, 'foo']) + def test_invalid_sort(self, value): + with pytest.raises(ValueError): + db._Sorter(value, '$value') + + @pytest.mark.parametrize('result, expected', [ + ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), + ({'k3' : 3, 'k2' : 2, 'k1' : 1}, ['k1', 'k2', 'k3']), + ({'k1' : 3, 'k3' : 1, 'k2' : 2}, ['k1', 'k2', 'k3']), + ]) + def test_order_by_key(self, result, expected): + ordered = db._Sorter(result, '$key').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_child(self, result, expected): + nested = {} + for key, val in result.items(): + nested[key] = {'child' : val} + ordered = db._Sorter(nested, 'child').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_grand_child(self, result, expected): + nested = {} + for key, val in result.items(): + nested[key] = {'child' : {'grandchild' : val}} + ordered = db._Sorter(nested, 'child/grandchild').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', [ + ({'k1': {'child': 1}, 'k2': {}}, ['k2', 'k1']), + ({'k1': {'child': 1}, 'k2': {'child': 0}}, ['k2', 'k1']), + ({'k1': {'child': 1}, 'k2': {'child': {}}, 'k3': {}}, ['k3', 'k1', 'k2']), + ]) + def test_child_path_resolution(self, result, expected): + ordered = db._Sorter(result, 'child').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected