diff --git a/docs/openapi.json b/docs/openapi.json index a774e09a..c0057989 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -123,6 +123,40 @@ } } }, + "/v1/shields": { + "get": { + "tags": [ + "shields" + ], + "summary": "Shields Endpoint Handler", + "description": "Handle requests to the /shields endpoint.\n\nProcess GET requests to the /shields endpoint, returning a list of available\nshields from the Llama Stack service.\n\nRaises:\n HTTPException: If unable to connect to the Llama Stack server or if\n shield retrieval fails for any reason.\n\nReturns:\n ShieldsResponse: An object containing the list of available shields.", + "operationId": "shields_endpoint_handler_v1_shields_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ShieldsResponse" + } + } + }, + "shields": [ + { + "identifier": "lightspeed_question_validity-shield", + "provider_resource_id": "lightspeed_question_validity-shield", + "provider_id": "lightspeed_question_validity", + "type": "shield", + "params": {} + } + ] + }, + "500": { + "description": "Connection to Llama Stack is broken" + } + } + } + }, "/v1/query": { "post": { "tags": [ @@ -1082,6 +1116,7 @@ "delete_conversation", "feedback", "get_models", + "get_shields", "get_metrics", "get_config", "info", @@ -2990,6 +3025,34 @@ "title": "ServiceConfiguration", "description": "Service configuration." }, + "ShieldsResponse": { + "properties": { + "shields": { + "items": { + "additionalProperties": true, + "type": "object" + }, + "type": "array", + "title": "Shields", + "description": "List of shields available", + "examples": [ + { + "identifier": "lightspeed_question_validity-shield", + "params": {}, + "provider_id": "lightspeed_question_validity", + "provider_resource_id": "lightspeed_question_validity-shield", + "type": "shield" + } + ] + } + }, + "type": "object", + "required": [ + "shields" + ], + "title": "ShieldsResponse", + "description": "Model representing a response to shields request." + }, "StatusResponse": { "properties": { "functionality": { diff --git a/src/app/endpoints/shields.py b/src/app/endpoints/shields.py new file mode 100644 index 00000000..ce632e40 --- /dev/null +++ b/src/app/endpoints/shields.py @@ -0,0 +1,96 @@ +"""Handler for REST API call to list available shields.""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.params import Depends +from llama_stack_client import APIConnectionError + +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +from models.config import Action +from models.responses import ShieldsResponse +from utils.endpoints import check_configuration_loaded + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["shields"]) + + +shields_responses: dict[int | str, dict[str, Any]] = { + 200: { + "shields": [ + { + "identifier": "lightspeed_question_validity-shield", + "provider_resource_id": "lightspeed_question_validity-shield", + "provider_id": "lightspeed_question_validity", + "type": "shield", + "params": {}, + } + ] + }, + 500: {"description": "Connection to Llama Stack is broken"}, +} + + +@router.get("/shields", responses=shields_responses) +@authorize(Action.GET_SHIELDS) +async def shields_endpoint_handler( + request: Request, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], +) -> ShieldsResponse: + """ + Handle requests to the /shields endpoint. + + Process GET requests to the /shields endpoint, returning a list of available + shields from the Llama Stack service. + + Raises: + HTTPException: If unable to connect to the Llama Stack server or if + shield retrieval fails for any reason. + + Returns: + ShieldsResponse: An object containing the list of available shields. + """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + + check_configuration_loaded(configuration) + + llama_stack_configuration = configuration.llama_stack_configuration + logger.info("Llama stack config: %s", llama_stack_configuration) + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + # retrieve shields + shields = await client.shields.list() + s = [dict(s) for s in shields] + return ShieldsResponse(shields=s) + + # connection to Llama Stack server + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + # any other exception that can occur during shield listing + except Exception as e: + logger.error("Unable to retrieve list of shields: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to retrieve list of shields", + "cause": str(e), + }, + ) from e diff --git a/src/app/routers.py b/src/app/routers.py index bd4de2e5..42606cea 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -5,6 +5,7 @@ from app.endpoints import ( info, models, + shields, root, query, health, @@ -27,6 +28,7 @@ def include_routers(app: FastAPI) -> None: app.include_router(root.router) app.include_router(info.router, prefix="/v1") app.include_router(models.router, prefix="/v1") + app.include_router(shields.router, prefix="/v1") app.include_router(query.router, prefix="/v1") app.include_router(streaming_query.router, prefix="/v1") app.include_router(config.router, prefix="/v1") diff --git a/src/models/config.py b/src/models/config.py index 1598d16b..99850a50 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -358,6 +358,7 @@ class Action(str, Enum): DELETE_CONVERSATION = "delete_conversation" FEEDBACK = "feedback" GET_MODELS = "get_models" + GET_SHIELDS = "get_shields" GET_METRICS = "get_metrics" GET_CONFIG = "get_config" diff --git a/src/models/responses.py b/src/models/responses.py index 7345b189..09db643f 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -36,6 +36,24 @@ class ModelsResponse(BaseModel): ) +class ShieldsResponse(BaseModel): + """Model representing a response to shields request.""" + + shields: list[dict[str, Any]] = Field( + ..., + description="List of shields available", + examples=[ + { + "identifier": "lightspeed_question_validity-shield", + "provider_resource_id": "lightspeed_question_validity-shield", + "provider_id": "lightspeed_question_validity", + "type": "shield", + "params": {}, + } + ], + ) + + class RAGChunk(BaseModel): """Model representing a RAG chunk used in the response.""" diff --git a/tests/unit/app/endpoints/test_shields.py b/tests/unit/app/endpoints/test_shields.py new file mode 100644 index 00000000..1b05ce56 --- /dev/null +++ b/tests/unit/app/endpoints/test_shields.py @@ -0,0 +1,365 @@ +"""Unit tests for the /shields REST API endpoint.""" + +import pytest + +from fastapi import HTTPException, Request, status + +from llama_stack_client import APIConnectionError + +from app.endpoints.shields import shields_endpoint_handler +from configuration import AppConfig +from tests.unit.utils.auth_helpers import mock_authorization_resolvers + + +@pytest.mark.asyncio +async def test_shields_endpoint_handler_configuration_not_loaded(mocker): + """Test the shields endpoint handler if configuration is not loaded.""" + mock_authorization_resolvers(mocker) + + # simulate state when no configuration is loaded + mocker.patch( + "app.endpoints.shields.configuration", + return_value=mocker.Mock(), + ) + mocker.patch("app.endpoints.shields.configuration", None) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("user_id", "user_name", "token") + + with pytest.raises(HTTPException) as e: + await shields_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Configuration is not loaded" + + +@pytest.mark.asyncio +async def test_shields_endpoint_handler_improper_llama_stack_configuration(mocker): + """Test the shields endpoint handler if Llama Stack configuration is not proper.""" + mock_authorization_resolvers(mocker) + + # configuration for tests + config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_enabled": False, + }, + "mcp_servers": [], + "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + mocker.patch( + "app.endpoints.shields.configuration", + return_value=None, + ) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("test_user", "token", {}) + with pytest.raises(HTTPException) as e: + await shields_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Llama stack is not configured" + + +@pytest.mark.asyncio +async def test_shields_endpoint_handler_configuration_loaded(mocker): + """Test the shields endpoint handler if configuration is loaded.""" + mock_authorization_resolvers(mocker) + + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_enabled": False, + }, + "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("test_user", "token", {}) + + with pytest.raises(HTTPException) as e: + await shields_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Unable to connect to Llama Stack" + + +@pytest.mark.asyncio +async def test_shields_endpoint_handler_unable_to_retrieve_shields_list(mocker): + """Test the shields endpoint handler if configuration is loaded.""" + mock_authorization_resolvers(mocker) + + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_enabled": False, + }, + "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + # Mock the LlamaStack client + mock_client = mocker.AsyncMock() + mock_client.shields.list.return_value = [] + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_config = mocker.Mock() + mocker.patch("app.endpoints.shields.configuration", mock_config) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("test_user", "token", {}) + response = await shields_endpoint_handler(request=request, auth=auth) + assert response is not None + + +@pytest.mark.asyncio +async def test_shields_endpoint_llama_stack_connection_error(mocker): + """Test the shields endpoint when LlamaStack connection fails.""" + mock_authorization_resolvers(mocker) + + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_enabled": False, + }, + "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, + } + + # mock AsyncLlamaStackClientHolder to raise APIConnectionError + # when shields.list() method is called + mock_client = mocker.AsyncMock() + mock_client.shields.list.side_effect = APIConnectionError(request=None) + mock_client_holder = mocker.patch( + "app.endpoints.shields.AsyncLlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("test_user", "token", {}) + + with pytest.raises(HTTPException) as e: + await shields_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Unable to connect to Llama Stack" + + +@pytest.mark.asyncio +async def test_shields_endpoint_handler_success_with_shields_data(mocker): + """Test the shields endpoint handler with successful response and shields data.""" + mock_authorization_resolvers(mocker) + + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_enabled": False, + }, + "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + # Mock the LlamaStack client with sample shields data + mock_shields_data = [ + { + "identifier": "lightspeed_question_validity-shield", + "provider_resource_id": "lightspeed_question_validity-shield", + "provider_id": "lightspeed_question_validity", + "type": "shield", + "params": {}, + }, + { + "identifier": "content_filter-shield", + "provider_resource_id": "content_filter-shield", + "provider_id": "content_filter", + "type": "shield", + "params": {"threshold": 0.8}, + }, + ] + + mock_client = mocker.AsyncMock() + mock_client.shields.list.return_value = mock_shields_data + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_config = mocker.Mock() + mocker.patch("app.endpoints.shields.configuration", mock_config) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("test_user", "token", {}) + response = await shields_endpoint_handler(request=request, auth=auth) + + assert response is not None + assert hasattr(response, "shields") + assert len(response.shields) == 2 + assert response.shields[0]["identifier"] == "lightspeed_question_validity-shield" + assert response.shields[1]["identifier"] == "content_filter-shield" + + +@pytest.mark.asyncio +async def test_shields_endpoint_handler_general_exception(mocker): + """Test the shields endpoint handler when a general exception occurs.""" + mock_authorization_resolvers(mocker) + + # configuration for tests + config_dict = { + "name": "foo", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "xyzzy", + "url": "http://x.y.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "feedback_enabled": False, + }, + "customization": None, + "authorization": {"access_rules": []}, + "authentication": {"module": "noop"}, + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + + # Mock the LlamaStack client to raise a general exception + mock_client = mocker.AsyncMock() + mock_client.shields.list.side_effect = Exception("General error") + mock_client_holder = mocker.patch( + "app.endpoints.shields.AsyncLlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + mock_config = mocker.Mock() + mocker.patch("app.endpoints.shields.configuration", mock_config) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + } + ) + auth = ("test_user", "token", {}) + + with pytest.raises(HTTPException) as e: + await shields_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Unable to retrieve list of shields" + assert e.detail["cause"] == "General error" diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index aee36b7d..d1ef55a2 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -12,6 +12,7 @@ root, info, models, + shields, query, health, config, @@ -61,10 +62,11 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 12 + assert len(app.routers) == 13 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() + assert shields.router in app.get_routers() assert query.router in app.get_routers() assert streaming_query.router in app.get_routers() assert config.router in app.get_routers() @@ -81,10 +83,11 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 12 + assert len(app.routers) == 13 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" + assert app.get_router_prefix(shields.router) == "/v1" assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" assert app.get_router_prefix(config.router) == "/v1"