From 2e91d9533251b35ce41bcd1d86961a2e1b29bb58 Mon Sep 17 00:00:00 2001 From: Radovan Fuchs Date: Thu, 13 Nov 2025 15:00:08 +0100 Subject: [PATCH 1/3] add option to disable topic summary --- src/app/endpoints/query.py | 16 ++- src/models/requests.py | 8 ++ src/utils/endpoints.py | 16 ++- tests/unit/app/endpoints/test_query.py | 97 +++++++++++++++++ .../models/requests/test_query_request.py | 12 +++ tests/unit/utils/test_endpoints.py | 102 ++++++++++++++++++ 6 files changed, 243 insertions(+), 8 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 4993da1a..7ce869e9 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -318,9 +318,19 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 session.query(UserConversation).filter_by(id=conversation_id).first() ) if not existing_conversation: - topic_summary = await get_topic_summary_func( - query_request.query, client, llama_stack_model_id - ) + # Check if topic summary should be generated (default: True) + should_generate = query_request.generate_topic_summary + + if should_generate: + logger.debug("Generating topic summary for new conversation") + topic_summary = await get_topic_summary_func( + query_request.query, client, llama_stack_model_id + ) + else: + logger.debug( + "Topic summary generation disabled by request parameter" + ) + topic_summary = None # Convert RAG chunks to dictionary format once for reuse logger.info("Processing RAG chunks...") rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] diff --git a/src/models/requests.py b/src/models/requests.py index 1033828e..c2e329b6 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -81,6 +81,7 @@ class QueryRequest(BaseModel): system_prompt: The optional system prompt. attachments: The optional attachments. no_tools: Whether to bypass all tools and MCP servers (default: False). + generate_topic_summary: Whether to generate topic summary for new conversations. media_type: The optional media type for response format (application/json or text/plain). Example: @@ -146,6 +147,12 @@ class QueryRequest(BaseModel): examples=[True, False], ) + generate_topic_summary: Optional[bool] = Field( + True, + description="Whether to generate topic summary for new conversations", + examples=[True, False], + ) + media_type: Optional[str] = Field( None, description="Media type for the response format", @@ -164,6 +171,7 @@ class QueryRequest(BaseModel): "model": "model-name", "system_prompt": "You are a helpful assistant", "no_tools": False, + "generate_topic_summary": True, "attachments": [ { "attachment_type": "log", diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 80b3b6e5..cc926889 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -671,11 +671,17 @@ async def cleanup_after_streaming( session.query(UserConversation).filter_by(id=conversation_id).first() ) if not existing_conversation: - topic_summary = await get_topic_summary_func( - query_request.query, - client, - llama_stack_model_id, - ) + # Check if topic summary should be generated (default: True) + should_generate = query_request.generate_topic_summary + + if should_generate: + logger.debug("Generating topic summary for new conversation") + topic_summary = await get_topic_summary_func( + query_request.query, client, llama_stack_model_id + ) + else: + logger.debug("Topic summary generation disabled by request parameter") + topic_summary = None completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 9b10d970..180ec54f 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -2265,6 +2265,7 @@ async def test_get_topic_summary_create_turn_parameters(mocker: MockerFixture) - @pytest.mark.asyncio +<<<<<<< HEAD async def test_query_endpoint_quota_exceeded( mocker: MockerFixture, dummy_request: Request ) -> None: @@ -2305,3 +2306,99 @@ async def test_query_endpoint_quota_exceeded( assert isinstance(detail, dict) assert detail["response"] == "Model quota exceeded" # type: ignore assert "gpt-4-turbo" in detail["cause"] # type: ignore +======= +async def test_query_endpoint_generate_topic_summary_default_true( + mocker: MockerFixture, dummy_request: Request +) -> None: + """Test that topic summary is generated by default for new conversations.""" + mock_client = mocker.AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + mock_config = mocker.Mock() + mock_config.quota_limiters = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + + summary = TurnSummary(llm_response="Test response", tool_calls=[]) + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=( + summary, + "00000000-0000-0000-0000-000000000000", + [], + TokenCounter(), + ), + ) + + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("test_model", "test_model", "test_provider"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + mock_get_topic_summary = mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Generated topic" + ) + mock_database_operations(mocker) + + await query_endpoint_handler( + request=dummy_request, + query_request=QueryRequest(query="test query"), + auth=("user123", "username", False, "auth_token_123"), + mcp_headers={}, + ) + + mock_get_topic_summary.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_endpoint_generate_topic_summary_explicit_false( + mocker: MockerFixture, dummy_request: Request +) -> None: + """Test that topic summary is NOT generated when explicitly set to False.""" + mock_client = mocker.AsyncMock() + mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") + mock_lsc.return_value = mock_client + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + mock_config = mocker.Mock() + mock_config.quota_limiters = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + + summary = TurnSummary(llm_response="Test response", tool_calls=[]) + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=( + summary, + "00000000-0000-0000-0000-000000000000", + [], + TokenCounter(), + ), + ) + + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("test_model", "test_model", "test_provider"), + ) + mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) + + mock_get_topic_summary = mocker.patch( + "app.endpoints.query.get_topic_summary", return_value="Generated topic" + ) + + mock_database_operations(mocker) + + await query_endpoint_handler( + request=dummy_request, + query_request=QueryRequest(query="test query", generate_topic_summary=False), + auth=("user123", "username", False, "auth_token_123"), + mcp_headers={}, + ) + + mock_get_topic_summary.assert_not_called() +>>>>>>> 81b4b90 (added unit tests for the extra logic) diff --git a/tests/unit/models/requests/test_query_request.py b/tests/unit/models/requests/test_query_request.py index f221e234..ef4f66c3 100644 --- a/tests/unit/models/requests/test_query_request.py +++ b/tests/unit/models/requests/test_query_request.py @@ -154,3 +154,15 @@ def test_validate_media_type(self, mocker: MockerFixture) -> None: # Media type is now fully supported, no warning expected mock_logger.warning.assert_not_called() + + def test_generate_topic_summary_explicit_false(self) -> None: + """Test that generate_topic_summary can be explicitly set to False.""" + qr = QueryRequest( + query="Tell me about Kubernetes", generate_topic_summary=False + ) + assert qr.generate_topic_summary is False + + def test_generate_topic_summary_explicit_true(self) -> None: + """Test that generate_topic_summary can be explicitly set to True.""" + qr = QueryRequest(query="Tell me about Kubernetes", generate_topic_summary=True) + assert qr.generate_topic_summary is True diff --git a/tests/unit/utils/test_endpoints.py b/tests/unit/utils/test_endpoints.py index 5d6caced..cfb3c359 100644 --- a/tests/unit/utils/test_endpoints.py +++ b/tests/unit/utils/test_endpoints.py @@ -1036,3 +1036,105 @@ def test_create_referenced_documents_invalid_urls(self) -> None: assert result[0].doc_title == "not-a-valid-url" assert result[1].doc_url == AnyUrl("https://example.com/doc1") assert result[1].doc_title == "doc1" + + +@pytest.mark.asyncio +async def test_cleanup_after_streaming_generate_topic_summary_default_true( + mocker: MockerFixture, +) -> None: + """Test that topic summary is generated by default for new conversations.""" + mock_is_transcripts_enabled = mocker.Mock(return_value=False) + mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic") + mock_store_transcript = mocker.Mock() + mock_persist_conversation = mocker.Mock() + mock_client = mocker.AsyncMock() + mock_config = mocker.Mock() + + mock_session = mocker.Mock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.__enter__ = mocker.Mock(return_value=mock_session) + mock_session.__exit__ = mocker.Mock(return_value=None) + mocker.patch("utils.endpoints.get_session", return_value=mock_session) + + mocker.patch( + "utils.endpoints.create_referenced_documents_with_metadata", return_value=[] + ) + mocker.patch("utils.endpoints.store_conversation_into_cache") + + query_request = QueryRequest(query="test query") + + await endpoints.cleanup_after_streaming( + user_id="test_user", + conversation_id="test_conv_id", + model_id="test_model", + provider_id="test_provider", + llama_stack_model_id="test_llama_model", + query_request=query_request, + summary=mocker.Mock(llm_response="test response", tool_calls=[]), + metadata_map={}, + started_at="2024-01-01T00:00:00Z", + client=mock_client, + config=mock_config, + skip_userid_check=False, + get_topic_summary_func=mock_get_topic_summary, + is_transcripts_enabled_func=mock_is_transcripts_enabled, + store_transcript_func=mock_store_transcript, + persist_user_conversation_details_func=mock_persist_conversation, + ) + + mock_get_topic_summary.assert_called_once_with( + "test query", mock_client, "test_llama_model" + ) + + mock_persist_conversation.assert_called_once() + assert mock_persist_conversation.call_args[1]["topic_summary"] == "Generated topic" + + +@pytest.mark.asyncio +async def test_cleanup_after_streaming_generate_topic_summary_explicit_false( + mocker: MockerFixture, +) -> None: + """Test that topic summary is NOT generated when explicitly set to False.""" + mock_is_transcripts_enabled = mocker.Mock(return_value=False) + mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic") + mock_store_transcript = mocker.Mock() + mock_persist_conversation = mocker.Mock() + mock_client = mocker.AsyncMock() + mock_config = mocker.Mock() + + mock_session = mocker.Mock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.__enter__ = mocker.Mock(return_value=mock_session) + mock_session.__exit__ = mocker.Mock(return_value=None) + mocker.patch("utils.endpoints.get_session", return_value=mock_session) + + mocker.patch( + "utils.endpoints.create_referenced_documents_with_metadata", return_value=[] + ) + mocker.patch("utils.endpoints.store_conversation_into_cache") + + query_request = QueryRequest(query="test query", generate_topic_summary=False) + + await endpoints.cleanup_after_streaming( + user_id="test_user", + conversation_id="test_conv_id", + model_id="test_model", + provider_id="test_provider", + llama_stack_model_id="test_llama_model", + query_request=query_request, + summary=mocker.Mock(llm_response="test response", tool_calls=[]), + metadata_map={}, + started_at="2024-01-01T00:00:00Z", + client=mock_client, + config=mock_config, + skip_userid_check=False, + get_topic_summary_func=mock_get_topic_summary, + is_transcripts_enabled_func=mock_is_transcripts_enabled, + store_transcript_func=mock_store_transcript, + persist_user_conversation_details_func=mock_persist_conversation, + ) + + mock_get_topic_summary.assert_not_called() + + mock_persist_conversation.assert_called_once() + assert mock_persist_conversation.call_args[1]["topic_summary"] is None From 263ed684cad1ba002d341aaa243f200fa57fad78 Mon Sep 17 00:00:00 2001 From: Radovan Fuchs Date: Wed, 19 Nov 2025 11:27:25 +0100 Subject: [PATCH 2/3] fix --- tests/unit/app/endpoints/test_query.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 180ec54f..085db4bd 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -2265,7 +2265,6 @@ async def test_get_topic_summary_create_turn_parameters(mocker: MockerFixture) - @pytest.mark.asyncio -<<<<<<< HEAD async def test_query_endpoint_quota_exceeded( mocker: MockerFixture, dummy_request: Request ) -> None: @@ -2306,7 +2305,6 @@ async def test_query_endpoint_quota_exceeded( assert isinstance(detail, dict) assert detail["response"] == "Model quota exceeded" # type: ignore assert "gpt-4-turbo" in detail["cause"] # type: ignore -======= async def test_query_endpoint_generate_topic_summary_default_true( mocker: MockerFixture, dummy_request: Request ) -> None: @@ -2401,4 +2399,3 @@ async def test_query_endpoint_generate_topic_summary_explicit_false( ) mock_get_topic_summary.assert_not_called() ->>>>>>> 81b4b90 (added unit tests for the extra logic) From be18bd1704c7a3179613d27362357eee3c8e1a48 Mon Sep 17 00:00:00 2001 From: Radovan Fuchs Date: Wed, 19 Nov 2025 11:29:26 +0100 Subject: [PATCH 3/3] fix --- tests/unit/app/endpoints/test_query.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 085db4bd..c7e6415b 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -2305,6 +2305,8 @@ async def test_query_endpoint_quota_exceeded( assert isinstance(detail, dict) assert detail["response"] == "Model quota exceeded" # type: ignore assert "gpt-4-turbo" in detail["cause"] # type: ignore + + async def test_query_endpoint_generate_topic_summary_default_true( mocker: MockerFixture, dummy_request: Request ) -> None: