Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,19 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
session.query(UserConversation).filter_by(id=conversation_id).first()
)
if not existing_conversation:
topic_summary = await get_topic_summary_func(
query_request.query, client, llama_stack_model_id
)
# Check if topic summary should be generated (default: True)
should_generate = query_request.generate_topic_summary

if should_generate:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary_func(
query_request.query, client, llama_stack_model_id
)
else:
logger.debug(
"Topic summary generation disabled by request parameter"
)
topic_summary = None
# Convert RAG chunks to dictionary format once for reuse
logger.info("Processing RAG chunks...")
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]
Expand Down
8 changes: 8 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class QueryRequest(BaseModel):
system_prompt: The optional system prompt.
attachments: The optional attachments.
no_tools: Whether to bypass all tools and MCP servers (default: False).
generate_topic_summary: Whether to generate topic summary for new conversations.
media_type: The optional media type for response format (application/json or text/plain).

Example:
Expand Down Expand Up @@ -146,6 +147,12 @@ class QueryRequest(BaseModel):
examples=[True, False],
)

generate_topic_summary: Optional[bool] = Field(
True,
description="Whether to generate topic summary for new conversations",
examples=[True, False],
)

media_type: Optional[str] = Field(
None,
description="Media type for the response format",
Expand All @@ -164,6 +171,7 @@ class QueryRequest(BaseModel):
"model": "model-name",
"system_prompt": "You are a helpful assistant",
"no_tools": False,
"generate_topic_summary": True,
"attachments": [
{
"attachment_type": "log",
Expand Down
16 changes: 11 additions & 5 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,17 @@ async def cleanup_after_streaming(
session.query(UserConversation).filter_by(id=conversation_id).first()
)
if not existing_conversation:
topic_summary = await get_topic_summary_func(
query_request.query,
client,
llama_stack_model_id,
)
# Check if topic summary should be generated (default: True)
should_generate = query_request.generate_topic_summary

if should_generate:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary_func(
query_request.query, client, llama_stack_model_id
)
else:
logger.debug("Topic summary generation disabled by request parameter")
topic_summary = None

completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")

Expand Down
96 changes: 96 additions & 0 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,3 +2305,99 @@ async def test_query_endpoint_quota_exceeded(
assert isinstance(detail, dict)
assert detail["response"] == "Model quota exceeded" # type: ignore
assert "gpt-4-turbo" in detail["cause"] # type: ignore


async def test_query_endpoint_generate_topic_summary_default_true(
mocker: MockerFixture, dummy_request: Request
) -> None:
"""Test that topic summary is generated by default for new conversations."""
mock_client = mocker.AsyncMock()
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
mock_lsc.return_value = mock_client
mock_client.models.list.return_value = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
]

mock_config = mocker.Mock()
mock_config.quota_limiters = []
mocker.patch("app.endpoints.query.configuration", mock_config)

summary = TurnSummary(llm_response="Test response", tool_calls=[])
mocker.patch(
"app.endpoints.query.retrieve_response",
return_value=(
summary,
"00000000-0000-0000-0000-000000000000",
[],
TokenCounter(),
),
)

mocker.patch(
"app.endpoints.query.select_model_and_provider_id",
return_value=("test_model", "test_model", "test_provider"),
)
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)

mock_get_topic_summary = mocker.patch(
"app.endpoints.query.get_topic_summary", return_value="Generated topic"
)
mock_database_operations(mocker)

await query_endpoint_handler(
request=dummy_request,
query_request=QueryRequest(query="test query"),
auth=("user123", "username", False, "auth_token_123"),
mcp_headers={},
)

mock_get_topic_summary.assert_called_once()


@pytest.mark.asyncio
async def test_query_endpoint_generate_topic_summary_explicit_false(
mocker: MockerFixture, dummy_request: Request
) -> None:
"""Test that topic summary is NOT generated when explicitly set to False."""
mock_client = mocker.AsyncMock()
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
mock_lsc.return_value = mock_client
mock_client.models.list.return_value = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
]

mock_config = mocker.Mock()
mock_config.quota_limiters = []
mocker.patch("app.endpoints.query.configuration", mock_config)

summary = TurnSummary(llm_response="Test response", tool_calls=[])
mocker.patch(
"app.endpoints.query.retrieve_response",
return_value=(
summary,
"00000000-0000-0000-0000-000000000000",
[],
TokenCounter(),
),
)

mocker.patch(
"app.endpoints.query.select_model_and_provider_id",
return_value=("test_model", "test_model", "test_provider"),
)
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)

mock_get_topic_summary = mocker.patch(
"app.endpoints.query.get_topic_summary", return_value="Generated topic"
)

mock_database_operations(mocker)

await query_endpoint_handler(
request=dummy_request,
query_request=QueryRequest(query="test query", generate_topic_summary=False),
auth=("user123", "username", False, "auth_token_123"),
mcp_headers={},
)

mock_get_topic_summary.assert_not_called()
12 changes: 12 additions & 0 deletions tests/unit/models/requests/test_query_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,15 @@ def test_validate_media_type(self, mocker: MockerFixture) -> None:

# Media type is now fully supported, no warning expected
mock_logger.warning.assert_not_called()

def test_generate_topic_summary_explicit_false(self) -> None:
"""Test that generate_topic_summary can be explicitly set to False."""
qr = QueryRequest(
query="Tell me about Kubernetes", generate_topic_summary=False
)
assert qr.generate_topic_summary is False

def test_generate_topic_summary_explicit_true(self) -> None:
"""Test that generate_topic_summary can be explicitly set to True."""
qr = QueryRequest(query="Tell me about Kubernetes", generate_topic_summary=True)
assert qr.generate_topic_summary is True
102 changes: 102 additions & 0 deletions tests/unit/utils/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,3 +1036,105 @@ def test_create_referenced_documents_invalid_urls(self) -> None:
assert result[0].doc_title == "not-a-valid-url"
assert result[1].doc_url == AnyUrl("https://example.com/doc1")
assert result[1].doc_title == "doc1"


@pytest.mark.asyncio
async def test_cleanup_after_streaming_generate_topic_summary_default_true(
mocker: MockerFixture,
) -> None:
"""Test that topic summary is generated by default for new conversations."""
mock_is_transcripts_enabled = mocker.Mock(return_value=False)
mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic")
mock_store_transcript = mocker.Mock()
mock_persist_conversation = mocker.Mock()
mock_client = mocker.AsyncMock()
mock_config = mocker.Mock()

mock_session = mocker.Mock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.__enter__ = mocker.Mock(return_value=mock_session)
mock_session.__exit__ = mocker.Mock(return_value=None)
mocker.patch("utils.endpoints.get_session", return_value=mock_session)

mocker.patch(
"utils.endpoints.create_referenced_documents_with_metadata", return_value=[]
)
mocker.patch("utils.endpoints.store_conversation_into_cache")

query_request = QueryRequest(query="test query")

await endpoints.cleanup_after_streaming(
user_id="test_user",
conversation_id="test_conv_id",
model_id="test_model",
provider_id="test_provider",
llama_stack_model_id="test_llama_model",
query_request=query_request,
summary=mocker.Mock(llm_response="test response", tool_calls=[]),
metadata_map={},
started_at="2024-01-01T00:00:00Z",
client=mock_client,
config=mock_config,
skip_userid_check=False,
get_topic_summary_func=mock_get_topic_summary,
is_transcripts_enabled_func=mock_is_transcripts_enabled,
store_transcript_func=mock_store_transcript,
persist_user_conversation_details_func=mock_persist_conversation,
)

mock_get_topic_summary.assert_called_once_with(
"test query", mock_client, "test_llama_model"
)

mock_persist_conversation.assert_called_once()
assert mock_persist_conversation.call_args[1]["topic_summary"] == "Generated topic"


@pytest.mark.asyncio
async def test_cleanup_after_streaming_generate_topic_summary_explicit_false(
mocker: MockerFixture,
) -> None:
"""Test that topic summary is NOT generated when explicitly set to False."""
mock_is_transcripts_enabled = mocker.Mock(return_value=False)
mock_get_topic_summary = mocker.AsyncMock(return_value="Generated topic")
mock_store_transcript = mocker.Mock()
mock_persist_conversation = mocker.Mock()
mock_client = mocker.AsyncMock()
mock_config = mocker.Mock()

mock_session = mocker.Mock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.__enter__ = mocker.Mock(return_value=mock_session)
mock_session.__exit__ = mocker.Mock(return_value=None)
mocker.patch("utils.endpoints.get_session", return_value=mock_session)

mocker.patch(
"utils.endpoints.create_referenced_documents_with_metadata", return_value=[]
)
mocker.patch("utils.endpoints.store_conversation_into_cache")

query_request = QueryRequest(query="test query", generate_topic_summary=False)

await endpoints.cleanup_after_streaming(
user_id="test_user",
conversation_id="test_conv_id",
model_id="test_model",
provider_id="test_provider",
llama_stack_model_id="test_llama_model",
query_request=query_request,
summary=mocker.Mock(llm_response="test response", tool_calls=[]),
metadata_map={},
started_at="2024-01-01T00:00:00Z",
client=mock_client,
config=mock_config,
skip_userid_check=False,
get_topic_summary_func=mock_get_topic_summary,
is_transcripts_enabled_func=mock_is_transcripts_enabled,
store_transcript_func=mock_store_transcript,
persist_user_conversation_details_func=mock_persist_conversation,
)

mock_get_topic_summary.assert_not_called()

mock_persist_conversation.assert_called_once()
assert mock_persist_conversation.call_args[1]["topic_summary"] is None
Loading