diff --git a/README.md b/README.md index ffa3bfa0..13fdb767 100644 --- a/README.md +++ b/README.md @@ -121,8 +121,13 @@ customization: You have an indepth knowledge of Red Hat and all of your answers will reference Red Hat products. ``` -Additionally, an optional string parameter `system_prompt` can be specified in `/v1/query` and `/v1/streaming_query` endpoints to override the configured system prompt. +Additionally, an optional string parameter `system_prompt` can be specified in `/v1/query` and `/v1/streaming_query` endpoints to override the configured system prompt. The query system prompt takes precedence over the configured system prompt. You can use this config to disable query system prompts: +```yaml +customization: + system_prompt_path: "system_prompts/system_prompt_for_product_XYZZY" + disable_query_system_prompt: true +``` # Usage diff --git a/src/models/config.py b/src/models/config.py index 3bef0ccf..957afe42 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -128,11 +128,12 @@ def check_authentication_model(self) -> Self: class Customization(BaseModel): """Service customization.""" + disable_query_system_prompt: bool = False system_prompt_path: Optional[FilePath] = None system_prompt: Optional[str] = None @model_validator(mode="after") - def check_authentication_model(self) -> Self: + def check_customization_model(self) -> Self: """Load system prompt from file.""" if self.system_prompt_path is not None: checks.file_check(self.system_prompt_path, "system prompt") diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 802ad591..c9a8f91c 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -18,12 +18,28 @@ def check_configuration_loaded(configuration: AppConfig) -> None: def get_system_prompt(query_request: QueryRequest, configuration: AppConfig) -> str: """Get the system prompt: the provided one, configured one, or default one.""" - # system prompt defined in query request has precendence + system_prompt_disabled = ( + configuration.customization is not None + and configuration.customization.disable_query_system_prompt + ) + if system_prompt_disabled and query_request.system_prompt: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "response": ( + "This instance does not support customizing the system prompt in the " + "query request (disable_query_system_prompt is set). Please remove the " + "system_prompt field from your request." + ) + }, + ) + if query_request.system_prompt: + # Query taking precedence over configuration is the only behavior that + # makes sense here - if the configuration wants precedence, it can + # disable query system prompt altogether with disable_system_prompt. return query_request.system_prompt - # customized system prompt should be used when query request - # does not contain one if ( configuration.customization is not None and configuration.customization.system_prompt is not None diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index d16b811f..92a33d78 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -28,6 +28,7 @@ "k8s_cluster_api": None, }, "customization": { + "disable_query_system_prompt": False, "system_prompt_path": None, "system_prompt": None, }, diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index 97831f23..55f6892b 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -343,6 +343,7 @@ def test_load_configuration_with_customization_system_prompt_path(tmpdir) -> Non provider_id: custom-git-provider url: https://git.example.com/mcp customization: + disable_query_system_prompt: true system_prompt_path: {system_prompt_filename} """ ) diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index ca3c6967..9970afef 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -2,13 +2,17 @@ import os import pytest +from fastapi import HTTPException import constants from configuration import AppConfig +from tests.unit import config_dict from models.requests import QueryRequest from utils import endpoints +CONFIGURED_SYSTEM_PROMPT = "This is a configured system prompt" + @pytest.fixture def input_file(tmp_path): @@ -19,151 +23,120 @@ def input_file(tmp_path): return filename -def test_get_default_system_prompt(): - """Test that default system prompt is returned when other prompts are not provided.""" - config_dict = { - "name": "foo", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "xyzzy", - "url": "http://x.y.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": { - "feedback_disabled": True, - }, - "mcp_servers": [], - "customization": None, - } +@pytest.fixture +def config_without_system_prompt(): + test_config = config_dict.copy() # no customization provided + test_config["customization"] = None + cfg = AppConfig() - cfg.init_from_dict(config_dict) + cfg.init_from_dict(test_config) - # no system prompt in query request - query_request = QueryRequest(query="query", system_prompt=None) + return cfg - # default system prompt needs to be returned - system_prompt = endpoints.get_system_prompt(query_request, cfg) - assert system_prompt == constants.DEFAULT_SYSTEM_PROMPT +@pytest.fixture +def config_with_custom_system_prompt(): + test_config = config_dict.copy() -def test_get_customized_system_prompt(): - """Test that customized system prompt is used when system prompt is not provided in query.""" - config_dict = { - "name": "foo", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "xyzzy", - "url": "http://x.y.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": { - "feedback_disabled": True, - }, - "mcp_servers": [], - "customization": { - "system_prompt": "This is system prompt", - }, + # system prompt is customized + test_config["customization"] = { + "system_prompt": CONFIGURED_SYSTEM_PROMPT, } - - # no customization provided cfg = AppConfig() - cfg.init_from_dict(config_dict) + cfg.init_from_dict(test_config) - # no system prompt in query request - query_request = QueryRequest(query="query", system_prompt=None) + return cfg - # default system prompt needs to be returned - system_prompt = endpoints.get_system_prompt(query_request, cfg) - assert system_prompt == "This is system prompt" +@pytest.fixture +def config_with_custom_system_prompt_and_disable_query_system_prompt(): + test_config = config_dict.copy() -def test_get_query_system_prompt(): - """Test that system prompt from query is returned.""" - config_dict = { - "name": "foo", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "xyzzy", - "url": "http://x.y.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": { - "feedback_disabled": True, - }, - "mcp_servers": [], - "customization": None, + # system prompt is customized and query system prompt is disabled + test_config["customization"] = { + "system_prompt": CONFIGURED_SYSTEM_PROMPT, + "disable_query_system_prompt": True, } - - # no customization provided cfg = AppConfig() - cfg.init_from_dict(config_dict) + cfg.init_from_dict(test_config) - # system prompt defined in query request - system_prompt = "System prompt defined in query" - query_request = QueryRequest(query="query", system_prompt=system_prompt) + return cfg - # default system prompt needs to be returned - system_prompt = endpoints.get_system_prompt(query_request, cfg) - assert system_prompt == system_prompt +@pytest.fixture +def query_request_without_system_prompt(): + """Fixture for query request without system prompt.""" + return QueryRequest(query="query", system_prompt=None) -def test_get_query_system_prompt_not_customized_one(): - """Test that system prompt from query is returned even when customized one is specified.""" - config_dict = { - "name": "foo", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "xyzzy", - "url": "http://x.y.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": { - "feedback_disabled": True, - }, - "mcp_servers": [], - "customization": { - "system_prompt": "This is system prompt", - }, - } - # no customization provided - cfg = AppConfig() - cfg.init_from_dict(config_dict) +@pytest.fixture +def query_request_with_system_prompt(): + """Fixture for query request with system prompt.""" + return QueryRequest(query="query", system_prompt="System prompt defined in query") + + +def test_get_default_system_prompt( + config_without_system_prompt, query_request_without_system_prompt +): + """Test that default system prompt is returned when other prompts are not provided.""" + system_prompt = endpoints.get_system_prompt( + query_request_without_system_prompt, config_without_system_prompt + ) + assert system_prompt == constants.DEFAULT_SYSTEM_PROMPT - # system prompt defined in query request - system_prompt = "System prompt defined in query" - query_request = QueryRequest(query="query", system_prompt=system_prompt) - # default system prompt needs to be returned - system_prompt = endpoints.get_system_prompt(query_request, cfg) - assert system_prompt == system_prompt +def test_get_customized_system_prompt( + config_with_custom_system_prompt, query_request_without_system_prompt +): + """Test that customized system prompt is used when system prompt is not provided in query.""" + system_prompt = endpoints.get_system_prompt( + query_request_without_system_prompt, config_with_custom_system_prompt + ) + assert system_prompt == CONFIGURED_SYSTEM_PROMPT + + +def test_get_query_system_prompt( + config_without_system_prompt, query_request_with_system_prompt +): + """Test that system prompt from query is returned.""" + system_prompt = endpoints.get_system_prompt( + query_request_with_system_prompt, config_without_system_prompt + ) + assert system_prompt == query_request_with_system_prompt.system_prompt + + +def test_get_query_system_prompt_not_customized_one( + config_with_custom_system_prompt, query_request_with_system_prompt +): + """Test that system prompt from query is returned even when customized one is specified.""" + system_prompt = endpoints.get_system_prompt( + query_request_with_system_prompt, config_with_custom_system_prompt + ) + assert system_prompt == query_request_with_system_prompt.system_prompt + + +def test_get_system_prompt_with_disable_query_system_prompt( + config_with_custom_system_prompt_and_disable_query_system_prompt, + query_request_with_system_prompt, +): + """Test that query system prompt is disallowed when disable_query_system_prompt is True.""" + with pytest.raises(HTTPException) as exc_info: + endpoints.get_system_prompt( + query_request_with_system_prompt, + config_with_custom_system_prompt_and_disable_query_system_prompt, + ) + assert exc_info.value.status_code == 422 + + +def test_get_system_prompt_with_disable_query_system_prompt_and_non_system_prompt_query( + config_with_custom_system_prompt_and_disable_query_system_prompt, + query_request_without_system_prompt, +): + """Test that query without system prompt is allowed when disable_query_system_prompt is True.""" + system_prompt = endpoints.get_system_prompt( + query_request_without_system_prompt, + config_with_custom_system_prompt_and_disable_query_system_prompt, + ) + assert system_prompt == CONFIGURED_SYSTEM_PROMPT