Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
tests/test_results/

# Translations
*.mo
Expand Down
16 changes: 15 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dev = [
"black>=25.1.0",
"pytest>=8.3.2",
"pytest-cov>=5.0.0",
"pytest-mock>=3.14.0",
]

[tool.pdm.scripts]
Expand Down
123 changes: 101 additions & 22 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,65 +7,144 @@
from llama_stack_client import LlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore

from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, HTTPException, status

from client import get_llama_stack_client
from configuration import configuration
from models.responses import QueryResponse
from models.requests import QueryRequest, Attachment
import constants

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["models"])
router = APIRouter(tags=["query"])


query_response: dict[int | str, dict[str, Any]] = {
200: {
"query": "User query",
"answer": "LLM ansert",
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM ansert",
},
}


@router.post("/query", responses=query_response)
def query_endpoint_handler(request: Request, query: str) -> QueryResponse:
def query_endpoint_handler(
request: Request, query_request: QueryRequest
) -> QueryResponse:
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

client = get_llama_stack_client(llama_stack_config)

# retrieve list of available models
models = client.models.list()

# select the first LLM
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier

logger.info("Model: %s", model_id)

response = retrieve_response(client, model_id, query)

return QueryResponse(query=query, response=response)
model_id = select_model_id(client, query_request)
response = retrieve_response(client, model_id, query_request)
return QueryResponse(
conversation_id=query_request.conversation_id, response=response
)


def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str:
def select_model_id(client: LlamaStackClient, query_request: QueryRequest) -> str:
"""Select the model ID based on the request or available models."""
models = client.models.list()
model_id = query_request.model
provider_id = query_request.provider

# TODO(lucasagomes): support default model selection via configuration
if not model_id:
logger.info("No model specified in request, using the first available LLM")
try:
return next(m for m in models if m.model_type == "llm").identifier
except (StopIteration, AttributeError):
message = "No LLM model found in available models"
logger.error(message)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
"cause": message,
},
)

logger.info(f"Searching for model: {model_id}, provider: {provider_id}")
if not any(
m.identifier == model_id and m.provider_id == provider_id for m in models
):
message = f"Model {model_id} from provider {provider_id} not found in available models"
logger.error(message)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
"cause": message,
},
)

return model_id


def retrieve_response(
client: LlamaStackClient, model_id: str, query_request: QueryRequest
) -> str:

available_shields = [shield.identifier for shield in client.shields.list()]
if not available_shields:
logger.info("No available shields. Disabling safety")
else:
logger.info(f"Available shields found: {available_shields}")

# use system prompt from request or default one
system_prompt = (
query_request.system_prompt
if query_request.system_prompt
else constants.DEFAULT_SYSTEM_PROMPT
)
logger.debug(f"Using system prompt: {system_prompt}")

# TODO(lucasagomes): redact attachments content before sending to LLM
# if attachments are provided, validate them
if query_request.attachments:
validate_attachments_metadata(query_request.attachments)

agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
tools=[],
)
session_id = agent.create_session("chat_session")
response = agent.create_turn(
messages=[UserMessage(role="user", content=prompt)],
messages=[UserMessage(role="user", content=query_request.query)],
session_id=session_id,
documents=query_request.get_documents(),
stream=False,
)

return str(response.output_message.content)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
"""Validate the attachments metadata provided in the request.
Raises HTTPException if any attachment has an improper type or content type.
"""
for attachment in attachments:
if attachment.attachment_type not in constants.ATTACHMENT_TYPES:
message = (
f"Attachment with improper type {attachment.attachment_type} detected"
)
logger.error(message)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
"cause": message,
},
)
if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES:
message = f"Attachment with improper content type {attachment.content_type} detected"
logger.error(message)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
"cause": message,
},
)
21 changes: 21 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request"

# Supported attachment types
ATTACHMENT_TYPES = frozenset(
{
"alert",
"api object",
"configuration",
"error message",
"event",
"log",
"stack trace",
}
)

# Supported attachment content types
ATTACHMENT_CONTENT_TYPES = frozenset(
{"text/plain", "application/json", "application/yaml", "application/xml"}
)

DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant"
127 changes: 127 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from pydantic import BaseModel, model_validator
from llama_stack_client.types.agents.turn_create_params import Document
from typing import Optional, Self


class Attachment(BaseModel):
"""Model representing an attachment that can be send from UI as part of query.

List of attachments can be optional part of 'query' request.

Attributes:
attachment_type: The attachment type, like "log", "configuration" etc.
content_type: The content type as defined in MIME standard
content: The actual attachment content

YAML attachments with **kind** and **metadata/name** attributes will
be handled as resources with specified name:
```
kind: Pod
metadata:
name: private-reg
```
"""

attachment_type: str
content_type: str
content: str

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"attachment_type": "log",
"content_type": "text/plain",
"content": "this is attachment",
},
{
"attachment_type": "configuration",
"content_type": "application/yaml",
"content": "kind: Pod\n metadata:\n name: private-reg",
},
{
"attachment_type": "configuration",
"content_type": "application/yaml",
"content": "foo: bar",
},
]
}
}


# TODO(lucasagomes): add media_type when needed, current implementation
# does not support streaming response, so this is not used
class QueryRequest(BaseModel):
"""Model representing a request for the LLM (Language Model).

Attributes:
query: The query string.
conversation_id: The optional conversation ID (UUID).
provider: The optional provider.
model: The optional model.
attachments: The optional attachments.

Example:
```python
query_request = QueryRequest(query="Tell me about Kubernetes")
```
"""

query: str
conversation_id: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
system_prompt: Optional[str] = None
attachments: Optional[list[Attachment]] = None

# provides examples for /docs endpoint
model_config = {
"extra": "forbid",
"json_schema_extra": {
"examples": [
{
"query": "write a deployment yaml for the mongodb image",
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"provider": "openai",
"model": "model-name",
"system_prompt": "You are a helpful assistant",
"attachments": [
{
"attachment_type": "log",
"content_type": "text/plain",
"content": "this is attachment",
},
{
"attachment_type": "configuration",
"content_type": "application/yaml",
"content": "kind: Pod\n metadata:\n name: private-reg",
},
{
"attachment_type": "configuration",
"content_type": "application/yaml",
"content": "foo: bar",
},
],
}
]
},
}

def get_documents(self) -> list[Document]:
"""Returns the list of documents from the attachments."""
if not self.attachments:
return []
return [
Document(content=att.content, mime_type=att.content_type)
for att in self.attachments
]

@model_validator(mode="after")
def validate_provider_and_model(self) -> Self:
"""Perform validation on the provider and model."""
if self.model and not self.provider:
raise ValueError("Provider must be specified if model is specified")
if self.provider and not self.model:
raise ValueError("Model must be specified if provider is specified")
return self
Loading