diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 7cf5b380a..0ee67bc62 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -99,7 +99,6 @@ from utils.tool_formatter import translate_vector_store_ids_to_user_facing from utils.types import ( RAGContext, - ResponseInput, ResponsesApiParams, ShieldModerationBlocked, ShieldModerationResult, @@ -236,19 +235,16 @@ async def responses_endpoint_handler( - 500: Internal Server Error - Configuration not loaded or other server errors - 503: Service Unavailable - Unable to connect to Llama Stack backend """ + original_request = responses_request # read-only request + updated_request = responses_request.model_copy(deep=True) + _ = responses_request + # Known LLS bug: https://redhat.atlassian.net/browse/LCORE-1583 - if responses_request.reasoning is not None: + if original_request.reasoning is not None: logger.warning("reasoning is not yet supported in LCORE and will be ignored") - responses_request.reasoning = None - - responses_request = responses_request.model_copy(deep=True) + updated_request.reasoning = None check_configuration_loaded(configuration) - client_instructions = responses_request.instructions - responses_request.instructions = get_system_prompt( - responses_request.instructions, field_name="instructions" - ) - instructions_substituted = client_instructions is None started_at = datetime.now(UTC) rh_identity_context = get_rh_identity_context(request) user_id, _, _, token = auth @@ -260,40 +256,41 @@ async def responses_endpoint_handler( # Enforce RBAC: optionally disallow overriding model in requests validate_model_provider_override( - responses_request.model, + original_request.model, None, # provider specified as model prefix request.state.authorized_actions, ) + updated_request.instructions = get_system_prompt( + original_request.instructions, field_name="instructions" + ) + response_context = await resolve_response_context( user_id=user_id, others_allowed=( Action.READ_OTHERS_CONVERSATIONS in request.state.authorized_actions ), - conversation_id=responses_request.conversation, - previous_response_id=responses_request.previous_response_id, - generate_topic_summary=responses_request.generate_topic_summary, + conversation_id=original_request.conversation, + previous_response_id=original_request.previous_response_id, + generate_topic_summary=original_request.generate_topic_summary, ) - responses_request.conversation = response_context.conversation - responses_request.generate_topic_summary = response_context.generate_topic_summary + updated_request.conversation = response_context.conversation + updated_request.generate_topic_summary = response_context.generate_topic_summary client = AsyncLlamaStackClientHolder().get_client() # LCORE-specific: Automatically select model if not provided in request # This extends the base LLS API which requires model to be specified. - client_model = responses_request.model - if not responses_request.model: - responses_request.model = await select_model_for_responses( - client, response_context.user_conversation - ) - model_substituted = not client_model - if not await check_model_configured(client, responses_request.model): - _, model_id = extract_provider_and_model_from_model_id(responses_request.model) + updated_request.model = await select_model_for_responses( + original_request.model, client, response_context.user_conversation + ) + if not await check_model_configured(client, updated_request.model): + _, model_id = extract_provider_and_model_from_model_id(updated_request.model) error_response = NotFoundResponse(resource="model", resource_id=model_id) raise HTTPException(**error_response.model_dump()) # Handle Azure token refresh if needed if ( - responses_request.model.startswith("azure") + updated_request.model.startswith("azure") and AzureEntraIDManager().is_entra_id_configured and AzureEntraIDManager().is_token_expired and AzureEntraIDManager().refresh_token() @@ -301,79 +298,66 @@ async def responses_endpoint_handler( client = await update_azure_token(client) input_text = ( - responses_request.input - if isinstance(responses_request.input, str) - else extract_text_from_response_items(responses_request.input) + original_request.input + if isinstance(original_request.input, str) + else extract_text_from_response_items(original_request.input) ) - attachments_text = extract_attachments_text(responses_request.input) + attachments_text = extract_attachments_text(original_request.input) moderation_result = await run_shield_moderation( client, input_text + "\n\n" + attachments_text, - responses_request.shield_ids, - ) - - # Extract vector store IDs for Inline RAG context before resolving tool choice. - vector_store_ids: Optional[list[str]] = ( - extract_vector_store_ids_from_tools(responses_request.tools) - if responses_request.tools is not None - else None + original_request.shield_ids, ) filter_server_tools = ( request.headers.get("X-LCS-Merge-Server-Tools", "").lower() == "true" ) + resolver = ( + resolve_client_tool_choice if filter_server_tools else resolve_tool_choice + ) + updated_request.tools, updated_request.tool_choice = await resolver( + original_request.tools, + original_request.tool_choice, + token, + mcp_headers, + request.headers, + ) - if filter_server_tools: - responses_request.tools, responses_request.tool_choice = ( - await resolve_client_tool_choice( - responses_request.tools, - responses_request.tool_choice, - auth[1], - mcp_headers, - request.headers, - ) - ) - else: - responses_request.tools, responses_request.tool_choice = ( - await resolve_tool_choice( - responses_request.tools, - responses_request.tool_choice, - auth[1], - mcp_headers, - request.headers, - ) - ) - + # Extract vector store IDs for Inline RAG context from the original request + vector_store_ids: Optional[list[str]] = ( + extract_vector_store_ids_from_tools(original_request.tools) + if original_request.tools is not None + else None + ) # Build RAG context from Inline RAG sources inline_rag_context = await build_rag_context( client, moderation_result.decision, input_text, vector_store_ids, - responses_request.solr, + original_request.solr, ) if moderation_result.decision == "passed": - responses_request.input = append_inline_rag_context_to_responses_input( - responses_request.input, inline_rag_context.context_text + updated_request.input = append_inline_rag_context_to_responses_input( + original_request.input, inline_rag_context.context_text ) response_handler = ( handle_streaming_response - if responses_request.stream + if original_request.stream else handle_non_streaming_response ) return await response_handler( client=client, - request=responses_request, + original_request=original_request, + updated_request=updated_request, auth=auth, input_text=input_text, started_at=started_at, moderation_result=moderation_result, inline_rag_context=inline_rag_context, filter_server_tools=filter_server_tools, - instructions_substituted=instructions_substituted, - model_substituted=model_substituted, background_tasks=background_tasks, rh_identity_context=rh_identity_context, ) @@ -381,15 +365,14 @@ async def responses_endpoint_handler( async def handle_streaming_response( client: AsyncLlamaStackClient, - request: ResponsesRequest, + original_request: ResponsesRequest, + updated_request: ResponsesRequest, auth: AuthTuple, input_text: str, started_at: datetime, moderation_result: ShieldModerationResult, inline_rag_context: RAGContext, filter_server_tools: bool = False, - instructions_substituted: bool = False, - model_substituted: bool = False, background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), ) -> StreamingResponse: @@ -404,14 +387,12 @@ async def handle_streaming_response( moderation_result: Result of shield moderation check inline_rag_context: Inline RAG context to be used for the response filter_server_tools: Whether to filter server-deployed MCP tool events from the stream - instructions_substituted: Whether the server substituted the instructions - model_substituted: Whether the server substituted the model background_tasks: FastAPI background task manager for telemetry events rh_identity_context: Tuple of (org_id, system_id) from RH identity Returns: StreamingResponse with SSE-formatted events """ - api_params = ResponsesApiParams.model_validate(request.model_dump()) + api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) turn_summary = TurnSummary() # Handle blocked response if moderation_result.decision == "blocked": @@ -423,7 +404,7 @@ async def handle_streaming_response( generator = shield_violation_generator( moderation_result, api_params.conversation, - request.echoed_params(), + updated_request.echoed_params(), started_at, available_quotas, ) @@ -431,7 +412,7 @@ async def handle_streaming_response( await append_turn_items_to_conversation( client=client, conversation_id=api_params.conversation, - user_input=request.input, + user_input=updated_request.input, llm_output=[moderation_result.refusal_response], ) _queue_responses_splunk_event( @@ -451,14 +432,13 @@ async def handle_streaming_response( ) generator = response_generator( stream=cast(AsyncIterator[OpenAIResponseObjectStream], response), - user_input=request.input, + original_request=original_request, + updated_request=updated_request, api_params=api_params, user_id=auth[0], turn_summary=turn_summary, inline_rag_context=inline_rag_context, filter_server_tools=filter_server_tools, - instructions_substituted=instructions_substituted, - model_substituted=model_substituted, ) except RuntimeError as e: # library mode wraps 413 into runtime error if is_context_length_error(str(e)): @@ -517,7 +497,7 @@ async def handle_streaming_response( input_text=input_text, started_at=started_at, api_params=api_params, - generate_topic_summary=request.generate_topic_summary or False, + generate_topic_summary=updated_request.generate_topic_summary or False, background_tasks=background_tasks, rh_identity_context=rh_identity_context, shield_blocked=(moderation_result.decision == "blocked"), @@ -629,41 +609,21 @@ async def shield_violation_generator( def _sanitize_response_dict( response_dict: dict[str, Any], configured_mcp_labels: set[str], - instructions_substituted: bool = False, - model_substituted: bool = False, + original_request: ResponsesRequest, ) -> None: """Sanitize a serialized response object in-place to remove internal details. Strips fields that expose server-side implementation details from the - response object before it is forwarded to the client: - - - ``instructions``: when the server substituted its own system prompt - (because the client sent ``None`` or a different value was resolved), - the value is replaced with a placeholder slug to avoid leaking the - actual prompt. When the client provided their own instructions and - they were used as-is, the value is left unchanged. - - ``tools``: server-deployed MCP tool definitions are removed; client- - provided tools (those whose ``server_label`` is not in - ``configured_mcp_labels``) are preserved. - - ``output``: server-deployed MCP output items (``mcp_list_tools``, - ``mcp_call``, ``mcp_approval_request``) are stripped so clients only - see item types they understand (``message``, ``function_call``, etc.). - - ``model``: the provider routing prefix (everything before the last - ``/``) is stripped only when the server selected the model - (``model_substituted=True``). When the client specified the model, - it is echoed back unchanged. + response object before it is forwarded to the client. Args: response_dict: Mutable dict produced by ``model_dump`` on a response object. Modified in-place. configured_mcp_labels: Set of ``server_label`` values that identify server-deployed MCP servers. - instructions_substituted: Whether the server substituted the - instructions (True) or the client provided them (False). - model_substituted: Whether the server substituted the model - (True) or the client provided it (False). + original_request: Original request object """ - if instructions_substituted: + if original_request.instructions is None: response_dict["instructions"] = SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER # else: leave instructions as-is (echo back client's value) @@ -681,7 +641,7 @@ def _sanitize_response_dict( if not _is_server_mcp_output_item(item, configured_mcp_labels) ] - if model_substituted: + if original_request.model is None: model = response_dict.get("model") if model and "/" in model: response_dict["model"] = model.rsplit("/", 1)[-1] @@ -796,27 +756,25 @@ def _populate_turn_summary( async def response_generator( stream: AsyncIterator[OpenAIResponseObjectStream], - user_input: ResponseInput, + original_request: ResponsesRequest, + updated_request: ResponsesRequest, api_params: ResponsesApiParams, user_id: str, turn_summary: TurnSummary, inline_rag_context: RAGContext, filter_server_tools: bool = False, - instructions_substituted: bool = False, - model_substituted: bool = False, ) -> AsyncIterator[str]: """Generate SSE-formatted streaming response with LCORE-enriched events. Args: stream: The streaming response from Llama Stack - user_input: User input to the response + original_request: Original request object + updated_request: Updated request object api_params: ResponsesApiParams user_id: User ID for quota retrieval turn_summary: TurnSummary to populate during streaming inline_rag_context: Inline RAG context to be used for the response filter_server_tools: Whether to filter server-deployed MCP tool events from the stream - instructions_substituted: Whether the server substituted the instructions - model_substituted: Whether the server substituted the model Yields: SSE-formatted strings for streaming events, ending with [DONE] """ @@ -853,8 +811,7 @@ async def response_generator( _sanitize_response_dict( chunk_dict["response"], configured_mcp_labels, - instructions_substituted, - model_substituted, + original_request, ) tools = chunk_dict["response"].get("tools") if tools is not None: @@ -916,7 +873,10 @@ async def response_generator( # Explicitly append the turn to conversation if context passed by previous response if api_params.store and api_params.previous_response_id and latest_response_object: await append_turn_items_to_conversation( - client, api_params.conversation, user_input, latest_response_object.output + client, + api_params.conversation, + updated_request.input, + latest_response_object.output, ) yield "data: [DONE]\n\n" @@ -1001,15 +961,14 @@ async def generate_response( async def handle_non_streaming_response( client: AsyncLlamaStackClient, - request: ResponsesRequest, + original_request: ResponsesRequest, + updated_request: ResponsesRequest, auth: AuthTuple, input_text: str, started_at: datetime, moderation_result: ShieldModerationResult, inline_rag_context: RAGContext, filter_server_tools: bool = False, - instructions_substituted: bool = False, - model_substituted: bool = False, background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), ) -> ResponsesResponse: @@ -1017,22 +976,21 @@ async def handle_non_streaming_response( Args: client: The AsyncLlamaStackClient instance - request: Request object + original_request: Original request object + updated_request: Updated request object auth: Authentication tuple input_text: The extracted input text started_at: Timestamp when the conversation started moderation_result: Result of shield moderation check inline_rag_context: Inline RAG context to be used for the response filter_server_tools: Whether to filter server-deployed MCP tool output - instructions_substituted: Whether the server substituted the instructions - model_substituted: Whether the server substituted the model background_tasks: FastAPI background task manager for telemetry events rh_identity_context: Tuple of (org_id, system_id) from RH identity Returns: ResponsesResponse with the completed response """ user_id, _, skip_userid_check, _ = auth - api_params = ResponsesApiParams.model_validate(request.model_dump()) + api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) # Fork: Get response object (blocked vs normal) if moderation_result.decision == "blocked": @@ -1043,13 +1001,13 @@ async def handle_non_streaming_response( status="completed", output=[moderation_result.refusal_response], usage=get_zero_usage(), - **request.echoed_params(), + **updated_request.echoed_params(), ) if api_params.store: await append_turn_items_to_conversation( client=client, conversation_id=api_params.conversation, - user_input=request.input, + user_input=updated_request.input, llm_output=[moderation_result.refusal_response], ) _queue_responses_splunk_event( @@ -1081,7 +1039,10 @@ async def handle_non_streaming_response( # Explicitly append the turn to conversation if context passed by previous response if api_params.store and api_params.previous_response_id: await append_turn_items_to_conversation( - client, api_params.conversation, request.input, api_response.output + client, + api_params.conversation, + updated_request.input, + api_response.output, ) except RuntimeError as e: @@ -1139,7 +1100,7 @@ async def handle_non_streaming_response( ) # Get topic summary for new conversation topic_summary = None - if request.generate_topic_summary: + if updated_request.generate_topic_summary: logger.debug("Generating topic summary for new conversation") topic_summary = await get_topic_summary(input_text, client, api_params.model) @@ -1193,8 +1154,7 @@ async def handle_non_streaming_response( _sanitize_response_dict( response_dict, configured_mcp_labels, - instructions_substituted, - model_substituted, + original_request, ) tools = response_dict.get("tools") if tools is not None: diff --git a/src/utils/responses.py b/src/utils/responses.py index 709c3192a..858973d02 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -331,10 +331,12 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma Returns: ResponsesApiParams containing all prepared parameters for the API request """ - if query_request.model and query_request.provider: - model = f"{query_request.provider}/{query_request.model}" - else: - model = await select_model_for_responses(client, user_conversation) + request_model = ( + f"{query_request.provider}/{query_request.model}" + if query_request.model and query_request.provider + else None + ) + model = await select_model_for_responses(request_model, client, user_conversation) if not await check_model_configured(client, model): _, model_id = extract_provider_and_model_from_model_id(model) @@ -1330,6 +1332,7 @@ async def check_model_configured( async def select_model_for_responses( + request_model: Optional[str], client: AsyncLlamaStackClient, user_conversation: Optional[UserConversation], ) -> str: @@ -1342,6 +1345,7 @@ async def select_model_for_responses( 4. Raise HTTPException if no LLM model is found Args: + request_model: The model explicitly specified in the request, or None if not specified client: The AsyncLlamaStackClient instance user_conversation: The user conversation if conversation_id was provided, None otherwise @@ -1351,6 +1355,9 @@ async def select_model_for_responses( Raises: HTTPException: If models cannot be fetched or an error occurs, or if no LLM model is found """ + if request_model: + return request_model + # 1. Conversation has existing last_used_model if ( user_conversation is not None diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index 94a0c7c65..c8c51e07b 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -43,6 +43,8 @@ MODULE = "app.endpoints.responses" ENDPOINTS_MODULE = "utils.endpoints" UTILS_RESPONSES_MODULE = "utils.responses" +MODEL = "google-vertex/publishers/google/models/gemini-2.5-flash" +SERVER_INSTRUCTIONS = "Server instructions" def _patch_base(mocker: MockerFixture, config: AppConfig) -> None: @@ -633,8 +635,8 @@ async def test_tool_choice_none_without_tools_does_not_load_server_tools( # The handler passes tools=None and tool_choice=None to the response handler # (the endpoint deep-copies the request, so we inspect the handler call args) call_kwargs = mock_handle.call_args[1] - assert call_kwargs["request"].tools is None - assert call_kwargs["request"].tool_choice is None + assert call_kwargs["updated_request"].tools is None + assert call_kwargs["updated_request"].tool_choice is None class TestHandleNonStreamingResponse: @@ -686,7 +688,8 @@ async def test_handle_non_streaming_blocked_returns_refusal( response = await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Bad input", started_at=datetime.now(UTC), @@ -756,7 +759,8 @@ async def test_handle_non_streaming_success_returns_response( response = await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -832,7 +836,8 @@ async def test_handle_non_streaming_with_previous_response_id_appends_turn( await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -869,7 +874,8 @@ async def test_handle_non_streaming_context_length_raises_413( with pytest.raises(HTTPException) as exc_info: await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Long input", started_at=datetime.now(UTC), @@ -906,7 +912,8 @@ async def test_handle_non_streaming_connection_error_raises_503( with pytest.raises(HTTPException) as exc_info: await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -953,7 +960,8 @@ async def test_handle_non_streaming_api_status_error_raises_http( with pytest.raises(HTTPException) as exc_info: await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -987,7 +995,8 @@ async def test_handle_non_streaming_runtime_error_without_context_reraises( with pytest.raises(RuntimeError, match="Some other error"): await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -1032,7 +1041,8 @@ async def test_handle_streaming_blocked_returns_sse_consumes_shield_generator( mock_client.conversations.items.create = mocker.AsyncMock() response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Bad", started_at=datetime.now(UTC), @@ -1111,7 +1121,8 @@ async def mock_stream() -> Any: mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -1194,7 +1205,8 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -1277,7 +1289,8 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -1354,7 +1367,8 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -1397,7 +1411,8 @@ async def test_handle_streaming_context_length_raises_413( with pytest.raises(HTTPException) as exc_info: await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Long", started_at=datetime.now(UTC), @@ -1432,7 +1447,8 @@ async def test_handle_streaming_connection_error_raises_503( with pytest.raises(HTTPException) as exc_info: await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -1486,10 +1502,10 @@ async def test_default_instructions_applied_when_client_omits_them( mcp_headers={}, ) - # The request passed to handle_non_streaming_response should have + # The updated request passed to handle_non_streaming_response should have # instructions resolved to the default system prompt. call_kwargs = mock_handler.call_args[1] - assert call_kwargs["request"].instructions == DEFAULT_SYSTEM_PROMPT + assert call_kwargs["updated_request"].instructions == DEFAULT_SYSTEM_PROMPT @pytest.mark.asyncio async def test_client_provided_instructions_pass_through( @@ -1534,7 +1550,7 @@ async def test_client_provided_instructions_pass_through( ) call_kwargs = mock_handler.call_args[1] - assert call_kwargs["request"].instructions == custom_instructions + assert call_kwargs["updated_request"].instructions == custom_instructions @pytest.mark.asyncio async def test_configured_system_prompt_used_when_no_client_instructions( @@ -1595,7 +1611,10 @@ async def test_configured_system_prompt_used_when_no_client_instructions( ) call_kwargs = mock_handler.call_args[1] - assert call_kwargs["request"].instructions == "You are a deployment assistant." + assert ( + call_kwargs["updated_request"].instructions + == "You are a deployment assistant." + ) @pytest.mark.asyncio async def test_client_instructions_rejected_when_disabled( @@ -1681,7 +1700,7 @@ async def test_streaming_response_uses_resolved_instructions( ) call_kwargs = mock_handler.call_args[1] - assert call_kwargs["request"].instructions == DEFAULT_SYSTEM_PROMPT + assert call_kwargs["updated_request"].instructions == DEFAULT_SYSTEM_PROMPT class TestIsServerMcpOutputItem: @@ -1874,33 +1893,43 @@ def test_does_not_filter_non_mcp_event(self, mocker: MockerFixture) -> None: ) +def mock_original_request( + *, instructions: Optional[str] = None, model: Optional[str] = None +) -> ResponsesRequest: + """Build a minimal ResponsesRequest for _sanitize_response_dict tests.""" + kwargs: dict[str, Any] = {"input": "x"} + if instructions is not None: + kwargs["instructions"] = instructions + if model is not None: + kwargs["model"] = model + return ResponsesRequest(**kwargs) + + class TestSanitizeResponseDict: """Unit tests for _sanitize_response_dict.""" def test_substituted_instructions_replaced_with_placeholder(self) -> None: """Test that substituted instructions are replaced with the slug constant.""" d: dict[str, Any] = {"instructions": "secret server prompt", "model": "m"} - _sanitize_response_dict(d, set(), instructions_substituted=True) + _sanitize_response_dict(d, set(), mock_original_request(instructions=None)) assert d["instructions"] == SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER def test_client_instructions_preserved_when_not_substituted(self) -> None: """Test that client-provided instructions are echoed back unchanged.""" - d: dict[str, Any] = {"instructions": "my custom prompt", "model": "m"} - _sanitize_response_dict(d, set(), instructions_substituted=False) + d: dict[str, Any] = {"instructions": "my custom prompt"} + _sanitize_response_dict( + d, + set(), + mock_original_request(**d), + ) assert d["instructions"] == "my custom prompt" def test_substituted_instructions_set_even_when_absent(self) -> None: """Test that placeholder is set even when instructions field is missing.""" d: dict[str, Any] = {"model": "m"} - _sanitize_response_dict(d, set(), instructions_substituted=True) + _sanitize_response_dict(d, set(), mock_original_request(instructions=None)) assert d["instructions"] == SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER - def test_no_error_when_instructions_absent_and_not_substituted(self) -> None: - """Test that missing instructions field with no substitution does not raise.""" - d: dict[str, Any] = {"model": "m"} - _sanitize_response_dict(d, set(), instructions_substituted=False) - assert "instructions" not in d - def test_strips_server_mcp_tools(self) -> None: """Test that server-deployed MCP tools are removed from tools array.""" d: dict[str, Any] = { @@ -1910,9 +1939,7 @@ def test_strips_server_mcp_tools(self) -> None: {"name": "client-tool"}, ] } - _sanitize_response_dict( - d, {"server-a", "server-b"}, instructions_substituted=False - ) + _sanitize_response_dict(d, {"server-a", "server-b"}, mock_original_request()) assert d["tools"] == [{"name": "client-tool"}] def test_preserves_client_tools(self) -> None: @@ -1923,13 +1950,13 @@ def test_preserves_client_tools(self) -> None: {"name": "client-tool"}, ] } - _sanitize_response_dict(d, {"server-a"}, instructions_substituted=False) + _sanitize_response_dict(d, {"server-a"}, mock_original_request()) assert d["tools"] == [{"name": "client-tool"}] def test_no_error_when_tools_absent(self) -> None: """Test that missing tools field does not raise.""" d: dict[str, Any] = {"model": "m"} - _sanitize_response_dict(d, {"server-a"}, instructions_substituted=False) + _sanitize_response_dict(d, {"server-a"}, mock_original_request()) assert "tools" not in d def test_empty_configured_mcp_labels_preserves_all_tools(self) -> None: @@ -1940,7 +1967,7 @@ def test_empty_configured_mcp_labels_preserves_all_tools(self) -> None: {"name": "client-tool"}, ] } - _sanitize_response_dict(d, set(), instructions_substituted=False) + _sanitize_response_dict(d, set(), mock_original_request()) assert len(d["tools"]) == 2 def test_strips_server_mcp_items_from_output(self) -> None: @@ -1961,7 +1988,9 @@ def test_strips_server_mcp_items_from_output(self) -> None: {"type": "mcp_call", "server_label": "okp", "id": "call-1"}, ], } - _sanitize_response_dict(d, {"okp"}, instructions_substituted=False) + _sanitize_response_dict( + d, {"okp"}, mock_original_request(instructions="client-provided") + ) assert len(d["output"]) == 1 assert d["output"][0]["type"] == "message" @@ -1974,13 +2003,13 @@ def test_preserves_non_server_mcp_output_items(self) -> None: {"type": "function_call", "name": "my_func"}, ], } - _sanitize_response_dict(d, {"okp"}, instructions_substituted=False) + _sanitize_response_dict(d, {"okp"}, mock_original_request()) assert len(d["output"]) == 3 def test_no_error_when_output_absent(self) -> None: """Test that missing output field does not raise.""" d: dict[str, Any] = {"model": "m"} - _sanitize_response_dict(d, {"okp"}, instructions_substituted=False) + _sanitize_response_dict(d, {"okp"}, mock_original_request()) assert "output" not in d def test_strips_provider_prefix_from_model_when_substituted(self) -> None: @@ -1988,9 +2017,7 @@ def test_strips_provider_prefix_from_model_when_substituted(self) -> None: d: dict[str, Any] = { "model": "google-vertex/publishers/google/models/gemini-2.5-flash" } - _sanitize_response_dict( - d, set(), instructions_substituted=False, model_substituted=True - ) + _sanitize_response_dict(d, set(), mock_original_request(model=None)) assert d["model"] == "gemini-2.5-flash" def test_preserves_client_model_when_not_substituted(self) -> None: @@ -1998,17 +2025,13 @@ def test_preserves_client_model_when_not_substituted(self) -> None: d: dict[str, Any] = { "model": "google-vertex/publishers/google/models/gemini-2.5-flash" } - _sanitize_response_dict( - d, set(), instructions_substituted=False, model_substituted=False - ) + _sanitize_response_dict(d, set(), mock_original_request(**d)) assert d["model"] == "google-vertex/publishers/google/models/gemini-2.5-flash" def test_model_without_slash_preserved(self) -> None: """Test that model names without provider prefix are left unchanged.""" d: dict[str, Any] = {"model": "gemini-2.5-flash"} - _sanitize_response_dict( - d, set(), instructions_substituted=False, model_substituted=True - ) + _sanitize_response_dict(d, set(), mock_original_request()) assert d["model"] == "gemini-2.5-flash" def test_all_fields_sanitized_together_with_substitution(self) -> None: @@ -2026,10 +2049,7 @@ def test_all_fields_sanitized_together_with_substitution(self) -> None: ], } _sanitize_response_dict( - d, - {"mcp-server"}, - instructions_substituted=True, - model_substituted=True, + d, {"mcp-server"}, mock_original_request(instructions=None, model=None) ) assert d["instructions"] == SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER assert d["model"] == "gemini" @@ -2054,8 +2074,10 @@ def test_all_fields_sanitized_together_without_substitution(self) -> None: _sanitize_response_dict( d, {"mcp-server"}, - instructions_substituted=False, - model_substituted=False, + mock_original_request( + instructions="client prompt", + model="provider/model1", + ), ) assert d["instructions"] == "client prompt" assert d["model"] == "google-vertex/publishers/google/models/gemini" @@ -2087,7 +2109,14 @@ async def test_non_streaming_sanitizes_mcp_output_and_model( mock_config.quota_limiters = minimal_config.quota_limiters mock_config.rag_id_mapping = {} - request = _request_with_model_and_conv("Hi", model="provider/model1") + original_request = ResponsesRequest(input="Hi") + updated_request = ResponsesRequest( + input="Hi", + model=MODEL, + instructions=SERVER_INSTRUCTIONS, + conversation=VALID_CONV_ID_NORMALIZED, + ) + mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) mock_moderation = mocker.Mock() mock_moderation.decision = "passed" @@ -2161,20 +2190,19 @@ async def test_non_streaming_sanitizes_mcp_output_and_model( response = await handle_non_streaming_response( client=mock_client, - request=request, + original_request=original_request, + updated_request=updated_request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), - instructions_substituted=True, - model_substituted=True, ) assert isinstance(response, ResponsesResponse) # Model provider prefix should be stripped when server-substituted assert response.model == "gemini-2.5-flash" - # Instructions should be replaced with placeholder + # Client omitted instructions: hide resolved server prompt assert response.instructions == SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER # MCP output items should be filtered out output_types = [item.type for item in response.output] @@ -2196,8 +2224,8 @@ def _make_streaming_completed_chunk(self, mocker: MockerFixture) -> Any: "type": "response.completed", "response": { "id": "r1", - "instructions": "secret server prompt", - "model": "google-vertex/publishers/google/models/gemini-2.5-flash", + "instructions": SERVER_INSTRUCTIONS, + "model": MODEL, "output": [ { "type": "mcp_list_tools", @@ -2234,7 +2262,13 @@ async def test_streaming_sanitizes_mcp_output_model_and_instructions( mock_config.quota_limiters = minimal_config.quota_limiters mock_config.rag_id_mapping = {} - request = _request_with_model_and_conv("Hi", model="provider/model1") + original_request = ResponsesRequest(input="Hi") + updated_request = ResponsesRequest( + input="Hi", + model=MODEL, + instructions=SERVER_INSTRUCTIONS, + conversation=VALID_CONV_ID_NORMALIZED, + ) mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) mock_moderation = mocker.Mock() mock_moderation.decision = "passed" @@ -2274,15 +2308,14 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=original_request, + updated_request=updated_request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), filter_server_tools=False, - instructions_substituted=True, - model_substituted=True, ) collected: list[str] = [] async for part in response.body_iterator: @@ -2393,7 +2426,8 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -2484,7 +2518,8 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), diff --git a/tests/unit/app/endpoints/test_responses_splunk.py b/tests/unit/app/endpoints/test_responses_splunk.py index b161e6551..ce1588a75 100644 --- a/tests/unit/app/endpoints/test_responses_splunk.py +++ b/tests/unit/app/endpoints/test_responses_splunk.py @@ -258,7 +258,8 @@ async def test_non_streaming_shield_blocked( await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Bad input", started_at=datetime.now(UTC), @@ -334,7 +335,8 @@ async def test_non_streaming_error_fires_telemetry( with pytest.raises(HTTPException): await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -412,7 +414,8 @@ async def test_non_streaming_success( await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -468,7 +471,8 @@ async def test_streaming_shield_blocked( response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Bad", started_at=datetime.now(UTC), @@ -545,7 +549,8 @@ async def test_streaming_error_fires_telemetry( with pytest.raises(HTTPException): await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -625,7 +630,8 @@ async def mock_stream() -> Any: response = await handle_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -696,7 +702,8 @@ async def test_splunk_disabled_no_background_tasks( # background_tasks=None (the default) means Splunk is disabled await handle_non_streaming_response( client=mock_client, - request=request, + original_request=request, + updated_request=request, auth=MOCK_AUTH, input_text="Bad input", started_at=datetime.now(UTC),