diff --git a/docs/openapi.json b/docs/openapi.json index 7df656f5..9413565a 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -96,6 +96,9 @@ "model_type": "llm" } ] + }, + "503": { + "description": "Connection to Llama Stack is broken" } } } diff --git a/src/app/endpoints/models.py b/src/app/endpoints/models.py index 6f7a1934..6df1a162 100644 --- a/src/app/endpoints/models.py +++ b/src/app/endpoints/models.py @@ -3,11 +3,13 @@ import logging from typing import Any -from fastapi import APIRouter, Request +from llama_stack_client import APIConnectionError +from fastapi import APIRouter, HTTPException, Request, status from client import get_llama_stack_client from configuration import configuration from models.responses import ModelsResponse +from utils.endpoints import check_configuration_loaded logger = logging.getLogger(__name__) router = APIRouter(tags=["models"]) @@ -36,16 +38,42 @@ }, ] }, + 503: {"description": "Connection to Llama Stack is broken"}, } @router.get("/models", responses=models_responses) def models_endpoint_handler(_request: Request) -> ModelsResponse: """Handle requests to the /models endpoint.""" - llama_stack_config = configuration.llama_stack_configuration - logger.info("LLama stack config: %s", llama_stack_config) + check_configuration_loaded(configuration) - client = get_llama_stack_client(llama_stack_config) - models = client.models.list() - m = [dict(m) for m in models] - return ModelsResponse(models=m) + llama_stack_configuration = configuration.llama_stack_configuration + logger.info("LLama stack config: %s", llama_stack_configuration) + + try: + # try to get Llama Stack client + client = get_llama_stack_client(llama_stack_configuration) + # retrieve models + models = client.models.list() + m = [dict(m) for m in models] + return ModelsResponse(models=m) + # connection to Llama Stack server + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + # any other exception that can occur during model listing + except Exception as e: + logger.error("Unable to retrieve list of models: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to retrieve list of models", + "cause": str(e), + }, + ) from e diff --git a/tests/unit/app/endpoints/test_models.py b/tests/unit/app/endpoints/test_models.py new file mode 100644 index 00000000..c9943ea2 --- /dev/null +++ b/tests/unit/app/endpoints/test_models.py @@ -0,0 +1,132 @@ +import pytest + +from unittest.mock import Mock +from fastapi import HTTPException, status + +from app.endpoints.models import models_endpoint_handler +from configuration import AppConfig + + +def test_models_endpoint_handler_configuration_not_loaded(mocker): + """Test the models endpoint handler if configuration is not loaded.""" + # simulate state when no configuration is loaded + mocker.patch( + "app.endpoints.models.configuration", + return_value=mocker.Mock(), + ) + mocker.patch("app.endpoints.models.configuration", None) + + request = None + with pytest.raises(HTTPException) as e: + models_endpoint_handler(request) + assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Configuration is not loaded" + + +def test_models_endpoint_handler_improper_llama_stack_configuration(mocker): + """Test the models endpoint handler if Llama Stack configuration is not proper.""" + # configuration for tests + config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_disabled": True, + }, + "mcp_servers": [], + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + mocker.patch( + "app.endpoints.models.configuration", + return_value=None, + ) + + request = None + with pytest.raises(HTTPException) as e: + models_endpoint_handler(request) + assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "LLama stack is not configured" + + +def test_models_endpoint_handler_configuration_loaded(mocker): + """Test the models endpoint handler if configuration is loaded.""" + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_disabled": True, + }, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + with pytest.raises(HTTPException) as e: + request = None + models_endpoint_handler(request) + assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Unable to connect to Llama Stack" + + +def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker): + """Test the models endpoint handler if configuration is loaded.""" + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_disabled": True, + }, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + # Mock the LlamaStack client + mock_client = Mock() + mock_client.models.list.return_value = [] + + # Mock the LlamaStack client (shouldn't be called directly) + mocker.patch( + "app.endpoints.models.get_llama_stack_client", return_value=mock_client + ) + + request = None + response = models_endpoint_handler(request) + assert response is not None