Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/app/endpoints/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from llama_stack.providers.datatypes import HealthStatus

from fastapi import APIRouter, status, Response
from client import LlamaStackClientHolder
from client import AsyncLlamaStackClientHolder
from models.responses import (
LivenessResponse,
ReadinessResponse,
Expand All @@ -22,16 +22,17 @@
router = APIRouter(tags=["health"])


def get_providers_health_statuses() -> list[ProviderHealthStatus]:
async def get_providers_health_statuses() -> list[ProviderHealthStatus]:
"""Check health of all providers.

Returns:
List of provider health statuses.
"""
try:
client = LlamaStackClientHolder().get_client()
client = AsyncLlamaStackClientHolder().get_client()

providers = client.providers.list()
# providers = []
providers = await client.providers.list()
logger.debug("Found %d providers", len(providers))

health_results = [
Expand Down Expand Up @@ -69,9 +70,9 @@ def get_providers_health_statuses() -> list[ProviderHealthStatus]:


@router.get("/readiness", responses=get_readiness_responses)
def readiness_probe_get_method(response: Response) -> ReadinessResponse:
async def readiness_probe_get_method(response: Response) -> ReadinessResponse:
"""Ready status of service with provider health details."""
provider_statuses = get_providers_health_statuses()
provider_statuses = await get_providers_health_statuses()

# Check if any provider is unhealthy (not counting not_implemented as unhealthy)
unhealthy_providers = [
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/app/endpoints/test_health.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from models.responses import ProviderHealthStatus, ReadinessResponse


def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
"""Test the readiness endpoint handler fails when providers are unhealthy."""
# Mock get_providers_health_statuses to return an unhealthy provider
mock_get_providers_health_statuses = mocker.patch(
Expand All @@ -29,15 +29,15 @@ def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
# Mock the Response object
mock_response = Mock()

response = readiness_probe_get_method(mock_response)
response = await readiness_probe_get_method(mock_response)

assert response.ready is False
assert "test_provider" in response.reason
assert "Providers not healthy" in response.reason
assert mock_response.status_code == 503


def test_readiness_probe_success_when_all_providers_healthy(mocker):
async def test_readiness_probe_success_when_all_providers_healthy(mocker):
"""Test the readiness endpoint handler succeeds when all providers are healthy."""
# Mock get_providers_health_statuses to return healthy providers
mock_get_providers_health_statuses = mocker.patch(
Expand All @@ -59,7 +59,7 @@ def test_readiness_probe_success_when_all_providers_healthy(mocker):
# Mock the Response object
mock_response = Mock()

response = readiness_probe_get_method(mock_response)
response = await readiness_probe_get_method(mock_response)
assert response is not None
assert isinstance(response, ReadinessResponse)
assert response.ready is True
Expand Down Expand Up @@ -98,13 +98,13 @@ def test_provider_health_status_optional_fields(self):
class TestGetProvidersHealthStatuses:
"""Test cases for the get_providers_health_statuses function."""

def test_get_providers_health_statuses(self, mocker):
async def test_get_providers_health_statuses(self, mocker):
"""Test get_providers_health_statuses with healthy providers."""
# Mock the imports
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")

# Mock the client and its methods
mock_client = mocker.Mock()
mock_client = mocker.AsyncMock()
mock_lsc.return_value = mock_client

# Mock providers.list() to return providers with health
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_get_providers_health_statuses(self, mocker):
]

# Mock configuration
result = get_providers_health_statuses()
result = await get_providers_health_statuses()

assert len(result) == 3
assert result[0].provider_id == "provider1"
Expand All @@ -149,15 +149,15 @@ def test_get_providers_health_statuses(self, mocker):
assert result[2].status == HealthStatus.ERROR.value
assert result[2].message == "Connection failed"

def test_get_providers_health_statuses_connection_error(self, mocker):
async def test_get_providers_health_statuses_connection_error(self, mocker):
"""Test get_providers_health_statuses when connection fails."""
# Mock the imports
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")

# Mock get_llama_stack_client to raise an exception
mock_lsc.side_effect = Exception("Connection error")

result = get_providers_health_statuses()
result = await get_providers_health_statuses()

assert len(result) == 1
assert result[0].provider_id == "unknown"
Expand Down
Loading