From 2c46aa9cbf13efe36a8b4b2210fb9e7ef9a8064c Mon Sep 17 00:00:00 2001 From: Lucas Alvares Gomes Date: Mon, 26 May 2025 13:49:37 +0100 Subject: [PATCH] Address /query endpoint compatibility This patch add some compatibily with the /query endpoint from OLS-service. For the request itself, everything should be identical now with the exception of the "media_type" field. The "media_type" field has not been included since the current implmentantion does not yet support streaming responses. For the response to a query, there are a lot of fields missing. These also should be added as we start adding more features to Lightspeed Core. At the moment, only the basic response is returned. TODOs were left in the code pointing to those gaps. Signed-off-by: Lucas Alvares Gomes --- .gitignore | 1 + pdm.lock | 16 ++- pyproject.toml | 1 + src/app/endpoints/query.py | 123 +++++++++++++---- src/constants.py | 21 +++ src/models/requests.py | 127 ++++++++++++++++++ src/models/responses.py | 34 ++++- tests/unit/app/endpoints/test_query.py | 174 +++++++++++++++++++++++++ tests/unit/models/test_requests.py | 125 ++++++++++++++++++ tests/unit/models/test_responses.py | 20 +++ 10 files changed, 616 insertions(+), 26 deletions(-) create mode 100644 src/constants.py create mode 100644 src/models/requests.py create mode 100644 tests/unit/app/endpoints/test_query.py create mode 100644 tests/unit/models/test_requests.py create mode 100644 tests/unit/models/test_responses.py diff --git a/.gitignore b/.gitignore index 3b58d13f..4ae1304e 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/test_results/ # Translations *.mo diff --git a/pdm.lock b/pdm.lock index 92f68419..5de43914 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:db82049ff8c8d98dacd64aa05d47b871510a406ed837a638ea81b30eeece7ab5" +content_hash = "sha256:f3dde2e916169abc41f23400e9488ae43f2eeb203c4dbb2e505ba28b9a853677" [[metadata.targets]] requires_python = ">=3.11.1,<=3.12.10" @@ -1061,6 +1061,20 @@ files = [ {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"}, ] +[[package]] +name = "pytest-mock" +version = "3.14.1" +requires_python = ">=3.8" +summary = "Thin-wrapper around the mock package for easier use with pytest" +groups = ["dev"] +dependencies = [ + "pytest>=6.2.5", +] +files = [ + {file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"}, + {file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/pyproject.toml b/pyproject.toml index b5f751c9..e41a71a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dev = [ "black>=25.1.0", "pytest>=8.3.2", "pytest-cov>=5.0.0", + "pytest-mock>=3.14.0", ] [tool.pdm.scripts] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index abe2eaa9..bea2162d 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -7,46 +7,82 @@ from llama_stack_client import LlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, HTTPException, status from client import get_llama_stack_client from configuration import configuration from models.responses import QueryResponse +from models.requests import QueryRequest, Attachment +import constants logger = logging.getLogger("app.endpoints.handlers") -router = APIRouter(tags=["models"]) +router = APIRouter(tags=["query"]) query_response: dict[int | str, dict[str, Any]] = { 200: { - "query": "User query", - "answer": "LLM ansert", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "response": "LLM ansert", }, } @router.post("/query", responses=query_response) -def query_endpoint_handler(request: Request, query: str) -> QueryResponse: +def query_endpoint_handler( + request: Request, query_request: QueryRequest +) -> QueryResponse: llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - client = get_llama_stack_client(llama_stack_config) - - # retrieve list of available models - models = client.models.list() - - # select the first LLM - llm = next(m for m in models if m.model_type == "llm") - model_id = llm.identifier - - logger.info("Model: %s", model_id) - - response = retrieve_response(client, model_id, query) - - return QueryResponse(query=query, response=response) + model_id = select_model_id(client, query_request) + response = retrieve_response(client, model_id, query_request) + return QueryResponse( + conversation_id=query_request.conversation_id, response=response + ) -def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str: +def select_model_id(client: LlamaStackClient, query_request: QueryRequest) -> str: + """Select the model ID based on the request or available models.""" + models = client.models.list() + model_id = query_request.model + provider_id = query_request.provider + + # TODO(lucasagomes): support default model selection via configuration + if not model_id: + logger.info("No model specified in request, using the first available LLM") + try: + return next(m for m in models if m.model_type == "llm").identifier + except (StopIteration, AttributeError): + message = "No LLM model found in available models" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) + + logger.info(f"Searching for model: {model_id}, provider: {provider_id}") + if not any( + m.identifier == model_id and m.provider_id == provider_id for m in models + ): + message = f"Model {model_id} from provider {provider_id} not found in available models" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) + + return model_id + + +def retrieve_response( + client: LlamaStackClient, model_id: str, query_request: QueryRequest +) -> str: available_shields = [shield.identifier for shield in client.shields.list()] if not available_shields: @@ -54,18 +90,61 @@ def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> s else: logger.info(f"Available shields found: {available_shields}") + # use system prompt from request or default one + system_prompt = ( + query_request.system_prompt + if query_request.system_prompt + else constants.DEFAULT_SYSTEM_PROMPT + ) + logger.debug(f"Using system prompt: {system_prompt}") + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + agent = Agent( client, model=model_id, - instructions="You are a helpful assistant", + instructions=system_prompt, input_shields=available_shields if available_shields else [], tools=[], ) session_id = agent.create_session("chat_session") response = agent.create_turn( - messages=[UserMessage(role="user", content=prompt)], + messages=[UserMessage(role="user", content=query_request.query)], session_id=session_id, + documents=query_request.get_documents(), stream=False, ) return str(response.output_message.content) + + +def validate_attachments_metadata(attachments: list[Attachment]) -> None: + """Validate the attachments metadata provided in the request. + Raises HTTPException if any attachment has an improper type or content type. + """ + for attachment in attachments: + if attachment.attachment_type not in constants.ATTACHMENT_TYPES: + message = ( + f"Attachment with improper type {attachment.attachment_type} detected" + ) + logger.error(message) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) + if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES: + message = f"Attachment with improper content type {attachment.content_type} detected" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 00000000..5699962e --- /dev/null +++ b/src/constants.py @@ -0,0 +1,21 @@ +UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request" + +# Supported attachment types +ATTACHMENT_TYPES = frozenset( + { + "alert", + "api object", + "configuration", + "error message", + "event", + "log", + "stack trace", + } +) + +# Supported attachment content types +ATTACHMENT_CONTENT_TYPES = frozenset( + {"text/plain", "application/json", "application/yaml", "application/xml"} +) + +DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" diff --git a/src/models/requests.py b/src/models/requests.py new file mode 100644 index 00000000..bddd6b79 --- /dev/null +++ b/src/models/requests.py @@ -0,0 +1,127 @@ +from pydantic import BaseModel, model_validator +from llama_stack_client.types.agents.turn_create_params import Document +from typing import Optional, Self + + +class Attachment(BaseModel): + """Model representing an attachment that can be send from UI as part of query. + + List of attachments can be optional part of 'query' request. + + Attributes: + attachment_type: The attachment type, like "log", "configuration" etc. + content_type: The content type as defined in MIME standard + content: The actual attachment content + + YAML attachments with **kind** and **metadata/name** attributes will + be handled as resources with specified name: + ``` + kind: Pod + metadata: + name: private-reg + ``` + """ + + attachment_type: str + content_type: str + content: str + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "attachment_type": "log", + "content_type": "text/plain", + "content": "this is attachment", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "kind: Pod\n metadata:\n name: private-reg", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "foo: bar", + }, + ] + } + } + + +# TODO(lucasagomes): add media_type when needed, current implementation +# does not support streaming response, so this is not used +class QueryRequest(BaseModel): + """Model representing a request for the LLM (Language Model). + + Attributes: + query: The query string. + conversation_id: The optional conversation ID (UUID). + provider: The optional provider. + model: The optional model. + attachments: The optional attachments. + + Example: + ```python + query_request = QueryRequest(query="Tell me about Kubernetes") + ``` + """ + + query: str + conversation_id: Optional[str] = None + provider: Optional[str] = None + model: Optional[str] = None + system_prompt: Optional[str] = None + attachments: Optional[list[Attachment]] = None + + # provides examples for /docs endpoint + model_config = { + "extra": "forbid", + "json_schema_extra": { + "examples": [ + { + "query": "write a deployment yaml for the mongodb image", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "provider": "openai", + "model": "model-name", + "system_prompt": "You are a helpful assistant", + "attachments": [ + { + "attachment_type": "log", + "content_type": "text/plain", + "content": "this is attachment", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "kind: Pod\n metadata:\n name: private-reg", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "foo: bar", + }, + ], + } + ] + }, + } + + def get_documents(self) -> list[Document]: + """Returns the list of documents from the attachments.""" + if not self.attachments: + return [] + return [ + Document(content=att.content, mime_type=att.content_type) + for att in self.attachments + ] + + @model_validator(mode="after") + def validate_provider_and_model(self) -> Self: + """Perform validation on the provider and model.""" + if self.model and not self.provider: + raise ValueError("Provider must be specified if model is specified") + if self.provider and not self.model: + raise ValueError("Model must be specified if provider is specified") + return self diff --git a/src/models/responses.py b/src/models/responses.py index 5aad89ce..b1398269 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Any +from typing import Any, Optional class ModelsResponse(BaseModel): @@ -8,12 +8,40 @@ class ModelsResponse(BaseModel): models: list[dict[str, Any]] +# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now +# we are keeping it simple. The missing fields are: +# - referenced_documents: The optional URLs and titles for the documents used +# to generate the response. +# - truncated: Set to True if conversation history was truncated to be within context window. +# - input_tokens: Number of tokens sent to LLM +# - output_tokens: Number of tokens received from LLM +# - available_quotas: Quota available as measured by all configured quota limiters +# - tool_calls: List of tool requests. +# - tool_results: List of tool results. +# See LLMResponse in ols-service for more details. class QueryResponse(BaseModel): - """Model representing LLM response to a query.""" + """Model representing LLM response to a query. - query: str + Attributes: + conversation_id: The optional conversation ID (UUID). + response: The response. + """ + + conversation_id: Optional[str] = None response: str + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "response": "Operator Lifecycle Manager (OLM) helps users install...", + } + ] + } + } + class InfoResponse(BaseModel): """Model representing a response to a info request. diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py new file mode 100644 index 00000000..db9698e7 --- /dev/null +++ b/tests/unit/app/endpoints/test_query.py @@ -0,0 +1,174 @@ +from fastapi import HTTPException, status +import pytest + +from app.endpoints.query import ( + query_endpoint_handler, + select_model_id, + retrieve_response, + validate_attachments_metadata, +) +from models.requests import QueryRequest, Attachment +from llama_stack_client.types import UserMessage # type: ignore + + +def test_query_endpoint_handler(mocker): + """Test the query endpoint handler.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), + ] + + mocker.patch( + "app.endpoints.query.configuration", + return_value=mocker.Mock(), + ) + mocker.patch("app.endpoints.query.get_llama_stack_client", return_value=mock_client) + mocker.patch("app.endpoints.query.retrieve_response", return_value="LLM answer") + mocker.patch("app.endpoints.query.select_model_id", return_value="fake_model_id") + + query_request = QueryRequest(query="What is OpenStack?") + + response = query_endpoint_handler(None, query_request) + + assert response.response == "LLM answer" + + +def test_select_model_id(mocker): + """Test the select_model_id function.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), + ] + + query_request = QueryRequest( + query="What is OpenStack?", model="model1", provider="provider1" + ) + + model_id = select_model_id(mock_client, query_request) + + assert model_id == "model1" + + +def test_select_model_id_no_model(mocker): + """Test the select_model_id function when no model is specified.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="not_llm_type", model_type="embedding", provider_id="provider1" + ), + mocker.Mock( + identifier="first_model", model_type="llm", provider_id="provider1" + ), + mocker.Mock( + identifier="second_model", model_type="llm", provider_id="provider2" + ), + ] + + query_request = QueryRequest(query="What is OpenStack?") + + model_id = select_model_id(mock_client, query_request) + + # Assert return the first available LLM model + assert model_id == "first_model" + + +def test_select_model_id_invalid_model(mocker): + """Test the select_model_id function with an invalid model.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + query_request = QueryRequest( + query="What is OpenStack?", model="invalid_model", provider="provider1" + ) + + with pytest.raises(Exception) as exc_info: + select_model_id(mock_client, query_request) + + assert ( + "Model invalid_model from provider provider1 not found in available models" + in str(exc_info.value) + ) + + +def test_validate_attachments_metadata(): + """Test the validate_attachments_metadata function.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + + # If no exception is raised, the test passes + validate_attachments_metadata(attachments) + + +def test_validate_attachments_metadata_invalid_type(): + """Test the validate_attachments_metadata function with invalid attachment type.""" + attachments = [ + Attachment( + attachment_type="invalid_type", + content_type="text/plain", + content="this is attachment", + ), + ] + + with pytest.raises(HTTPException) as exc_info: + validate_attachments_metadata(attachments) + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert ( + "Attachment with improper type invalid_type detected" + in exc_info.value.detail["cause"] + ) + + +def test_validate_attachments_metadata_invalid_content_type(): + """Test the validate_attachments_metadata function with invalid attachment type.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/invalid_content_type", + content="this is attachment", + ), + ] + + with pytest.raises(HTTPException) as exc_info: + validate_attachments_metadata(attachments) + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert ( + "Attachment with improper content type text/invalid_content_type detected" + in exc_info.value.detail["cause"] + ) + + +def test_retrieve_response(mocker): + """Test the retrieve_response function.""" + mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.Mock() + mock_client.shields.list.return_value = [] + + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + + response = retrieve_response(mock_client, model_id, query_request) + + assert response == "LLM answer" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], + session_id=mocker.ANY, + documents=[], + stream=False, + ) diff --git a/tests/unit/models/test_requests.py b/tests/unit/models/test_requests.py new file mode 100644 index 00000000..a9db072f --- /dev/null +++ b/tests/unit/models/test_requests.py @@ -0,0 +1,125 @@ +import pytest + +from models.requests import QueryRequest, Attachment + + +class TestAttachment: + """Test cases for the Attachment model.""" + + def test_constructor(self) -> None: + """Test the Attachment with custom values.""" + a = Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ) + assert a.attachment_type == "configuration" + assert a.content_type == "application/yaml" + assert a.content == "kind: Pod\n metadata:\n name: private-reg" + + +class TestQueryRequest: + """Test cases for the QueryRequest model.""" + + def test_constructor(self) -> None: + """Test the QueryRequest constructor.""" + qr = QueryRequest(query="Tell me about Kubernetes") + + assert qr.query == "Tell me about Kubernetes" + assert qr.conversation_id is None + assert qr.provider is None + assert qr.model is None + assert qr.system_prompt is None + assert qr.attachments is None + + def test_with_attachments(self) -> None: + """Test the QueryRequest with attachments.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + qr = QueryRequest( + query="Tell me about Kubernetes", + attachments=attachments, + ) + assert len(qr.attachments) == 2 + assert qr.attachments[0].attachment_type == "log" + assert qr.attachments[0].content_type == "text/plain" + assert qr.attachments[0].content == "this is attachment" + assert qr.attachments[1].attachment_type == "configuration" + assert qr.attachments[1].content_type == "application/yaml" + assert ( + qr.attachments[1].content == "kind: Pod\n metadata:\n name: private-reg" + ) + + def test_with_optional_fields(self) -> None: + """Test the QueryRequest with optional fields.""" + qr = QueryRequest( + query="Tell me about Kubernetes", + conversation_id="123e4567-e89b-12d3-a456-426614174000", + provider="OpenAI", + model="gpt-3.5-turbo", + system_prompt="You are a helpful assistant", + ) + assert qr.query == "Tell me about Kubernetes" + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.provider == "OpenAI" + assert qr.model == "gpt-3.5-turbo" + assert qr.system_prompt == "You are a helpful assistant" + assert qr.attachments is None + + def test_get_documents(self) -> None: + """Test the get_documents method.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + qr = QueryRequest( + query="Tell me about Kubernetes", + attachments=attachments, + ) + documents = qr.get_documents() + assert len(documents) == 2 + assert documents[0]["content"] == "this is attachment" + assert documents[0]["mime_type"] == "text/plain" + assert documents[1]["content"] == "kind: Pod\n metadata:\n name: private-reg" + assert documents[1]["mime_type"] == "application/yaml" + + def test_validate_provider_and_model(self) -> None: + """Test the validate_provider_and_model method.""" + qr = QueryRequest( + query="Tell me about Kubernetes", + provider="OpenAI", + model="gpt-3.5-turbo", + ) + validated_qr = qr.validate_provider_and_model() + assert validated_qr.provider == "OpenAI" + assert validated_qr.model == "gpt-3.5-turbo" + + # Test with missing provider + with pytest.raises( + ValueError, match="Provider must be specified if model is specified" + ): + QueryRequest(query="Tell me about Kubernetes", model="gpt-3.5-turbo") + + # Test with missing model + with pytest.raises( + ValueError, match="Model must be specified if provider is specified" + ): + QueryRequest(query="Tell me about Kubernetes", provider="OpenAI") diff --git a/tests/unit/models/test_responses.py b/tests/unit/models/test_responses.py new file mode 100644 index 00000000..b9db5f29 --- /dev/null +++ b/tests/unit/models/test_responses.py @@ -0,0 +1,20 @@ +from models.responses import QueryResponse + + +class TestQueryResponse: + """Test cases for the QueryResponse model.""" + + def test_constructor(self) -> None: + """Test the QueryResponse constructor.""" + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="LLM answer", + ) + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.response == "LLM answer" + + def test_optional_conversation_id(self) -> None: + """Test the QueryResponse with default conversation ID.""" + qr = QueryResponse(response="LLM answer") + assert qr.conversation_id is None + assert qr.response == "LLM answer"