diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b4c31b017..61e7bd81e 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -170,9 +170,10 @@ async def query_endpoint_handler( # Moderation input is the raw user content (query + attachments) without injected RAG # context, to avoid false positives from retrieved document content. + endpoint_path = "/v1/query" moderation_input = prepare_input(query_request) moderation_result = await run_shield_moderation( - client, moderation_input, query_request.shield_ids + client, moderation_input, endpoint_path, query_request.shield_ids ) # Build RAG context from Inline RAG sources @@ -207,7 +208,9 @@ async def query_endpoint_handler( client = await update_azure_token(client) # Retrieve response using Responses API - turn_summary = await retrieve_response(client, responses_params, moderation_result) + turn_summary = await retrieve_response( + client, responses_params, moderation_result, endpoint_path + ) if moderation_result.decision == "passed": # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript @@ -280,6 +283,7 @@ async def retrieve_response( client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, moderation_result: ShieldModerationResult, + endpoint_path: str = "", ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -332,5 +336,9 @@ async def retrieve_response( vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) rag_id_mapping = configuration.rag_id_mapping return build_turn_summary( - response, responses_params.model, vector_store_ids, rag_id_mapping + response, + responses_params.model, + endpoint_path, + vector_store_ids, + rag_id_mapping, ) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 230ce111e..1e49a050f 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -331,9 +331,11 @@ async def responses_endpoint_handler( ) attachments_text = extract_attachments_text(original_request.input) + endpoint_path = "/v1/responses" moderation_result = await run_shield_moderation( client, input_text + "\n\n" + attachments_text, + endpoint_path, original_request.shield_ids, ) @@ -388,6 +390,7 @@ async def responses_endpoint_handler( background_tasks=background_tasks, rh_identity_context=rh_identity_context, user_agent=_get_user_agent(request), + endpoint_path=endpoint_path, ) @@ -404,6 +407,7 @@ async def handle_streaming_response( background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), user_agent: Optional[str] = None, + endpoint_path: str = "", ) -> StreamingResponse: """Handle streaming response from Responses API. @@ -470,6 +474,7 @@ async def handle_streaming_response( turn_summary=turn_summary, inline_rag_context=inline_rag_context, filter_server_tools=filter_server_tools, + endpoint_path=endpoint_path, ) except RuntimeError as e: # library mode wraps 413 into runtime error if is_context_length_error(str(e)): @@ -798,6 +803,7 @@ async def response_generator( turn_summary: TurnSummary, inline_rag_context: RAGContext, filter_server_tools: bool = False, + endpoint_path: str = "", ) -> AsyncIterator[str]: """Generate SSE-formatted streaming response with LCORE-enriched events. @@ -810,6 +816,7 @@ async def response_generator( turn_summary: TurnSummary to populate during streaming inline_rag_context: Inline RAG context to be used for the response filter_server_tools: Whether to filter server-deployed MCP tool events from the stream + endpoint_path: API endpoint path used for metric labeling. Yields: SSE-formatted strings for streaming events, ending with [DONE] """ @@ -873,7 +880,7 @@ async def response_generator( # Extract and consume tokens if any were used turn_summary.token_usage = extract_token_usage( - latest_response_object.usage, api_params.model + latest_response_object.usage, api_params.model, endpoint_path ) consume_query_tokens( user_id=user_id, @@ -1010,6 +1017,7 @@ async def handle_non_streaming_response( background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), user_agent: Optional[str] = None, + endpoint_path: str = "", ) -> ResponsesResponse: """Handle non-streaming response from Responses API. @@ -1069,7 +1077,9 @@ async def handle_non_streaming_response( **api_params.model_dump(exclude_none=True) ), ) - token_usage = extract_token_usage(api_response.usage, api_params.model) + token_usage = extract_token_usage( + api_response.usage, api_params.model, endpoint_path + ) logger.info("Consuming tokens") consume_query_tokens( user_id=user_id, @@ -1152,6 +1162,7 @@ async def handle_non_streaming_response( turn_summary = build_turn_summary( api_response, api_params.model, + endpoint_path, vector_store_ids, configuration.rag_id_mapping, filter_server_tools=filter_server_tools, diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 3c7c7a6e6..b0d19dfdb 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -241,6 +241,7 @@ async def retrieve_simple_response( instructions: str, tools: Optional[list[Any]] = None, model_id: Optional[str] = None, + endpoint_path: str = "/v1/infer", ) -> str: """Retrieve a simple response from the LLM for a stateless query. @@ -263,7 +264,7 @@ async def retrieve_simple_response( """ resolved_model_id = model_id or await _get_default_model_id() response = await _call_llm(question, instructions, tools, resolved_model_id) - extract_token_usage(response.usage, resolved_model_id) + extract_token_usage(response.usage, resolved_model_id, endpoint_path) return extract_text_from_response_items(response.output) @@ -366,12 +367,13 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position background_tasks.add_task(send_splunk_event, event, sourcetype) -async def _check_shield_moderation( +async def _check_shield_moderation( # pylint: disable=too-many-arguments,too-many-positional-arguments input_text: str, request_id: str, background_tasks: BackgroundTasks, infer_request: RlsapiV1InferRequest, request: Request, + endpoint_path: str, ) -> Optional[RlsapiV1InferResponse]: """Run shield moderation and return a refusal response if blocked. @@ -384,13 +386,14 @@ async def _check_shield_moderation( background_tasks: FastAPI background tasks for async Splunk event sending. infer_request: The original inference request (for Splunk event context). request: The FastAPI request object (for Splunk event context). + endpoint_path: The API endpoint path for metric labeling. Returns: An RlsapiV1InferResponse containing the refusal message if the input was blocked, or None if moderation passed. """ client = AsyncLlamaStackClientHolder().get_client() - moderation_result = await run_shield_moderation(client, input_text) + moderation_result = await run_shield_moderation(client, input_text, endpoint_path) if moderation_result.decision != "blocked": return None @@ -432,6 +435,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po start_time: float, model: str, provider: str, + endpoint_path: str, ) -> float: """Record metrics and queue Splunk event for an inference failure. @@ -442,12 +446,15 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po request_id: Unique identifier for the request. error: The exception that caused the failure. start_time: Monotonic clock time when inference started. + model: The model name. + provider: The provider name. + endpoint_path: The API endpoint path for metric labeling. Returns: The total inference time in seconds. """ inference_time = time.monotonic() - start_time - recording.record_llm_failure(provider, model) + recording.record_llm_failure(provider, model, endpoint_path) _queue_splunk_event( background_tasks, infer_request, @@ -530,6 +537,7 @@ def _build_infer_response( request_id: str, response: Optional[OpenAIResponseObject], model_id: str, + endpoint_path: str, ) -> RlsapiV1InferResponse: """Build the final inference response, with optional verbose metadata. @@ -549,7 +557,11 @@ def _build_infer_response( """ if response is not None: turn_summary = build_turn_summary( - response, model_id, vector_store_ids=None, rag_id_mapping=None + response, + model_id, + endpoint_path, + vector_store_ids=None, + rag_id_mapping=None, ) return RlsapiV1InferResponse( data=RlsapiV1InferData( @@ -673,12 +685,19 @@ async def infer_endpoint( # pylint: disable=R0914 "Request %s: Combined input source length: %d", request_id, len(input_source) ) + endpoint_path = "/v1/infer" + # Run shield moderation on user input before inference. # Uses all configured shields; no-op when no shields are registered. # Runs before model/tool discovery so blocked requests short-circuit # without incurring external I/O. blocked_response = await _check_shield_moderation( - input_source, request_id, background_tasks, infer_request, request + input_source, + request_id, + background_tasks, + infer_request, + request, + endpoint_path, ) if blocked_response is not None: return blocked_response @@ -700,11 +719,11 @@ async def infer_endpoint( # pylint: disable=R0914 model_id=model_id, ) response_text = extract_text_from_response_items(response.output) - token_usage = extract_token_usage(response.usage, model_id) + token_usage = extract_token_usage(response.usage, model_id, endpoint_path) inference_time = time.monotonic() - start_time except _INFER_HANDLED_EXCEPTIONS as error: if response is not None: - extract_token_usage(response.usage, model_id) # type: ignore[arg-type] + extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type] _record_inference_failure( background_tasks, infer_request, @@ -714,6 +733,7 @@ async def infer_endpoint( # pylint: disable=R0914 start_time, model, provider, + endpoint_path, ) mapped_error = _map_inference_error_to_http_exception( error, @@ -755,4 +775,5 @@ async def infer_endpoint( # pylint: disable=R0914 request_id, response if verbose_enabled else None, model_id, + endpoint_path, ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index f11887d2f..8bc6134ec 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -226,8 +226,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals # Moderation input is the raw user content (query + attachments) without injected RAG # context, to avoid false positives from retrieved document content. moderation_input = prepare_input(query_request) + endpoint_path = "/v1/streaming_query" moderation_result = await run_shield_moderation( - client, moderation_input, query_request.shield_ids + client, moderation_input, endpoint_path, query_request.shield_ids ) # Build RAG context from Inline RAG sources @@ -283,11 +284,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals provider_id, model_id = extract_provider_and_model_from_model_id( responses_params.model ) - recording.record_llm_call(provider_id, model_id) + recording.record_llm_call(provider_id, model_id, endpoint_path) generator, turn_summary = await retrieve_response_generator( responses_params=responses_params, context=context, + endpoint_path=endpoint_path, ) # Combine inline RAG results (BYOK + Solr) with tool-based results @@ -316,6 +318,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, + endpoint_path: str, ) -> tuple[AsyncIterator[str], TurnSummary]: """ Retrieve the appropriate response generator. @@ -327,6 +330,7 @@ async def retrieve_response_generator( Args: responses_params: The Responses API parameters context: The response generator context + endpoint_path: API endpoint path used for metric labeling. Returns: tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary @@ -360,6 +364,7 @@ async def retrieve_response_generator( response, context, turn_summary, + endpoint_path, ), turn_summary, ) @@ -685,6 +690,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat turn_response: AsyncIterator[OpenAIResponseObjectStream], context: ResponseGeneratorContext, turn_summary: TurnSummary, + endpoint_path: str, ) -> AsyncIterator[str]: """Generate SSE formatted streaming response. @@ -696,6 +702,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat turn_response: The streaming response from Llama Stack context: The response generator context turn_summary: TurnSummary to populate during streaming + endpoint_path: API endpoint path used for metric labeling. Yields: SSE-formatted strings for tokens, tool calls, tool results, @@ -862,7 +869,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat return turn_summary.token_usage = extract_token_usage( - latest_response_object.usage, context.model_id + latest_response_object.usage, context.model_id, endpoint_path ) # Parse tool-based referenced documents from the final response object tool_rag_docs = parse_referenced_documents( diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 49c6767a0..893e634db 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -29,25 +29,29 @@ # Metric that counts how many LLM calls were made for each provider + model llm_calls_total = Counter( - "ls_llm_calls_total", "LLM calls counter", ["provider", "model"] + "ls_llm_calls_total", "LLM calls counter", ["provider", "model", "endpoint"] ) # Metric that counts how many LLM calls failed llm_calls_failures_total = Counter( - "ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"] + "ls_llm_calls_failures_total", + "LLM calls failures", + ["provider", "model", "endpoint"], ) # Metric that counts how many LLM calls had validation errors llm_calls_validation_errors_total = Counter( - "ls_llm_validation_errors_total", "LLM validation errors" + "ls_llm_validation_errors_total", "LLM validation errors", ["endpoint"] ) # Metric that counts how many tokens were sent to LLMs llm_token_sent_total = Counter( - "ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"] + "ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model", "endpoint"] ) # Metric that counts how many tokens were received from LLMs llm_token_received_total = Counter( - "ls_llm_token_received_total", "LLM tokens received", ["provider", "model"] + "ls_llm_token_received_total", + "LLM tokens received", + ["provider", "model", "endpoint"], ) diff --git a/src/metrics/recording.py b/src/metrics/recording.py index abea2270d..e7da276cc 100644 --- a/src/metrics/recording.py +++ b/src/metrics/recording.py @@ -44,36 +44,42 @@ def record_rest_api_call(path: str, status_code: int) -> None: logger.warning("Failed to update REST API call metric", exc_info=True) -def record_llm_call(provider: str, model: str) -> None: +def record_llm_call(provider: str, model: str, endpoint_path: str) -> None: """Record one LLM call for a provider and model. Args: provider: LLM provider identifier. model: LLM model identifier without the provider prefix. + endpoint_path: The API endpoint path for metric labeling. """ try: - metrics.llm_calls_total.labels(provider, model).inc() + metrics.llm_calls_total.labels(provider, model, endpoint_path).inc() except (AttributeError, TypeError, ValueError): logger.warning("Failed to update LLM call metric", exc_info=True) -def record_llm_failure(provider: str, model: str) -> None: +def record_llm_failure(provider: str, model: str, endpoint_path: str) -> None: """Record one failed LLM call for a provider and model. Args: provider: LLM provider identifier. model: LLM model identifier without the provider prefix. + endpoint_path: The API endpoint path for metric labeling. """ try: - metrics.llm_calls_failures_total.labels(provider, model).inc() + metrics.llm_calls_failures_total.labels(provider, model, endpoint_path).inc() except (AttributeError, TypeError, ValueError): logger.warning("Failed to update LLM failure metric", exc_info=True) -def record_llm_validation_error() -> None: - """Record one LLM validation error, such as a shield violation.""" +def record_llm_validation_error(endpoint_path: str = "") -> None: + """Record one LLM validation error, such as a shield violation. + + Args: + endpoint_path: The API endpoint path for metric labeling. + """ try: - metrics.llm_calls_validation_errors_total.inc() + metrics.llm_calls_validation_errors_total.labels(endpoint_path).inc() except (AttributeError, TypeError, ValueError): logger.warning("Failed to update LLM validation error metric", exc_info=True) @@ -83,6 +89,7 @@ def record_llm_token_usage( model: str, input_tokens: int, output_tokens: int, + endpoint_path: str, ) -> None: """Record LLM token usage for a provider and model. @@ -91,9 +98,14 @@ def record_llm_token_usage( model: LLM model identifier without the provider prefix. input_tokens: Number of tokens sent to the LLM. output_tokens: Number of tokens received from the LLM. + endpoint_path: The API endpoint path for metric labeling. """ try: - metrics.llm_token_sent_total.labels(provider, model).inc(input_tokens) - metrics.llm_token_received_total.labels(provider, model).inc(output_tokens) + metrics.llm_token_sent_total.labels(provider, model, endpoint_path).inc( + input_tokens + ) + metrics.llm_token_received_total.labels(provider, model, endpoint_path).inc( + output_tokens + ) except (AttributeError, TypeError, ValueError): logger.warning("Failed to update token metrics", exc_info=True) diff --git a/src/utils/responses.py b/src/utils/responses.py index 869a3a959..4e2863782 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -907,12 +907,15 @@ def parse_rag_chunks( return rag_chunks -def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCounter: +def extract_token_usage( + usage: Optional[ResponseUsage], model: str, endpoint_path: str +) -> TokenCounter: """Extract token usage from Responses API usage object and update metrics. Args: usage: ResponseUsage from the Responses API response, or None if not available. model: The model identifier in "provider/model" format + endpoint_path: The API endpoint path for metric labeling. Returns: TokenCounter with input_tokens and output_tokens @@ -922,7 +925,7 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun logger.debug( "No usage information in Responses API response, token counts will be 0" ) - recording.record_llm_call(provider_id, model_id) + recording.record_llm_call(provider_id, model_id, endpoint_path) return TokenCounter(llm_calls=1) token_counter = TokenCounter( @@ -940,8 +943,9 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun model_id, token_counter.input_tokens, token_counter.output_tokens, + endpoint_path, ) - recording.record_llm_call(provider_id, model_id) + recording.record_llm_call(provider_id, model_id, endpoint_path) return token_counter @@ -1432,9 +1436,10 @@ def is_server_deployed_output(output_item: ResponseOutput) -> bool: return True -def build_turn_summary( +def build_turn_summary( # pylint: disable=too-many-arguments,too-many-positional-arguments response: Optional[OpenAIResponseObject], model: str, + endpoint_path: str, vector_store_ids: Optional[list[str]] = None, rag_id_mapping: Optional[dict[str, str]] = None, filter_server_tools: bool = False, @@ -1444,6 +1449,7 @@ def build_turn_summary( Args: response: The ResponseObject to build the turn summary from, or None model: The model identifier in "provider/model" format + endpoint_path: The API endpoint path for metric labeling. vector_store_ids: Vector store IDs used in the query for source resolution. rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. filter_server_tools: When True, skip client-provided tool output items @@ -1478,7 +1484,7 @@ def build_turn_summary( summary.tool_results.append(tool_result) summary.rag_chunks = parse_rag_chunks(response, vector_store_ids, rag_id_mapping) - summary.token_usage = extract_token_usage(response.usage, model) + summary.token_usage = extract_token_usage(response.usage, model, endpoint_path) return summary diff --git a/src/utils/shields.py b/src/utils/shields.py index 6d6089139..3104c0183 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -122,6 +122,7 @@ def validate_shield_ids_override( async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, + endpoint_path: str, shield_ids: Optional[list[str]] = None, ) -> ShieldModerationResult: """ @@ -134,6 +135,7 @@ async def run_shield_moderation( ---------- client: The Llama Stack client. input_text: The text to moderate. + endpoint_path: The API endpoint path for metric labeling. shield_ids: Optional list of shield IDs to use. If None, uses all shields. If empty list, skips all shields. @@ -178,7 +180,7 @@ async def run_shield_moderation( if moderation_result.results and moderation_result.results[0].flagged: result = moderation_result.results[0] - recording.record_llm_validation_error() + recording.record_llm_validation_error(endpoint_path) logger.warning( "Shield '%s' flagged content: categories=%s", shield.identifier, diff --git a/tests/unit/app/endpoints/test_metrics.py b/tests/unit/app/endpoints/test_metrics.py index d2e1dbbab..5cebc0529 100644 --- a/tests/unit/app/endpoints/test_metrics.py +++ b/tests/unit/app/endpoints/test_metrics.py @@ -43,6 +43,5 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None: assert "# TYPE ls_llm_calls_total counter" in response_body assert "# TYPE ls_llm_calls_failures_total counter" in response_body assert "# TYPE ls_llm_validation_errors_total counter" in response_body - assert "# TYPE ls_llm_validation_errors_created gauge" in response_body assert "# TYPE ls_llm_token_sent_total counter" in response_body assert "# TYPE ls_llm_token_received_total counter" in response_body diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index f6c2bc958..aed2443a5 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -855,7 +855,7 @@ async def mock_response_generator( ) generator, turn_summary = await retrieve_response_generator( - mock_responses_params, mock_context + mock_responses_params, mock_context, endpoint_path="" ) assert isinstance(turn_summary, TurnSummary) @@ -894,7 +894,7 @@ async def test_retrieve_response_generator_shield_blocked( ) _generator, turn_summary = await retrieve_response_generator( - mock_responses_params, mock_context + mock_responses_params, mock_context, endpoint_path="" ) assert isinstance(turn_summary, TurnSummary) @@ -949,7 +949,9 @@ async def test_retrieve_response_generator_connection_error( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response_generator(mock_responses_params, mock_context) + await retrieve_response_generator( + mock_responses_params, mock_context, endpoint_path="" + ) assert exc_info.value.status_code == 503 @@ -999,7 +1001,9 @@ async def test_retrieve_response_generator_api_status_error( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response_generator(mock_responses_params, mock_context) + await retrieve_response_generator( + mock_responses_params, mock_context, endpoint_path="" + ) assert exc_info.value.status_code == 500 @@ -1046,7 +1050,9 @@ async def test_retrieve_response_generator_runtime_error_context_length( ) with pytest.raises(HTTPException) as exc_info: - await retrieve_response_generator(mock_responses_params, mock_context) + await retrieve_response_generator( + mock_responses_params, mock_context, endpoint_path="" + ) assert exc_info.value.status_code == 413 @@ -1083,7 +1089,9 @@ async def test_retrieve_response_generator_runtime_error_other( ) with pytest.raises(RuntimeError): - await retrieve_response_generator(mock_responses_params, mock_context) + await retrieve_response_generator( + mock_responses_params, mock_context, endpoint_path="" + ) class TestGenerateResponse: @@ -1870,7 +1878,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -1900,7 +1908,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -1938,7 +1946,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: ) async for _ in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): pass @@ -1980,7 +1988,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2029,7 +2037,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2080,7 +2088,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2123,7 +2131,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: ) async for _ in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): pass @@ -2172,7 +2180,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2218,7 +2226,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2263,7 +2271,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2306,7 +2314,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2350,7 +2358,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2392,7 +2400,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2445,7 +2453,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: ) async for _ in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): pass @@ -2572,7 +2580,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2635,7 +2643,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2728,7 +2736,7 @@ def build_mcp_tool_call_side_effect( result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) @@ -2798,7 +2806,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: result = [] async for item in response_generator( - mock_turn_response(), mock_context, mock_turn_summary + mock_turn_response(), mock_context, mock_turn_summary, endpoint_path="" ): result.append(item) diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py index d132a8293..8012ac311 100644 --- a/tests/unit/metrics/test_recording.py +++ b/tests/unit/metrics/test_recording.py @@ -61,9 +61,9 @@ def test_record_llm_call_records_counter(mocker: MockerFixture) -> None: """Test that LLM call recording increments the provider/model counter.""" mock_metric = mocker.patch("metrics.recording.metrics.llm_calls_total") - recording.record_llm_call("provider1", "model1") + recording.record_llm_call("provider1", "model1", "/test-endpoint") - mock_metric.labels.assert_called_once_with("provider1", "model1") + mock_metric.labels.assert_called_once_with("provider1", "model1", "/test-endpoint") mock_metric.labels.return_value.inc.assert_called_once() @@ -73,7 +73,7 @@ def test_record_llm_call_logs_metric_errors(mocker: MockerFixture) -> None: mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") mock_logger = mocker.patch("metrics.recording.logger") - recording.record_llm_call("provider1", "model1") + recording.record_llm_call("provider1", "model1", "/test-endpoint") mock_logger.warning.assert_called_once_with( "Failed to update LLM call metric", exc_info=True @@ -84,9 +84,9 @@ def test_record_llm_failure_records_counter(mocker: MockerFixture) -> None: """Test that LLM failure recording increments the provider/model counter.""" mock_metric = mocker.patch("metrics.recording.metrics.llm_calls_failures_total") - recording.record_llm_failure("provider1", "model1") + recording.record_llm_failure("provider1", "model1", "/test-endpoint") - mock_metric.labels.assert_called_once_with("provider1", "model1") + mock_metric.labels.assert_called_once_with("provider1", "model1", "/test-endpoint") mock_metric.labels.return_value.inc.assert_called_once() @@ -96,7 +96,7 @@ def test_record_llm_failure_logs_metric_errors(mocker: MockerFixture) -> None: mock_metric.labels.return_value.inc.side_effect = TypeError("bad") mock_logger = mocker.patch("metrics.recording.logger") - recording.record_llm_failure("provider1", "model1") + recording.record_llm_failure("provider1", "model1", "/test-endpoint") mock_logger.warning.assert_called_once_with( "Failed to update LLM failure metric", exc_info=True @@ -109,9 +109,10 @@ def test_record_llm_validation_error_records_counter(mocker: MockerFixture) -> N "metrics.recording.metrics.llm_calls_validation_errors_total" ) - recording.record_llm_validation_error() + recording.record_llm_validation_error("/test-endpoint") - mock_metric.inc.assert_called_once() + mock_metric.labels.assert_called_once_with("/test-endpoint") + mock_metric.labels.return_value.inc.assert_called_once() def test_record_llm_validation_error_logs_metric_errors( @@ -121,10 +122,10 @@ def test_record_llm_validation_error_logs_metric_errors( mock_metric = mocker.patch( "metrics.recording.metrics.llm_calls_validation_errors_total" ) - mock_metric.inc.side_effect = ValueError("bad") + mock_metric.labels.return_value.inc.side_effect = ValueError("bad") mock_logger = mocker.patch("metrics.recording.logger") - recording.record_llm_validation_error() + recording.record_llm_validation_error("/test-endpoint") mock_logger.warning.assert_called_once_with( "Failed to update LLM validation error metric", exc_info=True @@ -136,11 +137,13 @@ def test_record_llm_token_usage_records_counters(mocker: MockerFixture) -> None: mock_sent = mocker.patch("metrics.recording.metrics.llm_token_sent_total") mock_received = mocker.patch("metrics.recording.metrics.llm_token_received_total") - recording.record_llm_token_usage("provider1", "model1", 100, 50) + recording.record_llm_token_usage("provider1", "model1", 100, 50, "/test-endpoint") - mock_sent.labels.assert_called_once_with("provider1", "model1") + mock_sent.labels.assert_called_once_with("provider1", "model1", "/test-endpoint") mock_sent.labels.return_value.inc.assert_called_once_with(100) - mock_received.labels.assert_called_once_with("provider1", "model1") + mock_received.labels.assert_called_once_with( + "provider1", "model1", "/test-endpoint" + ) mock_received.labels.return_value.inc.assert_called_once_with(50) @@ -151,7 +154,7 @@ def test_record_llm_token_usage_logs_metric_errors(mocker: MockerFixture) -> Non mocker.patch("metrics.recording.metrics.llm_token_received_total") mock_logger = mocker.patch("metrics.recording.logger") - recording.record_llm_token_usage("provider1", "model1", 100, 50) + recording.record_llm_token_usage("provider1", "model1", 100, 50, "/test-endpoint") mock_logger.warning.assert_called_once_with( "Failed to update token metrics", exc_info=True diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index e6e49a16c..982c04f83 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -2247,14 +2247,14 @@ def test_extract_token_usage_with_usage_object( ) mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") - result = extract_token_usage(mock_usage, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1", "/test-endpoint") assert result.input_tokens == input_tokens assert result.output_tokens == output_tokens assert result.llm_calls == 1 mock_token_usage.assert_called_once_with( - "provider1", "model1", input_tokens, output_tokens + "provider1", "model1", input_tokens, output_tokens, "/test-endpoint" ) - mock_llm_call.assert_called_once_with("provider1", "model1") + mock_llm_call.assert_called_once_with("provider1", "model1", "/test-endpoint") def test_extract_token_usage_no_usage(self, mocker: MockerFixture) -> None: """Test extracting token usage when usage is None.""" @@ -2264,11 +2264,11 @@ def test_extract_token_usage_no_usage(self, mocker: MockerFixture) -> None: ) mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") - result = extract_token_usage(None, "provider1/model1") + result = extract_token_usage(None, "provider1/model1", "/test-endpoint") assert result.input_tokens == 0 assert result.output_tokens == 0 assert result.llm_calls == 1 - mock_llm_call.assert_called_once_with("provider1", "model1") + mock_llm_call.assert_called_once_with("provider1", "model1", "/test-endpoint") def test_extract_token_usage_zero_tokens(self, mocker: MockerFixture) -> None: """Test extracting token usage when tokens are 0.""" @@ -2285,11 +2285,13 @@ def test_extract_token_usage_zero_tokens(self, mocker: MockerFixture) -> None: ) mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") - result = extract_token_usage(mock_usage, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1", "/test-endpoint") assert result.input_tokens == 0 assert result.output_tokens == 0 - mock_token_usage.assert_called_once_with("provider1", "model1", 0, 0) - mock_llm_call.assert_called_once_with("provider1", "model1") + mock_token_usage.assert_called_once_with( + "provider1", "model1", 0, 0, "/test-endpoint" + ) + mock_llm_call.assert_called_once_with("provider1", "model1", "/test-endpoint") def test_extract_token_usage_none_response(self, mocker: MockerFixture) -> None: """Test extracting token usage with None response.""" @@ -2299,10 +2301,10 @@ def test_extract_token_usage_none_response(self, mocker: MockerFixture) -> None: ) mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") - result = extract_token_usage(None, "provider1/model1") + result = extract_token_usage(None, "provider1/model1", "/test-endpoint") assert result.input_tokens == 0 assert result.output_tokens == 0 - mock_llm_call.assert_called_once_with("provider1", "model1") + mock_llm_call.assert_called_once_with("provider1", "model1", "/test-endpoint") class TestBuildToolCallSummary: diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index b7e73b2c1..b11562704 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -118,7 +118,9 @@ async def test_returns_not_blocked_when_no_shields( mock_client.shields.list = mocker.AsyncMock(return_value=[]) mock_client.models.list = mocker.AsyncMock(return_value=[]) - result = await run_shield_moderation(mock_client, "test input") + result = await run_shield_moderation( + mock_client, "test input", "/test-endpoint" + ) assert result.decision == "passed" @@ -147,7 +149,9 @@ async def test_returns_not_blocked_when_moderation_passes( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "safe input") + result = await run_shield_moderation( + mock_client, "safe input", "/test-endpoint" + ) assert result.decision == "passed" mock_client.moderations.create.assert_called_once_with( @@ -187,11 +191,13 @@ async def test_returns_blocked_when_content_flagged( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "violent content") + result = await run_shield_moderation( + mock_client, "violent content", "/test-endpoint" + ) assert result.decision == "blocked" assert result.message == "Content blocked for violence" - mock_record_error.assert_called_once() + mock_record_error.assert_called_once_with("/test-endpoint") @pytest.mark.asyncio async def test_returns_blocked_with_default_message_when_no_user_message( @@ -226,7 +232,9 @@ async def test_returns_blocked_with_default_message_when_no_user_message( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "spam content") + result = await run_shield_moderation( + mock_client, "spam content", "/test-endpoint" + ) assert result.decision == "blocked" assert result.message == DEFAULT_VIOLATION_MESSAGE @@ -256,7 +264,9 @@ async def test_skips_model_check_for_non_llama_guard_shields( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "test input") + result = await run_shield_moderation( + mock_client, "test input", "/test-endpoint" + ) assert result.decision == "passed" mock_client.moderations.create.assert_called_once_with( @@ -283,7 +293,7 @@ async def test_raises_http_exception_when_shield_model_not_found( mock_client.models.list = mocker.AsyncMock(return_value=[model]) with pytest.raises(HTTPException) as exc_info: - await run_shield_moderation(mock_client, "test input") + await run_shield_moderation(mock_client, "test input", "/test-endpoint") assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert "missing-model" in exc_info.value.detail["cause"] # type: ignore @@ -305,7 +315,7 @@ async def test_raises_http_exception_when_shield_has_no_provider_resource_id( mock_client.models.list = mocker.AsyncMock(return_value=[]) with pytest.raises(HTTPException) as exc_info: - await run_shield_moderation(mock_client, "test input") + await run_shield_moderation(mock_client, "test input", "/test-endpoint") assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -320,7 +330,9 @@ async def test_shield_ids_empty_list_runs_no_shields_returns_passed( mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) mock_client.models.list = mocker.AsyncMock(return_value=[]) - result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) + result = await run_shield_moderation( + mock_client, "test input", "/test-endpoint", shield_ids=[] + ) assert result.decision == "passed" @@ -336,7 +348,7 @@ async def test_shield_ids_raises_404_when_no_shields_found( with pytest.raises(HTTPException) as exc_info: await run_shield_moderation( - mock_client, "test input", shield_ids=["typo-shield"] + mock_client, "test input", "/test-endpoint", shield_ids=["typo-shield"] ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -369,7 +381,7 @@ async def test_shield_ids_filters_to_specific_shield( ) result = await run_shield_moderation( - mock_client, "test input", shield_ids=["shield-1"] + mock_client, "test input", "/test-endpoint", shield_ids=["shield-1"] ) assert result.decision == "passed"