Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
/bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.integration coverage run -m pytest -m v4 tests/integration -v"

- name: Run asyncio integration tests
id: integration_tests
id: asyncio_integration_tests
continue-on-error: true
run: |
docker run --rm \
Expand All @@ -90,7 +90,7 @@ jobs:
-e CONDUCTOR_SERVER_URL=${{ env.CONDUCTOR_SERVER_URL }} \
-v ${{ github.workspace }}/${{ env.COVERAGE_DIR }}:/package/${{ env.COVERAGE_DIR }}:rw \
conductor-sdk-test:latest \
/bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.integration coverage run -m pytest -m v4 tests/integration -v"
/bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.asyncio_integration coverage run -m pytest -m v4 tests/integration/async -v"

- name: Generate coverage report
id: coverage_report
Expand Down Expand Up @@ -124,4 +124,4 @@ jobs:

- name: Check test results
if: steps.unit_tests.outcome == 'failure' || steps.bc_tests.outcome == 'failure' || steps.serdeser_tests.outcome == 'failure'
run: exit 1
run: exit 1
61 changes: 45 additions & 16 deletions src/conductor/asyncio_client/adapters/api_client_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import asyncio
import json
import logging
import re
import time
from typing import Dict, Optional

from conductor.asyncio_client.adapters.models import GenerateTokenRequest
Expand All @@ -15,6 +19,10 @@


class ApiClientAdapter(ApiClient):
def __init__(self, *args, **kwargs):
self._token_lock = asyncio.Lock()
super().__init__(*args, **kwargs)

async def call_api(
self,
method,
Expand All @@ -37,7 +45,9 @@ async def call_api(
"""

try:
logger.debug("HTTP request method: %s; url: %s; header_params: %s", method, url, header_params)
logger.debug(
"HTTP request method: %s; url: %s; header_params: %s", method, url, header_params
)
response_data = await self.rest_client.request(
method,
url,
Expand All @@ -46,9 +56,29 @@ async def call_api(
post_params=post_params,
_request_timeout=_request_timeout,
)
if response_data.status == 401 and url != self.configuration.host + "/token": # noqa: PLR2004 (Unauthorized status code)
logger.warning("HTTP response from: %s; status code: 401 - obtaining new token", url)
token = await self.refresh_authorization_token()
if (
response_data.status == 401 # noqa: PLR2004 (Unauthorized status code)
and url != self.configuration.host + "/token"
):
logger.warning(
"HTTP response from: %s; status code: 401 - obtaining new token", url
)
async with self._token_lock:
# The lock is intentionally broad (covers the whole block including the token state)
# to avoid race conditions: without it, other coroutines could mis-evaluate
# token state during a context switch and trigger redundant refreshes
token_expired = (
self.configuration.token_update_time > 0
and time.time()
>= self.configuration.token_update_time
+ self.configuration.auth_token_ttl_sec
)
invalid_token = not self.configuration._http_config.api_key.get("api_key")

if invalid_token or token_expired:
token = await self.refresh_authorization_token()
else:
token = self.configuration._http_config.api_key["api_key"]
header_params["X-Authorization"] = token
response_data = await self.rest_client.request(
method,
Expand All @@ -59,7 +89,9 @@ async def call_api(
_request_timeout=_request_timeout,
)
except ApiException as e:
logger.error("HTTP request failed url: %s status: %s; reason: %s", url, e.status, e.reason)
logger.error(
"HTTP request failed url: %s status: %s; reason: %s", url, e.status, e.reason
)
raise e

return response_data
Expand All @@ -82,12 +114,10 @@ def response_deserialize(
if (
not response_type
and isinstance(response_data.status, int)
and 100 <= response_data.status <= 599
and 100 <= response_data.status <= 599 # noqa: PLR2004
):
# if not found, look for '1XX', '2XX', etc.
response_type = response_types_map.get(
str(response_data.status)[0] + "XX", None
)
response_type = response_types_map.get(str(response_data.status)[0] + "XX", None)

# deserialize response data
response_text = None
Expand All @@ -104,12 +134,10 @@ def response_deserialize(
match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type)
encoding = match.group(1) if match else "utf-8"
response_text = response_data.data.decode(encoding)
return_data = self.deserialize(
response_text, response_type, content_type
)
return_data = self.deserialize(response_text, response_type, content_type)
finally:
if not 200 <= response_data.status <= 299:
logger.error(f"Unexpected response status code: {response_data.status}")
if not 200 <= response_data.status <= 299: # noqa: PLR2004
logger.error("Unexpected response status code: %s", response_data.status)
raise ApiException.from_response(
http_resp=response_data,
body=response_text,
Expand All @@ -126,8 +154,9 @@ def response_deserialize(
async def refresh_authorization_token(self):
obtain_new_token_response = await self.obtain_new_token()
token = obtain_new_token_response.get("token")
self.configuration.api_key["api_key"] = token
logger.debug(f"New auth token been set")
self.configuration._http_config.api_key["api_key"] = token
self.configuration.token_update_time = time.time()
logger.debug("New auth token been set")
return token

async def obtain_new_token(self):
Expand Down
17 changes: 7 additions & 10 deletions src/conductor/asyncio_client/configuration/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
auth_key: Optional[str] = None,
auth_secret: Optional[str] = None,
debug: bool = False,
auth_token_ttl_min: int = 45,
# Worker properties
polling_interval: Optional[int] = None,
domain: Optional[str] = None,
Expand Down Expand Up @@ -136,10 +137,6 @@ def __init__(
if api_key is None:
api_key = {}

if self.auth_key and self.auth_secret:
# Use the auth_key as the API key for X-Authorization header
api_key["api_key"] = self.auth_key

self.__ui_host = os.getenv("CONDUCTOR_UI_SERVER_URL")
if self.__ui_host is None:
self.__ui_host = self.server_url.replace("/api", "")
Expand Down Expand Up @@ -182,6 +179,10 @@ def __init__(

self.is_logger_config_applied = False

# Orkes Conductor auth token properties
self.token_update_time = 0
self.auth_token_ttl_sec = auth_token_ttl_min * 60

def _get_env_float(self, env_var: str, default: float) -> float:
"""Get float value from environment variable with default fallback."""
try:
Expand Down Expand Up @@ -268,9 +269,7 @@ def _convert_property_value(self, property_name: str, value: str) -> Any:
# For other properties, return as string
return value

def set_worker_property(
self, task_type: str, property_name: str, value: Any
) -> None:
def set_worker_property(self, task_type: str, property_name: str, value: Any) -> None:
"""
Set worker property for a specific task type.

Expand Down Expand Up @@ -523,7 +522,5 @@ def ui_host(self):
def __getattr__(self, name: str) -> Any:
"""Delegate attribute access to underlying HTTP configuration."""
if "_http_config" not in self.__dict__ or self._http_config is None:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
return getattr(self._http_config, name)
Loading
Loading