diff --git a/azure-quantum/azure/quantum/_authentication/__init__.py b/azure-quantum/azure/quantum/_authentication/__init__.py deleted file mode 100644 index 84646caa2..000000000 --- a/azure-quantum/azure/quantum/_authentication/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# coding=utf-8 -## -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -## - -from ._chained import * -from ._default import _DefaultAzureCredential -from ._token import _TokenFileCredential diff --git a/azure-quantum/azure/quantum/_authentication/_chained.py b/azure-quantum/azure/quantum/_authentication/_chained.py deleted file mode 100644 index 5ac7239d5..000000000 --- a/azure-quantum/azure/quantum/_authentication/_chained.py +++ /dev/null @@ -1,119 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -import logging - -import sys -from azure.core.exceptions import ClientAuthenticationError -from azure.identity import CredentialUnavailableError -from azure.core.credentials import AccessToken, TokenCredential - - -_LOGGER = logging.getLogger(__name__) - - - -def filter_credential_warnings(record): - """Suppress warnings from credentials other than DefaultAzureCredential""" - if record.levelno == logging.WARNING: - message = record.getMessage() - return "DefaultAzureCredential" in message - return True - - -def _get_error_message(history): - attempts = [] - for credential, error in history: - if error: - attempts.append(f"{credential.__class__.__name__}: {error}") - else: - attempts.append(credential.__class__.__name__) - return """ -Attempted credentials:\n\t{}""".format( - "\n\t".join(attempts) - ) - - -class _ChainedTokenCredential(object): - """ - Based on Azure.Identity.ChainedTokenCredential from: - https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/identity/azure-identity/azure/identity/_credentials/chained.py - - The key difference is that we don't stop attempting all credentials - if some of then failed or raised an exception. - We also don't log a warning unless all credential attempts have failed. - """ - - def __init__(self, *credentials: TokenCredential): - self._successful_credential = None - self.credentials = credentials - - def get_token(self, *scopes: str, **kwargs) -> AccessToken: # pylint:disable=unused-argument - """ - Request a token from each chained credential, in order, - returning the first token received. - This method is called automatically by Azure SDK clients. - - :param str scopes: desired scopes for the access token. - This method requires at least one scope. - - :raises ~azure.core.exceptions.ClientAuthenticationError: - no credential in the chain provided a token - """ - history = [] - - # Suppress warnings from credentials in Azure.Identity - azure_identity_logger = logging.getLogger("azure.identity") - handler = logging.StreamHandler(stream=sys.stdout) - handler.addFilter(filter_credential_warnings) - azure_identity_logger.addHandler(handler) - try: - for credential in self.credentials: - try: - token = credential.get_token(*scopes, **kwargs) - _LOGGER.info( - "%s acquired a token from %s", - self.__class__.__name__, - credential.__class__.__name__, - ) - self._successful_credential = credential - return token - except CredentialUnavailableError as ex: - # credential didn't attempt authentication because - # it lacks required data or state -> continue - history.append((credential, ex.message)) - _LOGGER.info( - "%s - %s is unavailable", - self.__class__.__name__, - credential.__class__.__name__, - ) - except Exception as ex: # pylint: disable=broad-except - # credential failed to authenticate, - # or something unexpectedly raised -> break - history.append((credential, str(ex))) - # instead of logging a warning, we just want to log an info - # since other credentials might succeed - _LOGGER.info( - '%s.get_token failed: %s raised unexpected error "%s"', - self.__class__.__name__, - credential.__class__.__name__, - ex, - exc_info=_LOGGER.isEnabledFor(logging.DEBUG), - ) - # here we do NOT want break and - # will continue to try other credentials - - finally: - # Re-enable warnings from credentials in Azure.Identity - azure_identity_logger.removeHandler(handler) - - # if all attempts failed, only then we log a warning and raise an error - attempts = _get_error_message(history) - message = ( - self.__class__.__name__ - + " failed to retrieve a token from the included credentials." - + attempts - ) - _LOGGER.warning(message) - raise ClientAuthenticationError(message=message) diff --git a/azure-quantum/azure/quantum/_authentication/_default.py b/azure-quantum/azure/quantum/_authentication/_default.py deleted file mode 100644 index 014754be6..000000000 --- a/azure-quantum/azure/quantum/_authentication/_default.py +++ /dev/null @@ -1,153 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -import logging -import re -from typing import Optional -import urllib3 -from azure.core.credentials import AccessToken -from azure.identity import ( - AzurePowerShellCredential, - EnvironmentCredential, - ManagedIdentityCredential, - AzureCliCredential, - VisualStudioCodeCredential, - InteractiveBrowserCredential, - DeviceCodeCredential, - _internal as AzureIdentityInternals, -) -from ._chained import _ChainedTokenCredential -from ._token import _TokenFileCredential -from azure.quantum._constants import ConnectionConstants - -_LOGGER = logging.getLogger(__name__) -WWW_AUTHENTICATE_REGEX = re.compile( - r""" - ^ - Bearer\sauthorization_uri=" - https://(?P[^/]*)/ - (?P[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}) - " - """, - re.VERBOSE | re.IGNORECASE) -WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate" - - -class _DefaultAzureCredential(_ChainedTokenCredential): - """ - Based on Azure.Identity.DefaultAzureCredential from: - https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/identity/azure-identity/azure/identity/_credentials/default.py - - The three key differences are: - 1) Inherit from _ChainedTokenCredential, which has - more aggressive error handling than ChainedTokenCredential - 2) Instantiate the internal credentials the first time the get_token gets called - such that we can get the tenant_id if it was not passed by the user (but we don't - want to do that in the constructor). - We automatically identify the user's tenant_id for a given subscription - so that users with MSA accounts don't need to pass it. - This is a mitigation for bug https://github.com/Azure/azure-sdk-for-python/issues/18975 - We need the following parameters to enable auto-detection of tenant_id - - subscription_id - - arm_endpoint (defaults to the production url "https://management.azure.com/") - 3) Add custom TokenFileCredential as first method to attempt, - which will look for a local access token. - """ - def __init__( - self, - arm_endpoint: str, - subscription_id: str, - client_id: Optional[str] = None, - tenant_id: Optional[str] = None, - authority: Optional[str] = None, - ): - if arm_endpoint is None: - raise ValueError("arm_endpoint is mandatory parameter") - if subscription_id is None: - raise ValueError("subscription_id is mandatory parameter") - - self.authority = self._authority_or_default( - authority=authority, - arm_endpoint=arm_endpoint) - self.tenant_id = tenant_id - self.subscription_id = subscription_id - self.arm_endpoint = arm_endpoint - self.client_id = client_id - # credentials will be created lazy on the first call to get_token - super(_DefaultAzureCredential, self).__init__() - - def _authority_or_default(self, authority: str, arm_endpoint: str): - if authority: - return AzureIdentityInternals.normalize_authority(authority) - if arm_endpoint == ConnectionConstants.ARM_DOGFOOD_ENDPOINT: - return ConnectionConstants.DOGFOOD_AUTHORITY - return ConnectionConstants.AUTHORITY - - def _initialize_credentials(self): - self._discover_tenant_id_( - arm_endpoint=self.arm_endpoint, - subscription_id=self.subscription_id) - credentials = [] - credentials.append(_TokenFileCredential()) - credentials.append(EnvironmentCredential()) - if self.client_id: - credentials.append(ManagedIdentityCredential(client_id=self.client_id)) - if self.authority and self.tenant_id: - credentials.append(VisualStudioCodeCredential(authority=self.authority, tenant_id=self.tenant_id)) - credentials.append(AzureCliCredential(tenant_id=self.tenant_id)) - credentials.append(AzurePowerShellCredential(tenant_id=self.tenant_id)) - credentials.append(InteractiveBrowserCredential(authority=self.authority, tenant_id=self.tenant_id)) - if self.client_id: - credentials.append(DeviceCodeCredential(authority=self.authority, client_id=self.client_id, tenant_id=self.tenant_id)) - self.credentials = credentials - - def get_token(self, *scopes: str, **kwargs) -> AccessToken: - """ - Request an access token for `scopes`. - This method is called automatically by Azure SDK clients. - - :param str scopes: desired scopes for the access token. - This method requires at least one scope. - - :raises ~azure.core.exceptions.ClientAuthenticationError:authentication failed. - The exception has a `message` attribute listing each authentication - attempt and its error message. - """ - # lazy-initialize the credentials - if self.credentials is None or len(self.credentials) == 0: - self._initialize_credentials() - - return super(_DefaultAzureCredential, self).get_token(*scopes, **kwargs) - - def _discover_tenant_id_(self, arm_endpoint:str, subscription_id:str): - """ - If the tenant_id was not given, try to obtain it - by calling the management endpoint for the subscription_id, - or by applying default values. - """ - if self.tenant_id: - return - - try: - url = ( - f"{arm_endpoint.rstrip('/')}/subscriptions/" - + f"{subscription_id}?api-version=2018-01-01" - + "&discover-tenant-id" # used by the test recording infrastructure - ) - http = urllib3.PoolManager() - response = http.request( - method="GET", - url=url, - ) - if WWW_AUTHENTICATE_HEADER_NAME in response.headers: - www_authenticate = response.headers[WWW_AUTHENTICATE_HEADER_NAME] - match = re.search(WWW_AUTHENTICATE_REGEX, www_authenticate) - if match: - self.tenant_id = match.group("tenant_id") - # pylint: disable=broad-exception-caught - except Exception as ex: - _LOGGER.error(ex) - - # apply default values - self.tenant_id = self.tenant_id or ConnectionConstants.MSA_TENANT_ID diff --git a/azure-quantum/azure/quantum/_authentication/_token.py b/azure-quantum/azure/quantum/_authentication/_token.py deleted file mode 100644 index 15ec4b135..000000000 --- a/azure-quantum/azure/quantum/_authentication/_token.py +++ /dev/null @@ -1,83 +0,0 @@ -## -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -## -import json -from json.decoder import JSONDecodeError -import logging -import os -import time - -from azure.identity import CredentialUnavailableError -from azure.core.credentials import AccessToken -from azure.quantum._constants import EnvironmentVariables - -_LOGGER = logging.getLogger(__name__) - - -class _TokenFileCredential(object): - """ - Implements a custom TokenCredential to use a local file as - the source for an AzureQuantum token. - - It will only use the local file if the AZURE_QUANTUM_TOKEN_FILE - environment variable is set, and references an existing json file - that contains the access_token and expires_on timestamp in milliseconds. - - If the environment variable is not set, the file does not exist, - or the token is invalid in any way (expired, for example), - then the credential will throw CredentialUnavailableError, - so that _ChainedTokenCredential can fallback to other methods. - """ - def __init__(self): - self.token_file = os.environ.get(EnvironmentVariables.QUANTUM_TOKEN_FILE) - if self.token_file: - _LOGGER.debug("Using provided token file location: %s", self.token_file) - else: - _LOGGER.debug("No token file location provided for %s environment variable.", - EnvironmentVariables.QUANTUM_TOKEN_FILE) - - def get_token(self, *scopes: str, **kwargs) -> AccessToken: # pylint:disable=unused-argument - """Request an access token for `scopes`. - This method is called automatically by Azure SDK clients. - This method only returns tokens for the https://quantum.microsoft.com/.default scope. - - :param str scopes: desired scopes for the access token. - - :raises ~azure.identity.CredentialUnavailableError - when failing to get the token. - The exception has a `message` attribute with the error message. - """ - if not self.token_file: - raise CredentialUnavailableError(message="Token file location not set.") - - if not os.path.isfile(self.token_file): - raise CredentialUnavailableError( - message=f"Token file at {self.token_file} does not exist.") - - try: - token = self._parse_token_file(self.token_file) - except JSONDecodeError as exception: - raise CredentialUnavailableError( - message="Failed to parse token file: Invalid JSON.") from exception - except KeyError as exception: - raise CredentialUnavailableError( - message="Failed to parse token file: Missing expected value: " - + str(exception)) from exception - except Exception as exception: - raise CredentialUnavailableError( - message="Failed to parse token file: " + str(exception)) from exception - - if token.expires_on <= time.time(): - raise CredentialUnavailableError( - message=f"Token already expired at {time.asctime(time.gmtime(token.expires_on))}") - - return token - - def _parse_token_file(self, path) -> AccessToken: - with open(path, mode="r", encoding="utf-8") as file: - data = json.load(file) - # Convert ms to seconds, since python time.time only handles epoch time in seconds - expires_on = int(data["expires_on"]) / 1000 - token = AccessToken(data["access_token"], expires_on) - return token diff --git a/azure-quantum/azure/quantum/_workspace_connection_params.py b/azure-quantum/azure/quantum/_workspace_connection_params.py index 291743192..797cf626b 100644 --- a/azure-quantum/azure/quantum/_workspace_connection_params.py +++ b/azure-quantum/azure/quantum/_workspace_connection_params.py @@ -14,7 +14,7 @@ ) from azure.core.credentials import AzureKeyCredential from azure.core.pipeline.policies import AzureKeyCredentialPolicy -from azure.quantum._authentication import _DefaultAzureCredential +from azure.identity import DefaultAzureCredential from azure.quantum._constants import ( EnvironmentKind, EnvironmentVariables, @@ -403,13 +403,10 @@ def _merge_connection_params( def get_credential_or_default(self) -> Any: """ Get the credential if one was set, - or defaults to a new _DefaultAzureCredential. + or defaults to a new DefaultAzureCredential. """ return (self.credential - or _DefaultAzureCredential( - subscription_id=self.subscription_id, - arm_endpoint=self.arm_endpoint, - tenant_id=self.tenant_id)) + or DefaultAzureCredential()) def get_auth_policy(self) -> Any: """ diff --git a/azure-quantum/tests/unit/test_authentication.py b/azure-quantum/tests/unit/test_authentication.py index e5b52dd91..5affdd9bf 100644 --- a/azure-quantum/tests/unit/test_authentication.py +++ b/azure-quantum/tests/unit/test_authentication.py @@ -18,15 +18,11 @@ API_KEY, ) from azure.identity import ( - CredentialUnavailableError, ClientSecretCredential, + DefaultAzureCredential, InteractiveBrowserCredential, ) from azure.quantum import Workspace -from azure.quantum._authentication import ( - _TokenFileCredential, - _DefaultAzureCredential, -) from azure.quantum._constants import ( EnvironmentVariables, ConnectionConstants, @@ -34,134 +30,6 @@ class TestWorkspace(QuantumTestBase): - def test_azure_quantum_token_credential_file_not_set(self): - credential = _TokenFileCredential() - with pytest.raises(CredentialUnavailableError) as exception: - credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - self.assertIn("Token file location not set.", str(exception.value)) - - def test_azure_quantum_token_credential_file_not_exists(self): - with patch.dict(os.environ, - {EnvironmentVariables.QUANTUM_TOKEN_FILE: "fake_file_path"}, - clear=True): - with patch('os.path.isfile') as mock_isfile: - mock_isfile.return_value = False - credential = _TokenFileCredential() - with pytest.raises(CredentialUnavailableError) as exception: - credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - self.assertIn("Token file at fake_file_path does not exist.", - str(exception.value)) - - def test_azure_quantum_token_credential_file_invalid_json(self): - tmpdir = self.create_temp_dir() - file = Path(tmpdir) / "token.json" - file.write_text("not a json") - with patch.dict(os.environ, - {EnvironmentVariables.QUANTUM_TOKEN_FILE: str(file.resolve())}, - clear=True): - credential = _TokenFileCredential() - with pytest.raises(CredentialUnavailableError) as exception: - credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - self.assertIn("Failed to parse token file: Invalid JSON.", - str(exception.value)) - - def test_azure_quantum_token_credential_file_missing_expires_on(self): - content = { - "access_token": "fake_token", - } - tmpdir = self.create_temp_dir() - file = Path(tmpdir) / "token.json" - file.write_text(json.dumps(content)) - with patch.dict(os.environ, - {EnvironmentVariables.QUANTUM_TOKEN_FILE: str(file.resolve())}, - clear=True): - credential = _TokenFileCredential() - with pytest.raises(CredentialUnavailableError) as exception: - credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - self.assertIn("Failed to parse token file: " + - "Missing expected value: 'expires_on'""", - str(exception.value)) - - def test_azure_quantum_token_credential_file_token_expired(self): - content = { - "access_token": "fake_token", - # Matches timestamp in error message below - "expires_on": 1628543125086 - } - tmpdir = self.create_temp_dir() - file = Path(tmpdir) / "token.json" - file.write_text(json.dumps(content)) - with patch.dict(os.environ, - {EnvironmentVariables.QUANTUM_TOKEN_FILE: str(file.resolve())}, - clear=True): - credential = _TokenFileCredential() - with pytest.raises(CredentialUnavailableError) as exception: - credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - self.assertIn("Token already expired at Mon Aug 9 21:05:25 2021", - str(exception.value)) - - def test_azure_quantum_token_credential_file_valid_token(self): - one_hour_ahead = time.time() + 60*60 - content = { - "access_token": "fake_token", - "expires_on": one_hour_ahead * 1000 # Convert to milliseconds - } - - tmpdir = self.create_temp_dir() - file = Path(tmpdir) / "token.json" - file.write_text(json.dumps(content)) - with patch.dict(os.environ, - {EnvironmentVariables.QUANTUM_TOKEN_FILE: str(file.resolve())}, - clear=True): - credential = _TokenFileCredential() - token = credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - self.assertEqual(token.token, "fake_token") - self.assertEqual(token.expires_on, pytest.approx(one_hour_ahead)) - - @pytest.mark.live_test - def test_workspace_auth_token_credential(self): - with patch.dict(os.environ): - self.clear_env_vars(os.environ) - connection_params = self.connection_params - - os.environ[EnvironmentVariables.AZURE_CLIENT_ID] = \ - connection_params.client_id - os.environ[EnvironmentVariables.AZURE_TENANT_ID] = \ - connection_params.tenant_id - - if self.in_recording and os.path.exists(self._client_certificate_path): - os.environ[EnvironmentVariables.AZURE_CLIENT_CERTIFICATE_PATH] = \ - self._client_certificate_path - os.environ[EnvironmentVariables.AZURE_CLIENT_SEND_CERTIFICATE_CHAIN] = \ - self._client_send_certificate_chain - else: - os.environ[EnvironmentVariables.AZURE_CLIENT_SECRET] = \ - self._client_secret - - credential = _DefaultAzureCredential( - subscription_id=connection_params.subscription_id, - arm_endpoint=connection_params.arm_endpoint, - tenant_id=connection_params.tenant_id) - - token = credential.get_token(ConnectionConstants.DATA_PLANE_CREDENTIAL_SCOPE) - content = { - "access_token": token.token, - "expires_on": token.expires_on * 1000 - } - tmpdir = self.create_temp_dir() - file = Path(tmpdir) / "token.json" - try: - file.write_text(json.dumps(content)) - with patch.dict(os.environ, - {EnvironmentVariables.QUANTUM_TOKEN_FILE: str(file.resolve())}, - clear=True): - credential = _TokenFileCredential() - workspace = self.create_workspace(credential=credential) - targets = workspace.get_targets() - self.assertGreater(len(targets), 1) - finally: - os.remove(file) - @pytest.mark.live_test def test_workspace_auth_client_secret_credential(self): client_secret = os.environ.get(EnvironmentVariables.AZURE_CLIENT_SECRET) @@ -199,10 +67,7 @@ def test_workspace_auth_default_credential(self): os.environ[EnvironmentVariables.AZURE_CLIENT_SECRET] = \ self._client_secret - credential = _DefaultAzureCredential( - subscription_id=connection_params.subscription_id, - arm_endpoint=connection_params.arm_endpoint, - tenant_id=connection_params.tenant_id) + credential = DefaultAzureCredential() workspace = self.create_workspace(credential=credential) targets = workspace.get_targets() @@ -223,10 +88,7 @@ def test_workspace_auth_interactive_credential(self): def _get_rp_credential(self): connection_params = self.connection_params # We have to use DefaultAzureCredential to avoid using ApiKeyCredential - credential = _DefaultAzureCredential( - subscription_id=connection_params.subscription_id, - arm_endpoint=connection_params.arm_endpoint, - tenant_id=connection_params.tenant_id) + credential = DefaultAzureCredential() scope = ConnectionConstants.ARM_CREDENTIAL_SCOPE token = credential.get_token(scope).token return token