Skip to content

Commit

Permalink
Add TokenCache
Browse files Browse the repository at this point in the history
  • Loading branch information
ninghu committed Jun 17, 2024
1 parent 32115d3 commit 35be3cc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 7 deletions.
37 changes: 37 additions & 0 deletions src/promptflow-azure/promptflow/azure/_utils/_token_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import time

from promptflow.core._connection_provider._utils import get_arm_token


class SingletonMeta(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]


class ArmTokenCache(metaclass=SingletonMeta):
DEFAULT_TTL_SECS = 1800

def __init__(self):
self._ttl_secs = self.DEFAULT_TTL_SECS
self._cache = {}

def _is_token_valid(self, entry):
return time.time() < entry["expires_at"]

def get_token(self, credential):
if credential in self._cache:
entry = self._cache[credential]
if self._is_token_valid(entry):
return entry["token"]

token = self._fetch_token(credential)
self._cache[credential] = {"token": token, "expires_at": time.time() + self._ttl_secs}
return token

def _fetch_token(self, credential):
return get_arm_token(credential=credential)
4 changes: 3 additions & 1 deletion src/promptflow-azure/promptflow/azure/_utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from promptflow.core._connection_provider._utils import get_arm_token, get_token

from ._token_cache import ArmTokenCache


def is_arm_id(obj) -> bool:
return isinstance(obj, str) and obj.startswith("azureml://")
Expand All @@ -24,7 +26,7 @@ def get_aml_token(credential) -> str:


def get_authorization(credential=None) -> str:
token = get_arm_token(credential=credential)
token = ArmTokenCache().get_token(credential=credential)
return "Bearer " + token


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import time
from unittest.mock import MagicMock, patch

import pytest

from promptflow.azure._utils._token_cache import ArmTokenCache
from promptflow.exceptions import UserErrorException


Expand Down Expand Up @@ -50,3 +52,32 @@ def test_user_specified_azure_cli_credential(self):
with patch.dict("os.environ", {EnvironmentVariables.PF_USE_AZURE_CLI_CREDENTIAL: "true"}):
cred = get_credentials_for_cli()
assert isinstance(cred, AzureCliCredential)

@patch.object(ArmTokenCache, "_fetch_token")
def test_arm_token_cache_get_token(self, mock_fetch_token):
mock_fetch_token.return_value = "test_token"
credential = "test_credential"

cache = ArmTokenCache()

# Test that the token is fetched and cached
token1 = cache.get_token(credential)
assert token1 == "test_token", f"Expected 'test_token' but got {token1}"
assert credential in cache._cache, f"Expected '{credential}' to be in cache"
assert cache._cache[credential]["token"] == "test_token", "Expected token in cache to be 'test_token'"

# Test that the cached token is returned if still valid
token2 = cache.get_token(credential)
assert token2 == "test_token", f"Expected 'test_token' but got {token2}"
assert (
mock_fetch_token.call_count == 1
), f"Expected fetch token to be called once, but it was called {mock_fetch_token.call_count} times"

# Test that a new token is fetched if the old one expires
expired_time = time.time() - 10
cache._cache[credential]["expires_at"] = expired_time

mock_fetch_token.return_value = "new_test_token"
token3 = cache.get_token(credential)
assert token3 == "new_test_token", f"Expected 'new_test_token' but got {token3}"
assert mock_fetch_token.call
9 changes: 3 additions & 6 deletions src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from promptflow.azure._utils._token_cache import ArmTokenCache
from promptflow.evals._version import VERSION

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -209,11 +210,7 @@ def get_metrics_url(self):
return f"https://{self._url_base}" "/mlflow/v2.0" f"{self._get_scope()}" f"/api/2.0/mlflow/runs/log-metric"

def _get_token(self):
"""The simple method to get token from the MLClient."""
# This behavior mimics how the authority is taken in azureml-mlflow.
# Note, that here we are taking authority for public cloud, however,
# it will not work for non-public clouds.
return self._ml_client._credential.get_token(EvalRun._SCOPE)
return ArmTokenCache().get_token(self._ml_client._credential)

def request_with_retry(
self, url: str, method: str, json_dict: Dict[str, Any], headers: Optional[Dict[str, str]] = None
Expand All @@ -234,7 +231,7 @@ def request_with_retry(
if headers is None:
headers = {}
headers["User-Agent"] = f"promptflow/{VERSION}"
headers["Authorization"] = f"Bearer {self._get_token().token}"
headers["Authorization"] = f"Bearer {self._get_token()}"
retry = Retry(
total=EvalRun._MAX_RETRIES,
connect=EvalRun._MAX_RETRIES,
Expand Down

0 comments on commit 35be3cc

Please sign in to comment.