diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 8fa82f0f..b58ed73b 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -28,7 +28,7 @@ import constants from auth import get_auth_dependency from utils.common import retrieve_user_id -from utils.endpoints import check_configuration_loaded +from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.suid import get_suid logger = logging.getLogger("app.endpoints.handlers") @@ -195,11 +195,7 @@ def retrieve_response( logger.info("Available shields found: %s", available_shields) # use system prompt from request or default one - system_prompt = ( - query_request.system_prompt - if query_request.system_prompt - else constants.DEFAULT_SYSTEM_PROMPT - ) + system_prompt = get_system_prompt(query_request, configuration) logger.debug("Using system prompt: %s", system_prompt) # TODO(lucasagomes): redact attachments content before sending to LLM diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 9c27ca9d..fc35b575 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -20,8 +20,7 @@ from client import get_async_llama_stack_client from configuration import configuration from models.requests import QueryRequest -import constants -from utils.endpoints import check_configuration_loaded +from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.common import retrieve_user_id from utils.suid import get_suid @@ -265,11 +264,7 @@ async def retrieve_response( logger.info("Available shields found: %s", available_shields) # use system prompt from request or default one - system_prompt = ( - query_request.system_prompt - if query_request.system_prompt - else constants.DEFAULT_SYSTEM_PROMPT - ) + system_prompt = get_system_prompt(query_request, configuration) logger.debug("Using system prompt: %s", system_prompt) # TODO(lucasagomes): redact attachments content before sending to LLM diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index f1ea332e..8d84ba2b 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -2,6 +2,8 @@ from fastapi import HTTPException, status +import constants +from models.requests import QueryRequest from configuration import AppConfig @@ -12,3 +14,12 @@ def check_configuration_loaded(configuration: AppConfig) -> None: status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"response": "Configuration is not loaded"}, ) + + +def get_system_prompt(query_request: QueryRequest, _configuration: AppConfig) -> str: + """Get the system prompt: the provided one, configured one, or default one.""" + return ( + query_request.system_prompt + if query_request.system_prompt + else constants.DEFAULT_SYSTEM_PROMPT + )