From 35be3cc0c861b22298b37812b2ad8ab89cbe542f Mon Sep 17 00:00:00 2001 From: Billy Hu Date: Mon, 17 Jun 2024 16:47:37 -0700 Subject: [PATCH] Add TokenCache --- .../promptflow/azure/_utils/_token_cache.py | 37 +++++++++++++++++++ .../promptflow/azure/_utils/general.py | 4 +- .../unittests/test_utils.py | 31 ++++++++++++++++ .../promptflow/evals/evaluate/_eval_run.py | 9 ++--- 4 files changed, 74 insertions(+), 7 deletions(-) create mode 100644 src/promptflow-azure/promptflow/azure/_utils/_token_cache.py diff --git a/src/promptflow-azure/promptflow/azure/_utils/_token_cache.py b/src/promptflow-azure/promptflow/azure/_utils/_token_cache.py new file mode 100644 index 00000000000..d284281d4cc --- /dev/null +++ b/src/promptflow-azure/promptflow/azure/_utils/_token_cache.py @@ -0,0 +1,37 @@ +import time + +from promptflow.core._connection_provider._utils import get_arm_token + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class ArmTokenCache(metaclass=SingletonMeta): + DEFAULT_TTL_SECS = 1800 + + def __init__(self): + self._ttl_secs = self.DEFAULT_TTL_SECS + self._cache = {} + + def _is_token_valid(self, entry): + return time.time() < entry["expires_at"] + + def get_token(self, credential): + if credential in self._cache: + entry = self._cache[credential] + if self._is_token_valid(entry): + return entry["token"] + + token = self._fetch_token(credential) + self._cache[credential] = {"token": token, "expires_at": time.time() + self._ttl_secs} + return token + + def _fetch_token(self, credential): + return get_arm_token(credential=credential) diff --git a/src/promptflow-azure/promptflow/azure/_utils/general.py b/src/promptflow-azure/promptflow/azure/_utils/general.py index 8e5ab801e21..0f1a1feb639 100644 --- a/src/promptflow-azure/promptflow/azure/_utils/general.py +++ b/src/promptflow-azure/promptflow/azure/_utils/general.py @@ -6,6 +6,8 @@ from promptflow.core._connection_provider._utils import get_arm_token, get_token +from ._token_cache import ArmTokenCache + def is_arm_id(obj) -> bool: return isinstance(obj, str) and obj.startswith("azureml://") @@ -24,7 +26,7 @@ def get_aml_token(credential) -> str: def get_authorization(credential=None) -> str: - token = get_arm_token(credential=credential) + token = ArmTokenCache().get_token(credential=credential) return "Bearer " + token diff --git a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py index 8cb2f342fee..3bae93a9105 100644 --- a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py +++ b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py @@ -1,10 +1,12 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import time from unittest.mock import MagicMock, patch import pytest +from promptflow.azure._utils._token_cache import ArmTokenCache from promptflow.exceptions import UserErrorException @@ -50,3 +52,32 @@ def test_user_specified_azure_cli_credential(self): with patch.dict("os.environ", {EnvironmentVariables.PF_USE_AZURE_CLI_CREDENTIAL: "true"}): cred = get_credentials_for_cli() assert isinstance(cred, AzureCliCredential) + + @patch.object(ArmTokenCache, "_fetch_token") + def test_arm_token_cache_get_token(self, mock_fetch_token): + mock_fetch_token.return_value = "test_token" + credential = "test_credential" + + cache = ArmTokenCache() + + # Test that the token is fetched and cached + token1 = cache.get_token(credential) + assert token1 == "test_token", f"Expected 'test_token' but got {token1}" + assert credential in cache._cache, f"Expected '{credential}' to be in cache" + assert cache._cache[credential]["token"] == "test_token", "Expected token in cache to be 'test_token'" + + # Test that the cached token is returned if still valid + token2 = cache.get_token(credential) + assert token2 == "test_token", f"Expected 'test_token' but got {token2}" + assert ( + mock_fetch_token.call_count == 1 + ), f"Expected fetch token to be called once, but it was called {mock_fetch_token.call_count} times" + + # Test that a new token is fetched if the old one expires + expired_time = time.time() - 10 + cache._cache[credential]["expires_at"] = expired_time + + mock_fetch_token.return_value = "new_test_token" + token3 = cache.get_token(credential) + assert token3 == "new_test_token", f"Expected 'new_test_token' but got {token3}" + assert mock_fetch_token.call diff --git a/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py b/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py index 7ba60c72c50..0f015ebeb19 100644 --- a/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py +++ b/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py @@ -16,6 +16,7 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry +from promptflow.azure._utils._token_cache import ArmTokenCache from promptflow.evals._version import VERSION LOGGER = logging.getLogger(__name__) @@ -209,11 +210,7 @@ def get_metrics_url(self): return f"https://{self._url_base}" "/mlflow/v2.0" f"{self._get_scope()}" f"/api/2.0/mlflow/runs/log-metric" def _get_token(self): - """The simple method to get token from the MLClient.""" - # This behavior mimics how the authority is taken in azureml-mlflow. - # Note, that here we are taking authority for public cloud, however, - # it will not work for non-public clouds. - return self._ml_client._credential.get_token(EvalRun._SCOPE) + return ArmTokenCache().get_token(self._ml_client._credential) def request_with_retry( self, url: str, method: str, json_dict: Dict[str, Any], headers: Optional[Dict[str, str]] = None @@ -234,7 +231,7 @@ def request_with_retry( if headers is None: headers = {} headers["User-Agent"] = f"promptflow/{VERSION}" - headers["Authorization"] = f"Bearer {self._get_token().token}" + headers["Authorization"] = f"Bearer {self._get_token()}" retry = Retry( total=EvalRun._MAX_RETRIES, connect=EvalRun._MAX_RETRIES,