diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index c33c81e53..cb2c5d84b 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -81,7 +81,7 @@ jobs: /bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.integration coverage run -m pytest -m v4 tests/integration -v" - name: Run asyncio integration tests - id: integration_tests + id: asyncio_integration_tests continue-on-error: true run: | docker run --rm \ @@ -90,7 +90,7 @@ jobs: -e CONDUCTOR_SERVER_URL=${{ env.CONDUCTOR_SERVER_URL }} \ -v ${{ github.workspace }}/${{ env.COVERAGE_DIR }}:/package/${{ env.COVERAGE_DIR }}:rw \ conductor-sdk-test:latest \ - /bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.integration coverage run -m pytest -m v4 tests/integration -v" + /bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.asyncio_integration coverage run -m pytest -m v4 tests/integration/async -v" - name: Generate coverage report id: coverage_report @@ -124,4 +124,4 @@ jobs: - name: Check test results if: steps.unit_tests.outcome == 'failure' || steps.bc_tests.outcome == 'failure' || steps.serdeser_tests.outcome == 'failure' - run: exit 1 \ No newline at end of file + run: exit 1 diff --git a/src/conductor/asyncio_client/adapters/api_client_adapter.py b/src/conductor/asyncio_client/adapters/api_client_adapter.py index 2c82a5ee4..2eb913449 100644 --- a/src/conductor/asyncio_client/adapters/api_client_adapter.py +++ b/src/conductor/asyncio_client/adapters/api_client_adapter.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +import asyncio import json import logging import re +import time from typing import Dict, Optional from conductor.asyncio_client.adapters.models import GenerateTokenRequest @@ -15,6 +19,10 @@ class ApiClientAdapter(ApiClient): + def __init__(self, *args, **kwargs): + self._token_lock = asyncio.Lock() + super().__init__(*args, **kwargs) + async def call_api( self, method, @@ -37,7 +45,9 @@ async def call_api( """ try: - logger.debug("HTTP request method: %s; url: %s; header_params: %s", method, url, header_params) + logger.debug( + "HTTP request method: %s; url: %s; header_params: %s", method, url, header_params + ) response_data = await self.rest_client.request( method, url, @@ -46,9 +56,29 @@ async def call_api( post_params=post_params, _request_timeout=_request_timeout, ) - if response_data.status == 401 and url != self.configuration.host + "/token": # noqa: PLR2004 (Unauthorized status code) - logger.warning("HTTP response from: %s; status code: 401 - obtaining new token", url) - token = await self.refresh_authorization_token() + if ( + response_data.status == 401 # noqa: PLR2004 (Unauthorized status code) + and url != self.configuration.host + "/token" + ): + logger.warning( + "HTTP response from: %s; status code: 401 - obtaining new token", url + ) + async with self._token_lock: + # The lock is intentionally broad (covers the whole block including the token state) + # to avoid race conditions: without it, other coroutines could mis-evaluate + # token state during a context switch and trigger redundant refreshes + token_expired = ( + self.configuration.token_update_time > 0 + and time.time() + >= self.configuration.token_update_time + + self.configuration.auth_token_ttl_sec + ) + invalid_token = not self.configuration._http_config.api_key.get("api_key") + + if invalid_token or token_expired: + token = await self.refresh_authorization_token() + else: + token = self.configuration._http_config.api_key["api_key"] header_params["X-Authorization"] = token response_data = await self.rest_client.request( method, @@ -59,7 +89,9 @@ async def call_api( _request_timeout=_request_timeout, ) except ApiException as e: - logger.error("HTTP request failed url: %s status: %s; reason: %s", url, e.status, e.reason) + logger.error( + "HTTP request failed url: %s status: %s; reason: %s", url, e.status, e.reason + ) raise e return response_data @@ -82,12 +114,10 @@ def response_deserialize( if ( not response_type and isinstance(response_data.status, int) - and 100 <= response_data.status <= 599 + and 100 <= response_data.status <= 599 # noqa: PLR2004 ): # if not found, look for '1XX', '2XX', etc. - response_type = response_types_map.get( - str(response_data.status)[0] + "XX", None - ) + response_type = response_types_map.get(str(response_data.status)[0] + "XX", None) # deserialize response data response_text = None @@ -104,12 +134,10 @@ def response_deserialize( match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type) encoding = match.group(1) if match else "utf-8" response_text = response_data.data.decode(encoding) - return_data = self.deserialize( - response_text, response_type, content_type - ) + return_data = self.deserialize(response_text, response_type, content_type) finally: - if not 200 <= response_data.status <= 299: - logger.error(f"Unexpected response status code: {response_data.status}") + if not 200 <= response_data.status <= 299: # noqa: PLR2004 + logger.error("Unexpected response status code: %s", response_data.status) raise ApiException.from_response( http_resp=response_data, body=response_text, @@ -126,8 +154,9 @@ def response_deserialize( async def refresh_authorization_token(self): obtain_new_token_response = await self.obtain_new_token() token = obtain_new_token_response.get("token") - self.configuration.api_key["api_key"] = token - logger.debug(f"New auth token been set") + self.configuration._http_config.api_key["api_key"] = token + self.configuration.token_update_time = time.time() + logger.debug("New auth token been set") return token async def obtain_new_token(self): diff --git a/src/conductor/asyncio_client/configuration/configuration.py b/src/conductor/asyncio_client/configuration/configuration.py index 695193653..8177242f7 100644 --- a/src/conductor/asyncio_client/configuration/configuration.py +++ b/src/conductor/asyncio_client/configuration/configuration.py @@ -57,6 +57,7 @@ def __init__( auth_key: Optional[str] = None, auth_secret: Optional[str] = None, debug: bool = False, + auth_token_ttl_min: int = 45, # Worker properties polling_interval: Optional[int] = None, domain: Optional[str] = None, @@ -136,10 +137,6 @@ def __init__( if api_key is None: api_key = {} - if self.auth_key and self.auth_secret: - # Use the auth_key as the API key for X-Authorization header - api_key["api_key"] = self.auth_key - self.__ui_host = os.getenv("CONDUCTOR_UI_SERVER_URL") if self.__ui_host is None: self.__ui_host = self.server_url.replace("/api", "") @@ -182,6 +179,10 @@ def __init__( self.is_logger_config_applied = False + # Orkes Conductor auth token properties + self.token_update_time = 0 + self.auth_token_ttl_sec = auth_token_ttl_min * 60 + def _get_env_float(self, env_var: str, default: float) -> float: """Get float value from environment variable with default fallback.""" try: @@ -268,9 +269,7 @@ def _convert_property_value(self, property_name: str, value: str) -> Any: # For other properties, return as string return value - def set_worker_property( - self, task_type: str, property_name: str, value: Any - ) -> None: + def set_worker_property(self, task_type: str, property_name: str, value: Any) -> None: """ Set worker property for a specific task type. @@ -523,7 +522,5 @@ def ui_host(self): def __getattr__(self, name: str) -> Any: """Delegate attribute access to underlying HTTP configuration.""" if "_http_config" not in self.__dict__ or self._http_config is None: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") return getattr(self._http_config, name) diff --git a/tests/unit/api_client/test_async_api_client.py b/tests/unit/api_client/test_async_api_client.py new file mode 100644 index 000000000..698081af1 --- /dev/null +++ b/tests/unit/api_client/test_async_api_client.py @@ -0,0 +1,221 @@ +import asyncio +import time + +import pytest + +from conductor.asyncio_client.adapters import ApiClient +from conductor.asyncio_client.configuration.configuration import Configuration +from conductor.asyncio_client.http.rest import RESTResponse + + +@pytest.fixture +def api_client(): + configuration = Configuration( + server_url="http://localhost:8080/api", + auth_key="test_key", + auth_secret="test_secret", + ) + return ApiClient(configuration) + + +@pytest.fixture +def mock_success_response(mocker): + response = mocker.Mock(spec=RESTResponse) + response.status = 200 + response.data = b'{"token": "test_token"}' + response.read = mocker.Mock() + return response + + +@pytest.fixture +def mock_401_response(mocker): + response = mocker.Mock(spec=RESTResponse) + response.status = 401 + response.data = b'{"message":"Token cannot be null or empty","error":"INVALID_TOKEN","timestamp":1758039192168}' + response.read = mocker.AsyncMock() + return response + + +@pytest.mark.asyncio +async def test_refresh_authorization_token_called_on_invalid_token( + mocker, api_client, mock_401_response, mock_success_response +): + api_client.configuration._http_config.api_key = {} + + api_client.rest_client = mocker.AsyncMock() + api_client.rest_client.request.side_effect = [ + mock_401_response, + mock_success_response, + ] + + mock_refresh = mocker.patch.object( + api_client, "refresh_authorization_token", new_callable=mocker.AsyncMock + ) + mock_refresh.return_value = "new_token" + + mock_obtain = mocker.patch.object( + api_client, "obtain_new_token", new_callable=mocker.AsyncMock + ) + mock_obtain.return_value = {"token": "new_token"} + + await api_client.call_api( + method="GET", url="http://localhost:8080/api/test", header_params={} + ) + + mock_refresh.assert_called_once() + + +@pytest.mark.asyncio +async def test_refresh_authorization_token_called_on_expired_token( + mocker, api_client, mock_401_response, mock_success_response +): + current_time = time.time() + api_client.configuration.token_update_time = current_time - 3600 + api_client.configuration.auth_token_ttl_sec = 1800 + api_client.configuration._http_config.api_key = {"api_key": "old_token"} + + api_client.rest_client = mocker.AsyncMock() + api_client.rest_client.request.side_effect = [ + mock_401_response, + mock_success_response, + ] + + mock_refresh = mocker.patch.object( + api_client, "refresh_authorization_token", new_callable=mocker.AsyncMock + ) + mock_refresh.return_value = "new_token" + + mock_obtain = mocker.patch.object( + api_client, "obtain_new_token", new_callable=mocker.AsyncMock + ) + mock_obtain.return_value = {"token": "new_token"} + + await api_client.call_api( + method="GET", url="http://localhost:8080/api/test", header_params={} + ) + + mock_refresh.assert_called_once() + + +@pytest.mark.asyncio +async def test_token_lock_prevents_concurrent_refresh( + mocker, api_client, mock_401_response, mock_success_response +): + api_client.configuration._http_config.api_key = {} + + refresh_calls = [] + + async def mock_refresh(): + refresh_calls.append(time.time()) + await asyncio.sleep(0.1) + return "new_token" + + mocker.patch.object( + api_client, "refresh_authorization_token", side_effect=mock_refresh + ) + + mock_obtain = mocker.patch.object( + api_client, "obtain_new_token", new_callable=mocker.AsyncMock + ) + mock_obtain.return_value = {"token": "new_token"} + + api_client.rest_client = mocker.AsyncMock() + api_client.rest_client.request.side_effect = [ + mock_401_response, + mock_success_response, + mock_401_response, + mock_success_response, + ] + + tasks = [ + api_client.call_api( + method="GET", + url="http://localhost:8080/api/test1", + header_params={}, + ), + api_client.call_api( + method="GET", + url="http://localhost:8080/api/test2", + header_params={}, + ), + ] + + await asyncio.gather(*tasks) + + assert len(refresh_calls) == 1 + + +@pytest.mark.asyncio +async def test_no_refresh_when_token_valid_and_not_expired( + mocker, api_client, mock_success_response +): + current_time = time.time() + api_client.configuration.token_update_time = current_time - 100 + api_client.configuration.auth_token_ttl_sec = 1800 + api_client.configuration._http_config.api_key = {"api_key": "valid_token"} + + api_client.rest_client = mocker.AsyncMock() + api_client.rest_client.request.return_value = mock_success_response + + mock_refresh = mocker.patch.object( + api_client, "refresh_authorization_token", new_callable=mocker.AsyncMock + ) + + await api_client.call_api( + method="GET", url="http://localhost:8080/api/test", header_params={} + ) + + mock_refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_no_refresh_for_token_endpoint(mocker, api_client, mock_401_response): + api_client.configuration._http_config.api_key = {} + + api_client.rest_client = mocker.AsyncMock() + api_client.rest_client.request.return_value = mock_401_response + + mock_refresh = mocker.patch.object( + api_client, "refresh_authorization_token", new_callable=mocker.AsyncMock + ) + + await api_client.call_api( + method="POST", url="http://localhost:8080/api/token", header_params={} + ) + + mock_refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_401_response_triggers_retry_with_new_token( + mocker, api_client, mock_401_response, mock_success_response +): + api_client.configuration._http_config.api_key = {} + + api_client.rest_client = mocker.AsyncMock() + api_client.rest_client.request.side_effect = [ + mock_401_response, + mock_success_response, + ] + + mock_refresh = mocker.patch.object( + api_client, "refresh_authorization_token", new_callable=mocker.AsyncMock + ) + mock_refresh.return_value = "new_token" + + mock_obtain = mocker.patch.object( + api_client, "obtain_new_token", new_callable=mocker.AsyncMock + ) + mock_obtain.return_value = {"token": "new_token"} + + header_params = {} + await api_client.call_api( + method="GET", + url="http://localhost:8080/api/test", + header_params=header_params, + ) + + assert api_client.rest_client.request.call_count == 2 + + second_call_args = api_client.rest_client.request.call_args_list[1] + assert second_call_args[1]["headers"]["X-Authorization"] == "new_token" diff --git a/tests/unit/asyncio_client/test_api_client_adapter.py b/tests/unit/asyncio_client/test_api_client_adapter.py index aecde3588..494b74d41 100644 --- a/tests/unit/asyncio_client/test_api_client_adapter.py +++ b/tests/unit/asyncio_client/test_api_client_adapter.py @@ -1,15 +1,15 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch from conductor.asyncio_client.adapters.api_client_adapter import ApiClientAdapter +from conductor.asyncio_client.configuration import Configuration from conductor.asyncio_client.http.exceptions import ApiException from conductor.asyncio_client.http.api_response import ApiResponse @pytest.fixture def mock_config(): - config = MagicMock() + config = Configuration() config.host = "http://test.com" - config.api_key = {"api_key": "test_token"} config.auth_key = "test_key" config.auth_secret = "test_secret" return config @@ -50,9 +50,7 @@ async def test_call_api_401_retry(adapter): adapter.rest_client.request = AsyncMock(return_value=mock_response) adapter.refresh_authorization_token = AsyncMock(return_value="new_token") - result = await adapter.call_api( - "GET", "http://test.com/api", {"X-Authorization": "old_token"} - ) + result = await adapter.call_api("GET", "http://test.com/api", {"X-Authorization": "old_token"}) assert result == mock_response assert adapter.rest_client.request.call_count == 2 @@ -215,9 +213,7 @@ async def test_obtain_new_token_with_patch(): client_adapter.configuration = MagicMock() client_adapter.configuration.auth_key = "test_key" client_adapter.configuration.auth_secret = "test_secret" - client_adapter.param_serialize = MagicMock( - return_value=("POST", "/token", {}, {}) - ) + client_adapter.param_serialize = MagicMock(return_value=("POST", "/token", {}, {})) mock_response = MagicMock() mock_response.data = b'{"token": "test_token"}' @@ -227,9 +223,7 @@ async def test_obtain_new_token_with_patch(): result = await client_adapter.obtain_new_token() assert result == {"token": "test_token"} - mock_generate_token.assert_called_once_with( - key_id="test_key", key_secret="test_secret" - ) + mock_generate_token.assert_called_once_with(key_id="test_key", key_secret="test_secret") def test_response_deserialize_encoding_detection(adapter): diff --git a/tests/unit/asyncio_client/test_configuration.py b/tests/unit/asyncio_client/test_configuration.py index a4be3d5c1..db4f427a2 100644 --- a/tests/unit/asyncio_client/test_configuration.py +++ b/tests/unit/asyncio_client/test_configuration.py @@ -50,6 +50,7 @@ def test_initialization_with_env_vars(monkeypatch): assert config.domain == "env_domain" assert config.polling_interval_seconds == 10 + def test_initialization_env_vars_override_params(monkeypatch): monkeypatch.setenv("CONDUCTOR_SERVER_URL", "https://env.com/api") monkeypatch.setenv("CONDUCTOR_AUTH_KEY", "env_key") @@ -146,7 +147,7 @@ def test_get_worker_property_value_poll_interval_seconds(): result = config.get_worker_property_value("poll_interval_seconds", "mytask") assert result == 0 - + def test_convert_property_value_polling_interval(): config = Configuration() result = config._convert_property_value("polling_interval", "250") @@ -378,12 +379,6 @@ def test_getattr_no_http_config(): _ = config.nonexistent_attr -def test_auth_setup_with_credentials(): - config = Configuration(auth_key="key", auth_secret="secret") - assert "api_key" in config.api_key - assert config.api_key["api_key"] == "key" - - def test_worker_properties_dict_initialization(): config = Configuration() assert isinstance(config._worker_properties, dict) @@ -398,9 +393,7 @@ def test_get_worker_property_value_unknown_property(): def test_get_poll_interval_with_task_type_none_value(): config = Configuration() - with patch.dict( - os.environ, {"CONDUCTOR_WORKER_MYTASK_POLLING_INTERVAL": "invalid"} - ): + with patch.dict(os.environ, {"CONDUCTOR_WORKER_MYTASK_POLLING_INTERVAL": "invalid"}): result = config.get_poll_interval("mytask") assert result == 100 diff --git a/tests/unit/configuration/test_async_configuration.py b/tests/unit/configuration/test_async_configuration.py new file mode 100644 index 000000000..d35d59ba8 --- /dev/null +++ b/tests/unit/configuration/test_async_configuration.py @@ -0,0 +1,25 @@ +import pytest + +from conductor.asyncio_client.configuration import Configuration + + +def test_initialization_default(monkeypatch): + monkeypatch.setenv("CONDUCTOR_SERVER_URL", "http://localhost:8080/api") + configuration = Configuration() + assert configuration.host == "http://localhost:8080/api" + + +def test_initialization_with_base_url(): + configuration = Configuration(server_url="https://play.orkes.io/api") + assert configuration.host == "https://play.orkes.io/api" + + +def test_missed_http_config(): + configuration = Configuration() + configuration._http_config = None + with pytest.raises(AttributeError) as ctx: + _ = configuration.api_key + assert ( + f"'{Configuration.__class__.__name__}' object has no attribute 'api_key'" + in ctx.value + )