Skip to content

Commit

Permalink
address the comment
Browse files Browse the repository at this point in the history
  • Loading branch information
ninghu committed Jun 19, 2024
1 parent d008a41 commit 4aa135f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
15 changes: 11 additions & 4 deletions src/promptflow-azure/promptflow/azure/_utils/_token_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import time

import jwt

from promptflow.core._connection_provider._utils import get_arm_token


Expand All @@ -14,14 +19,14 @@ def __call__(cls, *args, **kwargs):


class ArmTokenCache(metaclass=SingletonMeta):
DEFAULT_TTL_SECS = 1800
TOKEN_REFRESH_THRESHOLD_SECS = 300

def __init__(self):
self._ttl_secs = self.DEFAULT_TTL_SECS
self._cache = {}

def _is_token_valid(self, entry):
return time.time() < entry["expires_at"]
current_time = time.time()
return (entry["expires_at"] - current_time) >= self.TOKEN_REFRESH_THRESHOLD_SECS

def get_token(self, credential):
if credential in self._cache:
Expand All @@ -30,7 +35,9 @@ def get_token(self, credential):
return entry["token"]

token = self._fetch_token(credential)
self._cache[credential] = {"token": token, "expires_at": time.time() + self._ttl_secs}
decoded_token = jwt.decode(token, options={"verify_signature": False, "verify_aud": False})
expiration_time = decoded_token.get("exp", time.time())
self._cache[credential] = {"token": token, "expires_at": expiration_time}
return token

def _fetch_token(self, credential):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from unittest.mock import MagicMock, patch

import jwt
import pytest

from promptflow.azure._utils._token_cache import ArmTokenCache
Expand Down Expand Up @@ -55,29 +56,35 @@ def test_user_specified_azure_cli_credential(self):

@patch.object(ArmTokenCache, "_fetch_token")
def test_arm_token_cache_get_token(self, mock_fetch_token):
mock_fetch_token.return_value = "test_token"
expiration_time = time.time() + 3600 # 1 hour in the future
mock_token = jwt.encode({"exp": expiration_time}, "secret", algorithm="HS256")
mock_fetch_token.return_value = mock_token
credential = "test_credential"

cache = ArmTokenCache()

# Test that the token is fetched and cached
token1 = cache.get_token(credential)
assert token1 == "test_token", f"Expected 'test_token' but got {token1}"
assert token1 == mock_token, f"Expected '{mock_token}' but got {token1}"
assert credential in cache._cache, f"Expected '{credential}' to be in cache"
assert cache._cache[credential]["token"] == "test_token", "Expected token in cache to be 'test_token'"
assert cache._cache[credential]["token"] == mock_token, "Expected token in cache to be the mock token"

# Test that the cached token is returned if still valid
token2 = cache.get_token(credential)
assert token2 == "test_token", f"Expected 'test_token' but got {token2}"
assert token2 == mock_token, f"Expected '{mock_token}' but got {token2}"
assert (
mock_fetch_token.call_count == 1
), f"Expected fetch token to be called once, but it was called {mock_fetch_token.call_count} times"

# Test that a new token is fetched if the old one expires
expired_time = time.time() - 10
expired_time = time.time() - 10 # Set the token as expired
cache._cache[credential]["expires_at"] = expired_time

mock_fetch_token.return_value = "new_test_token"
new_expiration_time = time.time() + 3600
new_mock_token = jwt.encode({"exp": new_expiration_time}, "secret", algorithm="HS256")
mock_fetch_token.return_value = new_mock_token
token3 = cache.get_token(credential)
assert token3 == "new_test_token", f"Expected 'new_test_token' but got {token3}"
assert mock_fetch_token.call
assert token3 == new_mock_token, f"Expected '{new_mock_token}' but got {token3}"
assert (
mock_fetch_token.call_count == 2
), f"Expected fetch token to be called twice, but it was called {mock_fetch_token.call_count} times"

0 comments on commit 4aa135f

Please sign in to comment.