diff --git a/CHANGELOG.md b/CHANGELOG.md index 26877b525..012d8a403 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- [added] A new `auth.import_users()` API for importing users into Firebase + Auth in bulk. - [fixed] The `db.Reference.update()` function now accepts dictionaries with `None` values. This can be used to delete child keys from a reference. diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py new file mode 100644 index 000000000..852438725 --- /dev/null +++ b/firebase_admin/_auth_utils.py @@ -0,0 +1,183 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase auth utils.""" + +import json +import re + +import six +from six.moves import urllib + + +MAX_CLAIMS_PAYLOAD_SIZE = 1000 +RESERVED_CLAIMS = set([ + 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', + 'iss', 'jti', 'nbf', 'nonce', 'sub', 'firebase', +]) + + +def validate_uid(uid, required=False): + if uid is None and not required: + return None + if not isinstance(uid, six.string_types) or not uid or len(uid) > 128: + raise ValueError( + 'Invalid uid: "{0}". The uid must be a non-empty string with no more than 128 ' + 'characters.'.format(uid)) + return uid + +def validate_email(email, required=False): + if email is None and not required: + return None + if not isinstance(email, six.string_types) or not email: + raise ValueError( + 'Invalid email: "{0}". Email must be a non-empty string.'.format(email)) + parts = email.split('@') + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError('Malformed email address string: "{0}".'.format(email)) + return email + +def validate_phone(phone, required=False): + """Validates the specified phone number. + + Phone number vlidation is very lax here. Backend will enforce E.164 spec compliance, and + normalize accordingly. Here we check if the number starts with + sign, and contains at + least one alphanumeric character. + """ + if phone is None and not required: + return None + if not isinstance(phone, six.string_types) or not phone: + raise ValueError('Invalid phone number: "{0}". Phone number must be a non-empty ' + 'string.'.format(phone)) + if not phone.startswith('+') or not re.search('[a-zA-Z0-9]', phone): + raise ValueError('Invalid phone number: "{0}". Phone number must be a valid, E.164 ' + 'compliant identifier.'.format(phone)) + return phone + +def validate_password(password, required=False): + if password is None and not required: + return None + if not isinstance(password, six.string_types) or len(password) < 6: + raise ValueError( + 'Invalid password string. Password must be a string at least 6 characters long.') + return password + +def validate_bytes(value, label, required=False): + if value is None and not required: + return None + if not isinstance(value, six.binary_type) or not value: + raise ValueError('{0} must be a non-empty byte sequence.'.format(label)) + return value + +def validate_display_name(display_name, required=False): + if display_name is None and not required: + return None + if not isinstance(display_name, six.string_types) or not display_name: + raise ValueError( + 'Invalid display name: "{0}". Display name must be a non-empty ' + 'string.'.format(display_name)) + return display_name + +def validate_provider_id(provider_id, required=True): + if provider_id is None and not required: + return None + if not isinstance(provider_id, six.string_types) or not provider_id: + raise ValueError( + 'Invalid provider ID: "{0}". Provider ID must be a non-empty ' + 'string.'.format(provider_id)) + return provider_id + +def validate_photo_url(photo_url, required=False): + if photo_url is None and not required: + return None + if not isinstance(photo_url, six.string_types) or not photo_url: + raise ValueError( + 'Invalid photo URL: "{0}". Photo URL must be a non-empty ' + 'string.'.format(photo_url)) + try: + parsed = urllib.parse.urlparse(photo_url) + if not parsed.netloc: + raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) + return photo_url + except Exception: + raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) + +def validate_timestamp(timestamp, label, required=False): + if timestamp is None and not required: + return None + if isinstance(timestamp, bool): + raise ValueError('Boolean value specified as timestamp.') + try: + timestamp_int = int(timestamp) + except TypeError: + raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) + else: + if timestamp_int != timestamp: + raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) + if timestamp_int <= 0: + raise ValueError('{0} timestamp must be a positive interger.'.format(label)) + return timestamp_int + +def validate_int(value, label, low=None, high=None): + """Validates that the given value represents an integer. + + There are several ways to represent an integer in Python (e.g. 2, 2L, 2.0). This method allows + for all such representations except for booleans. Booleans also behave like integers, but + always translate to 1 and 0. Passing a boolean to an API that expects integers is most likely + a developer error. + """ + if value is None or isinstance(value, bool): + raise ValueError('Invalid type for integer value: {0}.'.format(value)) + try: + val_int = int(value) + except TypeError: + raise ValueError('Invalid type for integer value: {0}.'.format(value)) + else: + if val_int != value: + # This will be True for non-numeric values like '2' and non-whole numbers like 2.5. + raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) + if low is not None and val_int < low: + raise ValueError('{0} must not be smaller than {1}.'.format(label, low)) + if high is not None and val_int > high: + raise ValueError('{0} must not be larger than {1}.'.format(label, high)) + return val_int + +def validate_custom_claims(custom_claims, required=False): + """Validates the specified custom claims. + + Custom claims must be specified as a JSON string. The string must not exceed 1000 + characters, and the parsed JSON payload must not contain reserved JWT claims. + """ + if custom_claims is None and not required: + return None + claims_str = str(custom_claims) + if len(claims_str) > MAX_CLAIMS_PAYLOAD_SIZE: + raise ValueError( + 'Custom claims payload must not exceed {0} characters.'.format( + MAX_CLAIMS_PAYLOAD_SIZE)) + try: + parsed = json.loads(claims_str) + except Exception: + raise ValueError('Failed to parse custom claims string as JSON.') + + if not isinstance(parsed, dict): + raise ValueError('Custom claims must be parseable as a JSON object.') + invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys())) + if len(invalid_claims) > 1: + joined = ', '.join(sorted(invalid_claims)) + raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined)) + elif len(invalid_claims) == 1: + raise ValueError( + 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) + return claims_str diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py new file mode 100644 index 000000000..4e846773a --- /dev/null +++ b/firebase_admin/_user_import.py @@ -0,0 +1,403 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase user import sub module.""" + +import base64 +import json + +from firebase_admin import _auth_utils + + +def b64_encode(bytes_value): + return base64.urlsafe_b64encode(bytes_value).decode() + + +class UserProvider(object): + """Represents a user identity provider that can be associated with a Firebase user. + + One or more providers can be specified in a ``UserImportRecord`` when importing users via + ``auth.import_users()``. + + Args: + uid: User's unique ID assigned by the identity provider. + provider_id: ID of the identity provider. This can be a short domain name or the identifier + of an OpenID identity provider. + email: User's email address (optional). + display_name: User's display name (optional). + photo_url: User's photo URL (optional). + """ + + def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=None): + self.uid = uid + self.provider_id = provider_id + self.email = email + self.display_name = display_name + self.photo_url = photo_url + + @property + def uid(self): + return self._uid + + @uid.setter + def uid(self, uid): + self._uid = _auth_utils.validate_uid(uid, required=True) + + @property + def provider_id(self): + return self._provider_id + + @provider_id.setter + def provider_id(self, provider_id): + self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) + + @property + def email(self): + return self._email + + @email.setter + def email(self, email): + self._email = _auth_utils.validate_email(email) + + @property + def display_name(self): + return self._display_name + + @display_name.setter + def display_name(self, display_name): + self._display_name = _auth_utils.validate_display_name(display_name) + + @property + def photo_url(self): + return self._photo_url + + @photo_url.setter + def photo_url(self, photo_url): + self._photo_url = _auth_utils.validate_photo_url(photo_url) + + def to_dict(self): + payload = { + 'rawId': self.uid, + 'providerId': self.provider_id, + 'displayName': self.display_name, + 'email': self.email, + 'photoUrl': self.photo_url, + } + return {k: v for k, v in payload.items() if v is not None} + + +class UserImportRecord(object): + """Represents a user account to be imported to Firebase Auth. + + Must specify the ``uid`` field at a minimum. A sequence of ``UserImportRecord`` objects can be + passed to the ``auth.import_users()`` function, in order to import those users into Firebase + Auth in bulk. If the ``password_hash`` is set on a user, a hash configuration must be + specified when calling ``import_users()``. + + Args: + uid: User's unique ID. Must be a non-empty string not longer than 128 characters. + email: User's email address (optional). + email_verified: A boolean indicating whether the user's email has been verified (optional). + display_name: User's display name (optional). + phone_number: User's phone number (optional). + photo_url: User's photo URL (optional). + disabled: A boolean indicating whether this user account has been disabled (optional). + user_metadata: An ``auth.UserMetadata`` instance with additional user metadata (optional). + provider_data: A list of ``auth.UserProvider`` instances (optional). + custom_claims: A ``dict`` of custom claims to be set on the user account (optional). + password_hash: User's password hash as a ``bytes`` sequence (optional). + password_salt: User's password salt as a ``bytes`` sequence (optional). + + Raises: + ValueError: If provided arguments are invalid. + """ + + def __init__(self, uid, email=None, email_verified=None, display_name=None, phone_number=None, + photo_url=None, disabled=None, user_metadata=None, provider_data=None, + custom_claims=None, password_hash=None, password_salt=None): + self.uid = uid + self.email = email + self.display_name = display_name + self.phone_number = phone_number + self.photo_url = photo_url + self.password_hash = password_hash + self.password_salt = password_salt + self.email_verified = email_verified + self.disabled = disabled + self.user_metadata = user_metadata + self.provider_data = provider_data + self.custom_claims = custom_claims + + @property + def uid(self): + return self._uid + + @uid.setter + def uid(self, uid): + self._uid = _auth_utils.validate_uid(uid, required=True) + + @property + def email(self): + return self._email + + @email.setter + def email(self, email): + self._email = _auth_utils.validate_email(email) + + @property + def display_name(self): + return self._display_name + + @display_name.setter + def display_name(self, display_name): + self._display_name = _auth_utils.validate_display_name(display_name) + + @property + def phone_number(self): + return self._phone_number + + @phone_number.setter + def phone_number(self, phone_number): + self._phone_number = _auth_utils.validate_phone(phone_number) + + @property + def photo_url(self): + return self._photo_url + + @photo_url.setter + def photo_url(self, photo_url): + self._photo_url = _auth_utils.validate_photo_url(photo_url) + + @property + def password_hash(self): + return self._password_hash + + @password_hash.setter + def password_hash(self, password_hash): + self._password_hash = _auth_utils.validate_bytes(password_hash, 'password_hash') + + @property + def password_salt(self): + return self._password_salt + + @password_salt.setter + def password_salt(self, password_salt): + self._password_salt = _auth_utils.validate_bytes(password_salt, 'password_salt') + + @property + def user_metadata(self): + return self._user_metadata + + @user_metadata.setter + def user_metadata(self, user_metadata): + created_at = user_metadata.creation_timestamp if user_metadata is not None else None + last_login_at = user_metadata.last_sign_in_timestamp if user_metadata is not None else None + self._created_at = _auth_utils.validate_timestamp(created_at, 'creation_timestamp') + self._last_login_at = _auth_utils.validate_timestamp( + last_login_at, 'last_sign_in_timestamp') + self._user_metadata = user_metadata + + @property + def provider_data(self): + return self._provider_data + + @provider_data.setter + def provider_data(self, provider_data): + if provider_data is not None: + try: + if any([not isinstance(p, UserProvider) for p in provider_data]): + raise ValueError('One or more provider data instances are invalid.') + except TypeError: + raise ValueError('provider_data must be iterable.') + self._provider_data = provider_data + + @property + def custom_claims(self): + return self._custom_claims + + @custom_claims.setter + def custom_claims(self, custom_claims): + json_claims = json.dumps(custom_claims) if isinstance( + custom_claims, dict) else custom_claims + self._custom_claims_str = _auth_utils.validate_custom_claims(json_claims) + self._custom_claims = custom_claims + + def to_dict(self): + """Returns a dict representation of the user. For internal use only.""" + payload = { + 'localId': self.uid, + 'email': self.email, + 'displayName': self.display_name, + 'phoneNumber': self.phone_number, + 'photoUrl': self.photo_url, + 'emailVerified': (bool(self.email_verified) + if self.email_verified is not None else None), + 'disabled': bool(self.disabled) if self.disabled is not None else None, + 'customAttributes': self._custom_claims_str, + 'createdAt': self._created_at, + 'lastLoginAt': self._last_login_at, + 'passwordHash': b64_encode(self.password_hash) if self.password_hash else None, + 'salt': b64_encode(self.password_salt) if self.password_salt else None, + } + if self.provider_data: + payload['providerUserInfo'] = [p.to_dict() for p in self.provider_data] + return {k: v for k, v in payload.items() if v is not None} + + +class UserImportHash(object): + """Represents a hash algorithm used to hash user passwords. + + An instance of this class must be specified when importing users with passwords via the + ``auth.import_users()`` API. + """ + + def __init__(self, name, data=None): + self._name = name + self._data = data + + def to_dict(self): + payload = {'hashAlgorithm': self._name} + if self._data: + payload.update(self._data) + return payload + + @classmethod + def _hmac(cls, name, key): + data = { + 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)) + } + return UserImportHash(name, data) + + @classmethod + def _basic_hash(cls, name, rounds): + data = {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)} + return UserImportHash(name, data) + + @classmethod + def hmac_sha512(cls, key): + """Creates a new HMAC SHA512 algorithm instance. + + Args: + key: Signer key as a byte sequence. + + Returns: + UserImportHash: A new ``UserImportHash``. + """ + return cls._hmac('HMAC_SHA512', key) + + @classmethod + def hmac_sha256(cls, key): + return cls._hmac('HMAC_SHA256', key) + + @classmethod + def hmac_sha1(cls, key): + return cls._hmac('HMAC_SHA1', key) + + @classmethod + def hmac_md5(cls, key): + return cls._hmac('HMAC_MD5', key) + + @classmethod + def md5(cls, rounds): + return cls._basic_hash('MD5', rounds) + + @classmethod + def sha1(cls, rounds): + return cls._basic_hash('SHA1', rounds) + + @classmethod + def sha256(cls, rounds): + return cls._basic_hash('SHA256', rounds) + + @classmethod + def sha512(cls, rounds): + return cls._basic_hash('SHA512', rounds) + + @classmethod + def pbkdf_sha1(cls, rounds): + return cls._basic_hash('PBKDF_SHA1', rounds) + + @classmethod + def pbkdf_sha256(cls, rounds): + return cls._basic_hash('PBKDF2_SHA256', rounds) + + @classmethod + def scrypt(cls, key, rounds, memory_cost, salt_separator=None): + data = { + 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)), + 'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8), + 'memoryCost': _auth_utils.validate_int(memory_cost, 'memory_cost', 1, 14), + } + if salt_separator: + data['saltSeparator'] = b64_encode(_auth_utils.validate_bytes( + salt_separator, 'salt_separator')) + return UserImportHash('SCRYPT', data) + + @classmethod + def bcrypt(cls): + return UserImportHash('BCRYPT') + + @classmethod + def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_length): + data = { + 'memoryCost': _auth_utils.validate_int(memory_cost, 'memory_cost', low=0), + 'parallelization': _auth_utils.validate_int(parallelization, 'parallelization', low=0), + 'blockSize': _auth_utils.validate_int(block_size, 'block_size', low=0), + 'dkLen': _auth_utils.validate_int(derived_key_length, 'derived_key_length', low=0), + } + return UserImportHash('STANDARD_SCRYPT', data) + + +class ErrorInfo(object): + """Represents an error encountered while importing a ``UserImportRecord``.""" + + def __init__(self, error): + self._index = error['index'] + self._reason = error['message'] + + @property + def index(self): + return self._index + + @property + def reason(self): + return self._reason + + +class UserImportResult(object): + """Represents the result of a bulk user import operation. + + See ``auth.import_users()`` API for more details. + """ + + def __init__(self, result, total): + errors = result.get('error', []) + self._success_count = total - len(errors) + self._failure_count = len(errors) + self._errors = [ErrorInfo(err) for err in errors] + + @property + def success_count(self): + """Returns the number of users successfully imported.""" + return self._success_count + + @property + def failure_count(self): + """Returns the number of users that failed to be imported.""" + return self._failure_count + + @property + def errors(self): + """Returns a list of ``auth.ErrorInfo`` instances describing the errors encountered.""" + return self._errors diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 120151769..42f9db365 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -15,11 +15,12 @@ """Firebase user management sub module.""" import json -import re import requests import six -from six.moves import urllib + +from firebase_admin import _auth_utils +from firebase_admin import _user_import INTERNAL_ERROR = 'INTERNAL_ERROR' @@ -27,15 +28,11 @@ USER_CREATE_ERROR = 'USER_CREATE_ERROR' USER_UPDATE_ERROR = 'USER_UPDATE_ERROR' USER_DELETE_ERROR = 'USER_DELETE_ERROR' +USER_IMPORT_ERROR = 'USER_IMPORT_ERROR' USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR' MAX_LIST_USERS_RESULTS = 1000 -MAX_CLAIMS_PAYLOAD_SIZE = 1000 -RESERVED_CLAIMS = set([ - 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', - 'iss', 'jti', 'nbf', 'nonce', 'sub', 'firebase', -]) - +MAX_IMPORT_USERS_SIZE = 1000 class _Unspecified(object): pass @@ -44,136 +41,336 @@ class _Unspecified(object): _UNSPECIFIED = _Unspecified() -class _Validator(object): - """A collection of data validation utilities.""" +class ApiCallError(Exception): + """Represents an Exception encountered while invoking the Firebase user management API.""" - @classmethod - def validate_uid(cls, uid, required=False): - if uid is None and not required: - return None - if not isinstance(uid, six.string_types) or not uid or len(uid) > 128: - raise ValueError( - 'Invalid uid: "{0}". The uid must be a non-empty string with no more than 128 ' - 'characters.'.format(uid)) - return uid + def __init__(self, code, message, error=None): + Exception.__init__(self, message) + self.code = code + self.detail = error - @classmethod - def validate_email(cls, email, required=False): - if email is None and not required: - return None - if not isinstance(email, six.string_types): - raise ValueError( - 'Invalid email: "{0}". Email must be a non-empty string.'.format(email)) - parts = email.split('@') - if len(parts) != 2 or not parts[0] or not parts[1]: - raise ValueError('Malformed email address string: "{0}".'.format(email)) - return email - - @classmethod - def validate_phone(cls, phone, required=False): - """Validates the specified phone number. - - Phone number vlidation is very lax here. Backend will enforce E.164 spec compliance, and - normalize accordingly. Here we check if the number starts with + sign, and contains at - least one alphanumeric character. + +class UserMetadata(object): + """Contains additional metadata associated with a user account.""" + + def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None): + self._creation_timestamp = _auth_utils.validate_timestamp( + creation_timestamp, 'creation_timestamp') + self._last_sign_in_timestamp = _auth_utils.validate_timestamp( + last_sign_in_timestamp, 'last_sign_in_timestamp') + + @property + def creation_timestamp(self): + """ Creation timestamp in milliseconds since the epoch. + + Returns: + integer: The user creation timestamp in milliseconds since the epoch. """ - if phone is None and not required: - return None - if not isinstance(phone, six.string_types): - raise ValueError('Invalid phone number: "{0}". Phone number must be a non-empty ' - 'string.'.format(phone)) - if not phone.startswith('+') or not re.search('[a-zA-Z0-9]', phone): - raise ValueError('Invalid phone number: "{0}". Phone number must be a valid, E.164 ' - 'compliant identifier.'.format(phone)) - return phone - - @classmethod - def validate_password(cls, password, required=False): - if password is None and not required: - return None - if not isinstance(password, six.string_types) or len(password) < 6: - raise ValueError( - 'Invalid password string. Password must be a string at least 6 characters long.') - return password + return self._creation_timestamp - @classmethod - def validate_display_name(cls, display_name, required=False): - if display_name is None and not required: - return None - if not isinstance(display_name, six.string_types) or not display_name: - raise ValueError( - 'Invalid display name: "{0}". Display name must be a non-empty ' - 'string.'.format(display_name)) - return display_name + @property + def last_sign_in_timestamp(self): + """ Last sign in timestamp in milliseconds since the epoch. - @classmethod - def validate_photo_url(cls, photo_url, required=False): - if photo_url is None and not required: - return None - if not isinstance(photo_url, six.string_types) or not photo_url: - raise ValueError( - 'Invalid photo URL: "{0}". Photo URL must be a non-empty ' - 'string.'.format(photo_url)) - try: - parsed = urllib.parse.urlparse(photo_url) - if not parsed.netloc: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) - return photo_url - except Exception: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) - - @classmethod - def validate_timestamp(cls, timestamp, label, required=False): - if timestamp is None and not required: - return None - if isinstance(timestamp, bool): - raise ValueError('Boolean value specified as timestamp.') - try: - timestamp_int = int(timestamp) - if timestamp_int <= 0: - raise ValueError('{0} timestamp must be a positive interger.'.format(label)) - return timestamp_int - except TypeError: - raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) + Returns: + integer: The last sign in timestamp in milliseconds since the epoch. + """ + return self._last_sign_in_timestamp + + +class UserInfo(object): + """A collection of standard profile information for a user. + + Used to expose profile information returned by an identity provider. + """ + + @property + def uid(self): + """Returns the user ID of this user.""" + raise NotImplementedError + + @property + def display_name(self): + """Returns the display name of this user.""" + raise NotImplementedError + + @property + def email(self): + """Returns the email address associated with this user.""" + raise NotImplementedError + + @property + def phone_number(self): + """Returns the phone number associated with this user.""" + raise NotImplementedError + + @property + def photo_url(self): + """Returns the photo URL of this user.""" + raise NotImplementedError + + @property + def provider_id(self): + """Returns the ID of the identity provider. + + This can be a short domain name (e.g. google.com), or the identity of an OpenID + identity provider. + """ + raise NotImplementedError + + +class UserRecord(UserInfo): + """Contains metadata associated with a Firebase user account.""" + + def __init__(self, data): + super(UserRecord, self).__init__() + if not isinstance(data, dict): + raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) + if not data.get('localId'): + raise ValueError('User ID must not be None or empty.') + self._data = data + + @property + def uid(self): + """Returns the user ID of this user. + + Returns: + string: A user ID string. This value is never None or empty. + """ + return self._data.get('localId') + + @property + def display_name(self): + """Returns the display name of this user. + + Returns: + string: A display name string or None. + """ + return self._data.get('displayName') + + @property + def email(self): + """Returns the email address associated with this user. + + Returns: + string: An email address string or None. + """ + return self._data.get('email') + + @property + def phone_number(self): + """Returns the phone number associated with this user. + + Returns: + string: A phone number string or None. + """ + return self._data.get('phoneNumber') + + @property + def photo_url(self): + """Returns the photo URL of this user. + + Returns: + string: A URL string or None. + """ + return self._data.get('photoUrl') + + @property + def provider_id(self): + """Returns the provider ID of this user. + + Returns: + string: A constant provider ID value. + """ + return 'firebase' + + @property + def email_verified(self): + """Returns whether the email address of this user has been verified. + + Returns: + bool: True if the email has been verified, and False otherwise. + """ + return bool(self._data.get('emailVerified')) + + @property + def disabled(self): + """Returns whether this user account is disabled. - @classmethod - def validate_custom_claims(cls, custom_claims, required=False): - """Validates the specified custom claims. + Returns: + bool: True if the user account is disabled, and False otherwise. + """ + return bool(self._data.get('disabled')) + + @property + def tokens_valid_after_timestamp(self): + """Returns the time, in milliseconds since the epoch, before which tokens are invalid. + + Note: this is truncated to 1 second accuracy. + + Returns: + int: Timestamp in milliseconds since the epoch, truncated to the second. + All tokens issued before that time are considered revoked. + """ + valid_since = self._data.get('validSince') + if valid_since is not None: + return 1000 * int(valid_since) + return None + + @property + def user_metadata(self): + """Returns additional metadata associated with this user. - Custom claims must be specified as a JSON string. The string must not exceed 1000 - characters, and the parsed JSON payload must not contain reserved JWT claims. + Returns: + UserMetadata: A UserMetadata instance. Does not return None. """ - if custom_claims is None and not required: + def _int_or_none(key): + if key in self._data: + return int(self._data[key]) return None - claims_str = str(custom_claims) - if len(claims_str) > MAX_CLAIMS_PAYLOAD_SIZE: - raise ValueError( - 'Custom claims payload must not exceed {0} ' - 'characters.'.format(MAX_CLAIMS_PAYLOAD_SIZE)) - try: - parsed = json.loads(claims_str) - except Exception: - raise ValueError('Failed to parse custom claims string as JSON.') - - if not isinstance(parsed, dict): - raise ValueError('Custom claims must be parseable as a JSON object.') - invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys())) - if len(invalid_claims) > 1: - joined = ', '.join(sorted(invalid_claims)) - raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined)) - elif len(invalid_claims) == 1: - raise ValueError( - 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) - return claims_str + return UserMetadata(_int_or_none('createdAt'), _int_or_none('lastLoginAt')) + @property + def provider_data(self): + """Returns a list of UserInfo instances. -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase user management API.""" + Each object represents an identity from an identity provider that is linked to this user. - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error + Returns: + list: A list of UserInfo objects, which may be empty. + """ + providers = self._data.get('providerUserInfo', []) + return [ProviderUserInfo(entry) for entry in providers] + + @property + def custom_claims(self): + """Returns any custom claims set on this user account. + + Returns: + dict: A dictionary of claims or None. + """ + claims = self._data.get('customAttributes') + if claims: + parsed = json.loads(claims) + if parsed != {}: + return parsed + return None + + +class ExportedUserRecord(UserRecord): + """Contains metadata associated with a user including password hash and salt.""" + + def __init__(self, data): + super(ExportedUserRecord, self).__init__(data) + + @property + def password_hash(self): + """The user's password hash as a base64-encoded string. + + If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this + is the base64-encoded password hash of the user. If a different hashing algorithm was + used to create this user, as is typical when migrating from another Auth system, this + is an empty string. If no password is set, this is ``None``. + """ + return self._data.get('passwordHash') + + @property + def password_salt(self): + """The user's password salt as a base64-encoded string. + + If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this + is the base64-encoded password salt of the user. If a different hashing algorithm was + used to create this user, as is typical when migrating from another Auth system, this is + an empty string. If no password is set, this is ``None``. + """ + return self._data.get('salt') + + +class ListUsersPage(object): + """Represents a page of user records exported from a Firebase project. + + Provides methods for traversing the user accounts included in this page, as well as retrieving + subsequent pages of users. The iterator returned by ``iterate_all()`` can be used to iterate + through all users in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + + @property + def users(self): + """A list of ``ExportedUserRecord`` instances available in this page.""" + return [ExportedUserRecord(user) for user in self._current.get('users', [])] + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of user accounts, if available. + + Returns: + ListUsersPage: Next page of users, or None if this is the last page. + """ + if self.has_next_page: + return ListUsersPage(self._download, self.next_page_token, self._max_results) + return None + + def iterate_all(self): + """Retrieves an iterator for user accounts. + + Returned iterator will iterate through all the user accounts in the Firebase project + starting from this page. The iterator will never buffer more than one page of users + in memory at a time. + + Returns: + iterator: An iterator of ExportedUserRecord instances. + """ + return _UserIterator(self) + + +class ProviderUserInfo(UserInfo): + """Contains metadata regarding how a user is known by a particular identity provider.""" + + def __init__(self, data): + super(ProviderUserInfo, self).__init__() + if not isinstance(data, dict): + raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) + if not data.get('rawId'): + raise ValueError('User ID must not be None or empty.') + self._data = data + + @property + def uid(self): + return self._data.get('rawId') + + @property + def display_name(self): + return self._data.get('displayName') + + @property + def email(self): + return self._data.get('email') + + @property + def phone_number(self): + return self._data.get('phoneNumber') + + @property + def photo_url(self): + return self._data.get('photoUrl') + + @property + def provider_id(self): + return self._data.get('providerId') class UserManager(object): @@ -186,15 +383,15 @@ def get_user(self, **kwargs): """Gets the user data corresponding to the provided key.""" if 'uid' in kwargs: key, key_type = kwargs.pop('uid'), 'user ID' - payload = {'localId' : [_Validator.validate_uid(key, required=True)]} + payload = {'localId' : [_auth_utils.validate_uid(key, required=True)]} elif 'email' in kwargs: key, key_type = kwargs.pop('email'), 'email' - payload = {'email' : [_Validator.validate_email(key, required=True)]} + payload = {'email' : [_auth_utils.validate_email(key, required=True)]} elif 'phone_number' in kwargs: key, key_type = kwargs.pop('phone_number'), 'phone number' - payload = {'phoneNumber' : [_Validator.validate_phone(key, required=True)]} + payload = {'phoneNumber' : [_auth_utils.validate_phone(key, required=True)]} else: - raise ValueError('Unsupported keyword arguments: {0}.'.format(kwargs)) + raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) try: response = self._client.request('post', 'getAccountInfo', json=payload) @@ -232,12 +429,12 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None photo_url=None, password=None, disabled=None, email_verified=None): """Creates a new user account with the specified properties.""" payload = { - 'localId': _Validator.validate_uid(uid), - 'displayName': _Validator.validate_display_name(display_name), - 'email': _Validator.validate_email(email), - 'phoneNumber': _Validator.validate_phone(phone_number), - 'photoUrl': _Validator.validate_photo_url(photo_url), - 'password': _Validator.validate_password(password), + 'localId': _auth_utils.validate_uid(uid), + 'displayName': _auth_utils.validate_display_name(display_name), + 'email': _auth_utils.validate_email(email), + 'phoneNumber': _auth_utils.validate_phone(phone_number), + 'photoUrl': _auth_utils.validate_photo_url(photo_url), + 'password': _auth_utils.validate_password(password), 'emailVerified': bool(email_verified) if email_verified is not None else None, 'disabled': bool(disabled) if disabled is not None else None, } @@ -256,10 +453,10 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ valid_since=None, custom_claims=_UNSPECIFIED): """Updates an existing user account with the specified properties""" payload = { - 'localId': _Validator.validate_uid(uid, required=True), - 'email': _Validator.validate_email(email), - 'password': _Validator.validate_password(password), - 'validSince': _Validator.validate_timestamp(valid_since, 'valid_since'), + 'localId': _auth_utils.validate_uid(uid, required=True), + 'email': _auth_utils.validate_email(email), + 'password': _auth_utils.validate_password(password), + 'validSince': _auth_utils.validate_timestamp(valid_since, 'valid_since'), 'emailVerified': bool(email_verified) if email_verified is not None else None, 'disableUser': bool(disabled) if disabled is not None else None, } @@ -269,12 +466,12 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ if display_name is None: remove.append('DISPLAY_NAME') else: - payload['displayName'] = _Validator.validate_display_name(display_name) + payload['displayName'] = _auth_utils.validate_display_name(display_name) if photo_url is not _UNSPECIFIED: if photo_url is None: remove.append('PHOTO_URL') else: - payload['photoUrl'] = _Validator.validate_photo_url(photo_url) + payload['photoUrl'] = _auth_utils.validate_photo_url(photo_url) if remove: payload['deleteAttribute'] = remove @@ -282,14 +479,14 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ if phone_number is None: payload['deleteProvider'] = ['phone'] else: - payload['phoneNumber'] = _Validator.validate_phone(phone_number) + payload['phoneNumber'] = _auth_utils.validate_phone(phone_number) if custom_claims is not _UNSPECIFIED: if custom_claims is None: custom_claims = {} json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims - payload['customAttributes'] = _Validator.validate_custom_claims(json_claims) + payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) payload = {k: v for k, v in payload.items() if v is not None} try: @@ -304,7 +501,7 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ def delete_user(self, uid): """Deletes the user identified by the specified user ID.""" - _Validator.validate_uid(uid, required=True) + _auth_utils.validate_uid(uid, required=True) try: response = self._client.request('post', 'deleteAccount', json={'localId' : uid}) except requests.exceptions.RequestException as error: @@ -314,6 +511,32 @@ def delete_user(self, uid): if not response or not response.get('kind'): raise ApiCallError(USER_DELETE_ERROR, 'Failed to delete user: {0}.'.format(uid)) + def import_users(self, users, hash_alg=None): + """Imports the given list of users to Firebase Auth.""" + try: + if not users or len(users) > MAX_IMPORT_USERS_SIZE: + raise ValueError( + 'Users must be a non-empty list with no more than {0} elements.'.format( + MAX_IMPORT_USERS_SIZE)) + if any([not isinstance(u, _user_import.UserImportRecord) for u in users]): + raise ValueError('One or more user objects are invalid.') + except TypeError: + raise ValueError('users must be iterable') + + payload = {'users': [u.to_dict() for u in users]} + if any(['passwordHash' in u for u in payload['users']]): + if not isinstance(hash_alg, _user_import.UserImportHash): + raise ValueError('A UserImportHash is required to import users with passwords.') + payload.update(hash_alg.to_dict()) + try: + response = self._client.request('post', 'uploadAccount', json=payload) + except requests.exceptions.RequestException as error: + self._handle_http_error(USER_IMPORT_ERROR, 'Failed to import users.', error) + else: + if not isinstance(response, dict): + raise ApiCallError(USER_IMPORT_ERROR, 'Failed to import users.') + return response + def _handle_http_error(self, code, msg, error): if error.response is not None: msg += '\nServer response: {0}'.format(error.response.content.decode()) @@ -322,7 +545,7 @@ def _handle_http_error(self, code, msg, error): raise ApiCallError(code, msg, error) -class UserIterator(object): +class _UserIterator(object): """An iterator that allows iterating over user accounts, one at a time. This implementation loads a page of users into memory, and iterates on them. When the whole diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 1659385ae..b02ce769e 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -19,13 +19,13 @@ creating and managing user accounts in Firebase projects. """ -import json import time from google.auth import transport import firebase_admin from firebase_admin import _token_gen +from firebase_admin import _user_import from firebase_admin import _user_mgt from firebase_admin import _utils @@ -35,6 +35,47 @@ _SESSION_COOKIE_REVOKED = 'SESSION_COOKIE_REVOKED' +__all__ = [ + 'AuthError', + 'ErrorInfo', + 'ExportedUserRecord', + 'ListUsersPage', + 'UserImportHash', + 'UserImportRecord', + 'UserImportResult', + 'UserInfo', + 'UserMetadata', + 'UserProvider', + 'UserRecord', + + 'create_custom_token', + 'create_session_cookie', + 'create_user', + 'delete_user', + 'get_user', + 'get_user_by_email', + 'get_user_by_phone_number', + 'import_users', + 'list_users', + 'revoke_refresh_tokens', + 'set_custom_user_claims', + 'update_user', + 'verify_id_token', + 'verify_session_cookie', +] + +ErrorInfo = _user_import.ErrorInfo +ExportedUserRecord = _user_mgt.ExportedUserRecord +ListUsersPage = _user_mgt.ListUsersPage +UserImportHash = _user_import.UserImportHash +UserImportRecord = _user_import.UserImportRecord +UserImportResult = _user_import.UserImportResult +UserInfo = _user_mgt.UserInfo +UserMetadata = _user_mgt.UserMetadata +UserProvider = _user_import.UserProvider +UserRecord = _user_mgt.UserRecord + + def _get_auth_service(app): """Returns an _AuthService instance for an App. @@ -376,333 +417,20 @@ def delete_user(uid, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) +def import_users(users, hash_alg=None, app=None): + user_manager = _get_auth_service(app).user_manager + try: + result = user_manager.import_users(users, hash_alg) + return UserImportResult(result, len(users)) + except _user_mgt.ApiCallError as error: + raise AuthError(error.code, str(error), error.detail) + def _check_jwt_revoked(verified_claims, error_code, label, app): user = get_user(verified_claims.get('uid'), app=app) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: raise AuthError(error_code, 'The Firebase {0} has been revoked.'.format(label)) -class UserInfo(object): - """A collection of standard profile information for a user. - - Used to expose profile information returned by an identity provider. - """ - - @property - def uid(self): - """Returns the user ID of this user.""" - raise NotImplementedError - - @property - def display_name(self): - """Returns the display name of this user.""" - raise NotImplementedError - - @property - def email(self): - """Returns the email address associated with this user.""" - raise NotImplementedError - - @property - def phone_number(self): - """Returns the phone number associated with this user.""" - raise NotImplementedError - - @property - def photo_url(self): - """Returns the photo URL of this user.""" - raise NotImplementedError - - @property - def provider_id(self): - """Returns the ID of the identity provider. - - This can be a short domain name (e.g. google.com), or the identity of an OpenID - identity provider. - """ - raise NotImplementedError - - -class UserRecord(UserInfo): - """Contains metadata associated with a Firebase user account.""" - - def __init__(self, data): - super(UserRecord, self).__init__() - if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) - if not data.get('localId'): - raise ValueError('User ID must not be None or empty.') - self._data = data - - @property - def uid(self): - """Returns the user ID of this user. - - Returns: - string: A user ID string. This value is never None or empty. - """ - return self._data.get('localId') - - @property - def display_name(self): - """Returns the display name of this user. - - Returns: - string: A display name string or None. - """ - return self._data.get('displayName') - - @property - def email(self): - """Returns the email address associated with this user. - - Returns: - string: An email address string or None. - """ - return self._data.get('email') - - @property - def phone_number(self): - """Returns the phone number associated with this user. - - Returns: - string: A phone number string or None. - """ - return self._data.get('phoneNumber') - - @property - def photo_url(self): - """Returns the photo URL of this user. - - Returns: - string: A URL string or None. - """ - return self._data.get('photoUrl') - - @property - def provider_id(self): - """Returns the provider ID of this user. - - Returns: - string: A constant provider ID value. - """ - return 'firebase' - - @property - def email_verified(self): - """Returns whether the email address of this user has been verified. - - Returns: - bool: True if the email has been verified, and False otherwise. - """ - return bool(self._data.get('emailVerified')) - - @property - def disabled(self): - """Returns whether this user account is disabled. - - Returns: - bool: True if the user account is disabled, and False otherwise. - """ - return bool(self._data.get('disabled')) - - @property - def tokens_valid_after_timestamp(self): - """Returns the time, in milliseconds since the epoch, before which tokens are invalid. - - Note: this is truncated to 1 second accuracy. - - Returns: - int: Timestamp in milliseconds since the epoch, truncated to the second. - All tokens issued before that time are considered revoked. - """ - valid_since = self._data.get('validSince') - if valid_since is not None: - return 1000 * int(valid_since) - return None - - @property - def user_metadata(self): - """Returns additional metadata associated with this user. - - Returns: - UserMetadata: A UserMetadata instance. Does not return None. - """ - return UserMetadata(self._data) - - @property - def provider_data(self): - """Returns a list of UserInfo instances. - - Each object represents an identity from an identity provider that is linked to this user. - - Returns: - list: A list of UserInfo objects, which may be empty. - """ - providers = self._data.get('providerUserInfo', []) - return [_ProviderUserInfo(entry) for entry in providers] - - @property - def custom_claims(self): - """Returns any custom claims set on this user account. - - Returns: - dict: A dictionary of claims or None. - """ - claims = self._data.get('customAttributes') - if claims: - parsed = json.loads(claims) - if parsed != {}: - return parsed - return None - - -class UserMetadata(object): - """Contains additional metadata associated with a user account.""" - - def __init__(self, data): - if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) - self._data = data - - @property - def creation_timestamp(self): - """ Creation timestamp in milliseconds since the epoch. - - Returns: - integer: The user creation timestamp in milliseconds since the epoch. - """ - if 'createdAt' in self._data: - return int(self._data['createdAt']) - return None - - @property - def last_sign_in_timestamp(self): - """ Last sign in timestamp in milliseconds since the epoch. - - Returns: - integer: The last sign in timestamp in milliseconds since the epoch. - """ - if 'lastLoginAt' in self._data: - return int(self._data['lastLoginAt']) - return None - -class ExportedUserRecord(UserRecord): - """Contains metadata associated with a user including password hash and salt.""" - - def __init__(self, data): - super(ExportedUserRecord, self).__init__(data) - - @property - def password_hash(self): - """The user's password hash as a base64-encoded string. - - If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this - is the base64-encoded password hash of the user. If a different hashing algorithm was - used to create this user, as is typical when migrating from another Auth system, this - is an empty string. If no password is set, this is ``None``. - """ - return self._data.get('passwordHash') - - @property - def password_salt(self): - """The user's password salt as a base64-encoded string. - - If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this - is the base64-encoded password salt of the user. If a different hashing algorithm was - used to create this user, as is typical when migrating from another Auth system, this is - an empty string. If no password is set, this is ``None``. - """ - return self._data.get('salt') - - -class ListUsersPage(object): - """Represents a page of user records exported from a Firebase project. - - Provides methods for traversing the user accounts included in this page, as well as retrieving - subsequent pages of users. The iterator returned by ``iterate_all()`` can be used to iterate - through all users in the Firebase project starting from this page. - """ - - def __init__(self, download, page_token, max_results): - self._download = download - self._max_results = max_results - self._current = download(page_token, max_results) - - @property - def users(self): - """A list of ``ExportedUserRecord`` instances available in this page.""" - return [ExportedUserRecord(user) for user in self._current.get('users', [])] - - @property - def next_page_token(self): - """Page token string for the next page (empty string indicates no more pages).""" - return self._current.get('nextPageToken', '') - - @property - def has_next_page(self): - """A boolean indicating whether more pages are available.""" - return bool(self.next_page_token) - - def get_next_page(self): - """Retrieves the next page of user accounts, if available. - - Returns: - ListUsersPage: Next page of users, or None if this is the last page. - """ - if self.has_next_page: - return ListUsersPage(self._download, self.next_page_token, self._max_results) - return None - - def iterate_all(self): - """Retrieves an iterator for user accounts. - - Returned iterator will iterate through all the user accounts in the Firebase project - starting from this page. The iterator will never buffer more than one page of users - in memory at a time. - - Returns: - iterator: An iterator of ExportedUserRecord instances. - """ - return _user_mgt.UserIterator(self) - - -class _ProviderUserInfo(UserInfo): - """Contains metadata regarding how a user is known by a particular identity provider.""" - - def __init__(self, data): - super(_ProviderUserInfo, self).__init__() - if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) - if not data.get('rawId'): - raise ValueError('User ID must not be None or empty.') - self._data = data - - @property - def uid(self): - return self._data.get('rawId') - - @property - def display_name(self): - return self._data.get('displayName') - - @property - def email(self): - return self._data.get('email') - - @property - def phone_number(self): - return self._data.get('phoneNumber') - - @property - def photo_url(self): - return self._data.get('photoUrl') - - @property - def provider_id(self): - return self._data.get('providerId') - - class AuthError(Exception): """Represents an Exception encountered while invoking the Firebase auth API.""" diff --git a/integration/test_auth.py b/integration/test_auth.py index d5b6eeb90..1faf785b8 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -13,6 +13,7 @@ # limitations under the License. """Integration tests for firebase_admin.auth module.""" +import base64 import datetime import random import time @@ -24,13 +25,21 @@ from firebase_admin import auth -_id_toolkit_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' +_verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' +_verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword' def _sign_in(custom_token, api_key): body = {'token' : custom_token.decode(), 'returnSecureToken' : True} params = {'key' : api_key} - resp = requests.request('post', _id_toolkit_url, params=params, json=body) + resp = requests.request('post', _verify_token_url, params=params, json=body) + resp.raise_for_status() + return resp.json().get('idToken') + +def _sign_in_with_password(email, password, api_key): + body = {'email': email, 'password': password} + params = {'key' : api_key} + resp = requests.request('post', _verify_password_url, params=params, json=body) resp.raise_for_status() return resp.json().get('idToken') @@ -309,3 +318,38 @@ def test_verify_session_cookie_revoked(new_user, api_key): session_cookie = auth.create_session_cookie(id_token, expires_in=datetime.timedelta(days=1)) claims = auth.verify_session_cookie(session_cookie, check_revoked=True) assert claims['iat'] * 1000 >= user.tokens_valid_after_timestamp + +def test_import_users(): + uid, email = _random_id() + user = auth.UserImportRecord(uid=uid, email=email) + result = auth.import_users([user]) + try: + assert result.success_count == 1 + assert result.failure_count == 0 + saved_user = auth.get_user(uid) + assert saved_user.email == email + finally: + auth.delete_user(uid) + +def test_import_users_with_password(api_key): + uid, email = _random_id() + password_hash = base64.b64decode( + 'V358E8LdWJXAO7muq0CufVpEOXaj8aFiC7T/rcaGieN04q/ZPJ08WhJEHGjj9lz/2TT+/86N5VjVoc5DdBhBiw==') + user = auth.UserImportRecord( + uid=uid, email=email, password_hash=password_hash, password_salt=b'NaCl') + + scrypt_key = base64.b64decode( + 'jxspr8Ki0RYycVU8zykbdLGjFQ3McFUH0uiiTvC8pVMXAn210wjLNmdZJzxUECKbm0QsEmYUSDzZvpjeJ9WmXA==') + salt_separator = base64.b64decode('Bw==') + scrypt = auth.UserImportHash.scrypt( + key=scrypt_key, salt_separator=salt_separator, rounds=8, memory_cost=14) + result = auth.import_users([user], hash_alg=scrypt) + try: + assert result.success_count == 1 + assert result.failure_count == 0 + saved_user = auth.get_user(uid) + assert saved_user.email == email + id_token = _sign_in_with_password(email, 'password', api_key) + assert len(id_token) > 0 + finally: + auth.delete_user(uid) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 6b0c3f640..cc5c1114d 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -21,12 +21,16 @@ import firebase_admin from firebase_admin import auth +from firebase_admin import _auth_utils +from firebase_admin import _user_import from firebase_admin import _user_mgt from tests import testutils INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_DICTS = [None, 'foo', 0, 1, True, False, list(), tuple()] +INVALID_INTS = [None, 'foo', '1', -1, 1.1, True, False, list(), tuple(), dict()] +INVALID_TIMESTAMPS = ['foo', '1', 0, -1, 1.1, True, False, list(), tuple(), dict()] MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') @@ -91,16 +95,11 @@ def test_invalid_record(self, data): with pytest.raises(ValueError): auth.UserRecord(data) - @pytest.mark.parametrize('data', INVALID_DICTS) - def test_invalid_metadata(self, data): - with pytest.raises(ValueError): - auth.UserMetadata(data) - def test_metadata(self): - metadata = auth.UserMetadata({'createdAt' : 10, 'lastLoginAt' : 20}) + metadata = auth.UserMetadata(10, 20) assert metadata.creation_timestamp == 10 assert metadata.last_sign_in_timestamp == 20 - metadata = auth.UserMetadata({}) + metadata = auth.UserMetadata() assert metadata.creation_timestamp is None assert metadata.last_sign_in_timestamp is None @@ -150,15 +149,11 @@ def test_empty_custom_claims(self): @pytest.mark.parametrize('data', INVALID_DICTS + [{}, {'foo':'bar'}]) def test_invalid_provider(self, data): with pytest.raises(ValueError): - auth._ProviderUserInfo(data) + _user_mgt.ProviderUserInfo(data) class TestGetUser(object): - VALID_UID = 'testuser' - VALID_EMAIL = 'testuser@example.com' - VALID_PHONE = '+1234567890' - @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) def test_invalid_get_user(self, arg, user_mgt_app): with pytest.raises(ValueError): @@ -166,25 +161,25 @@ def test_invalid_get_user(self, arg, user_mgt_app): def test_get_user(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) - _check_user_record(auth.get_user(self.VALID_UID, user_mgt_app)) + _check_user_record(auth.get_user('testuser', user_mgt_app)) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-an-email']) - def test_invalid_get_user_by_email(self, arg): + def test_invalid_get_user_by_email(self, arg, user_mgt_app): with pytest.raises(ValueError): - auth.get_user_by_email(arg) + auth.get_user_by_email(arg, app=user_mgt_app) def test_get_user_by_email(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) - _check_user_record(auth.get_user_by_email(self.VALID_EMAIL, user_mgt_app)) + _check_user_record(auth.get_user_by_email('testuser@example.com', user_mgt_app)) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-a-phone']) - def test_invalid_get_user_by_phone(self, arg): + def test_invalid_get_user_by_phone(self, arg, user_mgt_app): with pytest.raises(ValueError): - auth.get_user_by_phone_number(arg) + auth.get_user_by_phone_number(arg, app=user_mgt_app) def test_get_user_by_phone(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) - _check_user_record(auth.get_user_by_phone_number(self.VALID_PHONE, user_mgt_app)) + _check_user_record(auth.get_user_by_phone_number('+1234567890', user_mgt_app)) def test_get_user_non_existing(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') @@ -209,7 +204,7 @@ def test_get_user_by_email_http_error(self, user_mgt_app): def test_get_user_by_phone_http_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') with pytest.raises(auth.AuthError) as excinfo: - auth.get_user_by_phone_number(self.VALID_PHONE, user_mgt_app) + auth.get_user_by_phone_number('+1234567890', user_mgt_app) assert excinfo.value.code == _user_mgt.INTERNAL_ERROR assert '{"error":"test"}' in str(excinfo.value) @@ -326,7 +321,7 @@ def test_invalid_property(self, user_mgt_app): with pytest.raises(TypeError): auth.update_user('user', unsupported='arg', app=user_mgt_app) - @pytest.mark.parametrize('arg', ['foo', 0, -1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', INVALID_TIMESTAMPS) def test_invalid_valid_since(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.update_user('user', valid_since=arg, app=user_mgt_app) @@ -375,11 +370,12 @@ def test_update_user_error(self, user_mgt_app): assert excinfo.value.code == _user_mgt.USER_UPDATE_ERROR assert '{"error":"test"}' in str(excinfo.value) - def test_update_user_valid_since(self, user_mgt_app): + @pytest.mark.parametrize('arg', [1, 1.0]) + def test_update_user_valid_since(self, user_mgt_app, arg): user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') - user_mgt.update_user('testuser', valid_since=1) + user_mgt.update_user('testuser', valid_since=arg) request = json.loads(recorder[0].body.decode()) - assert request == {'localId': 'testuser', 'validSince': 1} + assert request == {'localId': 'testuser', 'validSince': int(arg)} class TestSetCustomUserClaims(object): @@ -394,7 +390,7 @@ def test_invalid_custom_claims(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.set_custom_user_claims('user', arg, app=user_mgt_app) - @pytest.mark.parametrize('key', _user_mgt.RESERVED_CLAIMS) + @pytest.mark.parametrize('key', _auth_utils.RESERVED_CLAIMS) def test_single_reserved_claim(self, user_mgt_app, key): claims = {key : 'value'} with pytest.raises(ValueError) as excinfo: @@ -402,7 +398,7 @@ def test_single_reserved_claim(self, user_mgt_app, key): assert str(excinfo.value) == 'Claim "{0}" is reserved, and must not be set.'.format(key) def test_multiple_reserved_claims(self, user_mgt_app): - claims = {key : 'value' for key in _user_mgt.RESERVED_CLAIMS} + claims = {key : 'value' for key in _auth_utils.RESERVED_CLAIMS} with pytest.raises(ValueError) as excinfo: auth.set_custom_user_claims('user', claims, app=user_mgt_app) joined = ', '.join(sorted(claims.keys())) @@ -448,7 +444,7 @@ class TestDeleteUser(object): @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) def test_invalid_delete_user(self, user_mgt_app, arg): with pytest.raises(ValueError): - auth.get_user(arg, app=user_mgt_app) + auth.delete_user(arg, app=user_mgt_app) def test_delete_user(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"kind":"deleteresponse"}') @@ -620,6 +616,329 @@ def _check_rpc_calls(self, recorder, expected=None): assert request == expected +class TestUserProvider(object): + + _INVALID_PROVIDERS = ( + [{'display_name': arg} for arg in INVALID_STRINGS[1:]] + + [{'email': arg} for arg in INVALID_STRINGS[1:] + ['not-an-email']] + + [{'photo_url': arg} for arg in INVALID_STRINGS[1:] + ['not-a-url']] + ) + + def test_uid_and_provider_id(self): + provider = auth.UserProvider(uid='test', provider_id='google.com') + expected = {'rawId': 'test', 'providerId': 'google.com'} + assert provider.to_dict() == expected + + def test_all_params(self): + provider = auth.UserProvider( + uid='test', provider_id='google.com', email='test@example.com', + display_name='Test Name', photo_url='https://test.com/user.png') + expected = { + 'rawId': 'test', + 'providerId': 'google.com', + 'email': 'test@example.com', + 'displayName': 'Test Name', + 'photoUrl': 'https://test.com/user.png' + } + assert provider.to_dict() == expected + + @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) + def test_invalid_uid(self, arg): + with pytest.raises(ValueError): + auth.UserProvider(uid=arg, provider_id='google.com') + + @pytest.mark.parametrize('arg', INVALID_STRINGS) + def test_invalid_provider_id(self, arg): + with pytest.raises(ValueError): + auth.UserProvider(uid='test', provider_id=arg) + + @pytest.mark.parametrize('arg', _INVALID_PROVIDERS) + def test_invalid_arg(self, arg): + with pytest.raises(ValueError): + auth.UserProvider(uid='test', provider_id='google.com', **arg) + + +class TestUserMetadata(object): + + _INVALID_ARGS = ( + [{'creation_timestamp': arg} for arg in INVALID_TIMESTAMPS] + + [{'last_sign_in_timestamp': arg} for arg in INVALID_TIMESTAMPS] + ) + + @pytest.mark.parametrize('arg', _INVALID_ARGS) + def test_invalid_args(self, arg): + with pytest.raises(ValueError): + auth.UserMetadata(**arg) + +class TestUserImportRecord(object): + + _INVALID_USERS = ( + [{'display_name': arg} for arg in INVALID_STRINGS[1:]] + + [{'email': arg} for arg in INVALID_STRINGS[1:] + ['not-an-email']] + + [{'photo_url': arg} for arg in INVALID_STRINGS[1:] + ['not-a-url']] + + [{'phone_number': arg} for arg in INVALID_STRINGS[1:] + ['not-a-phone']] + + [{'password_hash': arg} for arg in INVALID_STRINGS[1:] + [u'test']] + + [{'password_salt': arg} for arg in INVALID_STRINGS[1:] + [u'test']] + + [{'custom_claims': arg} for arg in INVALID_DICTS[1:] + ['"json"', {'key': 'a'*1000}]] + + [{'provider_data': arg} for arg in ['foo', 1, True]] + ) + + def test_uid(self): + user = auth.UserImportRecord(uid='test') + assert user.uid == 'test' + assert user.custom_claims is None + assert user.user_metadata is None + assert user.to_dict() == {'localId': 'test'} + + def test_all_params(self): + providers = [auth.UserProvider(uid='test', provider_id='google.com')] + metadata = auth.UserMetadata(100, 150) + user = auth.UserImportRecord( + uid='test', email='test@example.com', photo_url='https://test.com/user.png', + phone_number='+1234567890', display_name='name', user_metadata=metadata, + password_hash=b'password', password_salt=b'NaCl', custom_claims={'admin': True}, + email_verified=True, disabled=False, provider_data=providers) + expected = { + 'localId': 'test', + 'email': 'test@example.com', + 'photoUrl': 'https://test.com/user.png', + 'phoneNumber': '+1234567890', + 'displayName': 'name', + 'createdAt': 100, + 'lastLoginAt': 150, + 'passwordHash': _user_import.b64_encode(b'password'), + 'salt': _user_import.b64_encode(b'NaCl'), + 'customAttributes': json.dumps({'admin': True}), + 'emailVerified': True, + 'disabled': False, + 'providerUserInfo': [{'rawId': 'test', 'providerId': 'google.com'}], + } + assert user.to_dict() == expected + + @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) + def test_invalid_uid(self, arg): + with pytest.raises(ValueError): + auth.UserImportRecord(uid=arg) + + @pytest.mark.parametrize('args', _INVALID_USERS) + def test_invalid_args(self, args): + with pytest.raises(ValueError): + auth.UserImportRecord(uid='test', **args) + + @pytest.mark.parametrize('claims', [{}, {'admin': True}, '{"admin": true}']) + def test_custom_claims(self, claims): + user = auth.UserImportRecord(uid='test', custom_claims=claims) + assert user.custom_claims == claims + json_claims = json.dumps(claims) if isinstance(claims, dict) else claims + expected = {'localId': 'test', 'customAttributes': json_claims} + assert user.to_dict() == expected + + @pytest.mark.parametrize('email_verified', [True, False]) + def test_email_verified(self, email_verified): + user = auth.UserImportRecord(uid='test', email_verified=email_verified) + assert user.email_verified == email_verified + assert user.to_dict() == {'localId': 'test', 'emailVerified': email_verified} + + @pytest.mark.parametrize('disabled', [True, False]) + def test_disabled(self, disabled): + user = auth.UserImportRecord(uid='test', disabled=disabled) + assert user.disabled == disabled + assert user.to_dict() == {'localId': 'test', 'disabled': disabled} + + +class TestUserImportHash(object): + + @pytest.mark.parametrize('func,name', [ + (auth.UserImportHash.hmac_sha512, 'HMAC_SHA512'), + (auth.UserImportHash.hmac_sha256, 'HMAC_SHA256'), + (auth.UserImportHash.hmac_sha1, 'HMAC_SHA1'), + (auth.UserImportHash.hmac_md5, 'HMAC_MD5'), + ]) + def test_hmac(self, func, name): + hmac = func(key=b'key') + expected = { + 'hashAlgorithm': name, + 'signerKey': _user_import.b64_encode(b'key'), + } + assert hmac.to_dict() == expected + + @pytest.mark.parametrize('func', [ + auth.UserImportHash.hmac_sha512, auth.UserImportHash.hmac_sha256, + auth.UserImportHash.hmac_sha1, auth.UserImportHash.hmac_md5, + ]) + @pytest.mark.parametrize('key', INVALID_STRINGS) + def test_invalid_hmac(self, func, key): + with pytest.raises(ValueError): + func(key=key) + + @pytest.mark.parametrize('func,name', [ + (auth.UserImportHash.sha512, 'SHA512'), + (auth.UserImportHash.sha256, 'SHA256'), + (auth.UserImportHash.sha1, 'SHA1'), + (auth.UserImportHash.md5, 'MD5'), + ]) + def test_basic(self, func, name): + basic = func(rounds=10) + expected = { + 'hashAlgorithm': name, + 'rounds': 10, + } + assert basic.to_dict() == expected + + @pytest.mark.parametrize('func', [ + auth.UserImportHash.sha512, auth.UserImportHash.sha256, + auth.UserImportHash.sha1, auth.UserImportHash.md5, + ]) + @pytest.mark.parametrize('rounds', INVALID_INTS + [120001]) + def test_invalid_basic(self, func, rounds): + with pytest.raises(ValueError): + func(rounds=rounds) + + def test_scrypt(self): + scrypt = auth.UserImportHash.scrypt( + key=b'key', salt_separator=b'sep', rounds=8, memory_cost=14) + expected = { + 'hashAlgorithm': 'SCRYPT', + 'signerKey': _user_import.b64_encode(b'key'), + 'rounds': 8, + 'memoryCost': 14, + 'saltSeparator': _user_import.b64_encode(b'sep'), + } + assert scrypt.to_dict() == expected + + @pytest.mark.parametrize('arg', ( + [{'key': arg} for arg in INVALID_STRINGS] + + [{'rounds': arg} for arg in INVALID_INTS + [0, 9]] + + [{'memory_cost': arg} for arg in INVALID_INTS + [0, 15]] + + [{'salt_separator': arg} for arg in INVALID_STRINGS] + )) + def test_invalid_scrypt(self, arg): + params = {'key': 'key', 'rounds': 0, 'memory_cost': 14} + params.update(arg) + with pytest.raises(ValueError): + auth.UserImportHash.scrypt(**params) + + def test_bcrypt(self): + bcrypt = auth.UserImportHash.bcrypt() + assert bcrypt.to_dict() == {'hashAlgorithm': 'BCRYPT'} + + def test_standard_scrypt(self): + scrypt = auth.UserImportHash.standard_scrypt( + memory_cost=14, parallelization=2, block_size=10, derived_key_length=128) + expected = { + 'hashAlgorithm': 'STANDARD_SCRYPT', + 'memoryCost': 14, + 'parallelization': 2, + 'blockSize': 10, + 'dkLen': 128, + } + assert scrypt.to_dict() == expected + + @pytest.mark.parametrize('arg', ( + [{'memory_cost': arg} for arg in INVALID_INTS] + + [{'parallelization': arg} for arg in INVALID_INTS] + + [{'block_size': arg} for arg in INVALID_INTS] + + [{'derived_key_length': arg} for arg in INVALID_INTS] + )) + def test_invalid_standard_scrypt(self, arg): + params = { + 'memory_cost': 14, + 'parallelization': 2, + 'block_size': 10, + 'derived_key_length': 128, + } + params.update(arg) + with pytest.raises(ValueError): + auth.UserImportHash.standard_scrypt(**params) + + +class TestImportUsers(object): + + @pytest.mark.parametrize('arg', [None, list(), tuple(), dict(), 0, 1, 'foo']) + def test_invalid_users(self, user_mgt_app, arg): + with pytest.raises(Exception): + auth.import_users(arg, app=user_mgt_app) + + def test_too_many_users(self, user_mgt_app): + users = [auth.UserImportRecord(uid='test{0}'.format(i)) for i in range(1001)] + with pytest.raises(ValueError): + auth.import_users(users, app=user_mgt_app) + + def test_import_users(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{}') + users = [ + auth.UserImportRecord(uid='user1'), + auth.UserImportRecord(uid='user2'), + ] + result = auth.import_users(users, app=user_mgt_app) + assert result.success_count == 2 + assert result.failure_count is 0 + assert result.errors == [] + expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}]} + self._check_rpc_calls(recorder, expected) + + def test_import_users_error(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, """{"error": [ + {"index": 0, "message": "Some error occured in user1"}, + {"index": 2, "message": "Another error occured in user3"} + ]}""") + users = [ + auth.UserImportRecord(uid='user1'), + auth.UserImportRecord(uid='user2'), + auth.UserImportRecord(uid='user3'), + ] + result = auth.import_users(users, app=user_mgt_app) + assert result.success_count == 1 + assert result.failure_count == 2 + assert len(result.errors) == 2 + err = result.errors[0] + assert err.index == 0 + assert err.reason == 'Some error occured in user1' + err = result.errors[1] + assert err.index == 2 + assert err.reason == 'Another error occured in user3' + expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}, {'localId': 'user3'}]} + self._check_rpc_calls(recorder, expected) + + def test_import_users_missing_required_hash(self, user_mgt_app): + users = [ + auth.UserImportRecord(uid='user1', password_hash=b'password'), + auth.UserImportRecord(uid='user2'), + ] + with pytest.raises(ValueError): + auth.import_users(users, app=user_mgt_app) + + def test_import_users_with_hash(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{}') + users = [ + auth.UserImportRecord(uid='user1', password_hash=b'password'), + auth.UserImportRecord(uid='user2'), + ] + hash_alg = auth.UserImportHash.scrypt( + b'key', rounds=8, memory_cost=14, salt_separator=b'sep') + result = auth.import_users(users, hash_alg=hash_alg, app=user_mgt_app) + assert result.success_count == 2 + assert result.failure_count is 0 + assert result.errors == [] + expected = { + 'users': [ + {'localId': 'user1', 'passwordHash': _user_import.b64_encode(b'password')}, + {'localId': 'user2'} + ], + 'hashAlgorithm': 'SCRYPT', + 'signerKey': _user_import.b64_encode(b'key'), + 'rounds': 8, + 'memoryCost': 14, + 'saltSeparator': _user_import.b64_encode(b'sep'), + } + self._check_rpc_calls(recorder, expected) + + def _check_rpc_calls(self, recorder, expected): + assert len(recorder) == 1 + request = json.loads(recorder[0].body.decode()) + assert request == expected + + class TestRevokeRefreshTokkens(object): def test_revoke_refresh_tokens(self, user_mgt_app):