diff --git a/src/app/endpoints/health.py b/src/app/endpoints/health.py index cf720bf6..65ff3904 100644 --- a/src/app/endpoints/health.py +++ b/src/app/endpoints/health.py @@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import HealthStatus from fastapi import APIRouter, status, Response -from client import LlamaStackClientHolder +from client import AsyncLlamaStackClientHolder from models.responses import ( LivenessResponse, ReadinessResponse, @@ -22,16 +22,17 @@ router = APIRouter(tags=["health"]) -def get_providers_health_statuses() -> list[ProviderHealthStatus]: +async def get_providers_health_statuses() -> list[ProviderHealthStatus]: """Check health of all providers. Returns: List of provider health statuses. """ try: - client = LlamaStackClientHolder().get_client() + client = AsyncLlamaStackClientHolder().get_client() - providers = client.providers.list() + # providers = [] + providers = await client.providers.list() logger.debug("Found %d providers", len(providers)) health_results = [ @@ -69,9 +70,9 @@ def get_providers_health_statuses() -> list[ProviderHealthStatus]: @router.get("/readiness", responses=get_readiness_responses) -def readiness_probe_get_method(response: Response) -> ReadinessResponse: +async def readiness_probe_get_method(response: Response) -> ReadinessResponse: """Ready status of service with provider health details.""" - provider_statuses = get_providers_health_statuses() + provider_statuses = await get_providers_health_statuses() # Check if any provider is unhealthy (not counting not_implemented as unhealthy) unhealthy_providers = [ diff --git a/tests/unit/app/endpoints/test_health.py b/tests/unit/app/endpoints/test_health.py index 4583bdb8..02714ad2 100644 --- a/tests/unit/app/endpoints/test_health.py +++ b/tests/unit/app/endpoints/test_health.py @@ -12,7 +12,7 @@ from models.responses import ProviderHealthStatus, ReadinessResponse -def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): +async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): """Test the readiness endpoint handler fails when providers are unhealthy.""" # Mock get_providers_health_statuses to return an unhealthy provider mock_get_providers_health_statuses = mocker.patch( @@ -29,7 +29,7 @@ def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): # Mock the Response object mock_response = Mock() - response = readiness_probe_get_method(mock_response) + response = await readiness_probe_get_method(mock_response) assert response.ready is False assert "test_provider" in response.reason @@ -37,7 +37,7 @@ def test_readiness_probe_fails_due_to_unhealthy_providers(mocker): assert mock_response.status_code == 503 -def test_readiness_probe_success_when_all_providers_healthy(mocker): +async def test_readiness_probe_success_when_all_providers_healthy(mocker): """Test the readiness endpoint handler succeeds when all providers are healthy.""" # Mock get_providers_health_statuses to return healthy providers mock_get_providers_health_statuses = mocker.patch( @@ -59,7 +59,7 @@ def test_readiness_probe_success_when_all_providers_healthy(mocker): # Mock the Response object mock_response = Mock() - response = readiness_probe_get_method(mock_response) + response = await readiness_probe_get_method(mock_response) assert response is not None assert isinstance(response, ReadinessResponse) assert response.ready is True @@ -98,13 +98,13 @@ def test_provider_health_status_optional_fields(self): class TestGetProvidersHealthStatuses: """Test cases for the get_providers_health_statuses function.""" - def test_get_providers_health_statuses(self, mocker): + async def test_get_providers_health_statuses(self, mocker): """Test get_providers_health_statuses with healthy providers.""" # Mock the imports - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") # Mock the client and its methods - mock_client = mocker.Mock() + mock_client = mocker.AsyncMock() mock_lsc.return_value = mock_client # Mock providers.list() to return providers with health @@ -136,7 +136,7 @@ def test_get_providers_health_statuses(self, mocker): ] # Mock configuration - result = get_providers_health_statuses() + result = await get_providers_health_statuses() assert len(result) == 3 assert result[0].provider_id == "provider1" @@ -149,15 +149,15 @@ def test_get_providers_health_statuses(self, mocker): assert result[2].status == HealthStatus.ERROR.value assert result[2].message == "Connection failed" - def test_get_providers_health_statuses_connection_error(self, mocker): + async def test_get_providers_health_statuses_connection_error(self, mocker): """Test get_providers_health_statuses when connection fails.""" # Mock the imports - mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") # Mock get_llama_stack_client to raise an exception mock_lsc.side_effect = Exception("Connection error") - result = get_providers_health_statuses() + result = await get_providers_health_statuses() assert len(result) == 1 assert result[0].provider_id == "unknown"