Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ async def query_endpoint_handler(

# Moderation input is the raw user content (query + attachments) without injected RAG
# context, to avoid false positives from retrieved document content.
endpoint_path = "/v1/query"
moderation_input = prepare_input(query_request)
moderation_result = await run_shield_moderation(
client, moderation_input, query_request.shield_ids
client, moderation_input, endpoint_path, query_request.shield_ids
)

# Build RAG context from Inline RAG sources
Expand Down Expand Up @@ -207,7 +208,9 @@ async def query_endpoint_handler(
client = await update_azure_token(client)

# Retrieve response using Responses API
turn_summary = await retrieve_response(client, responses_params, moderation_result)
turn_summary = await retrieve_response(
client, responses_params, moderation_result, endpoint_path
)

if moderation_result.decision == "passed":
# Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript
Expand Down Expand Up @@ -280,6 +283,7 @@ async def retrieve_response(
client: AsyncLlamaStackClient,
responses_params: ResponsesApiParams,
moderation_result: ShieldModerationResult,
endpoint_path: str = "",
) -> TurnSummary:
"""
Retrieve response from LLMs and agents.
Expand Down Expand Up @@ -332,5 +336,9 @@ async def retrieve_response(
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
rag_id_mapping = configuration.rag_id_mapping
return build_turn_summary(
response, responses_params.model, vector_store_ids, rag_id_mapping
response,
responses_params.model,
endpoint_path,
vector_store_ids,
rag_id_mapping,
)
15 changes: 13 additions & 2 deletions src/app/endpoints/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,11 @@ async def responses_endpoint_handler(
)
attachments_text = extract_attachments_text(original_request.input)

endpoint_path = "/v1/responses"
moderation_result = await run_shield_moderation(
client,
input_text + "\n\n" + attachments_text,
endpoint_path,
original_request.shield_ids,
)

Expand Down Expand Up @@ -388,6 +390,7 @@ async def responses_endpoint_handler(
background_tasks=background_tasks,
rh_identity_context=rh_identity_context,
user_agent=_get_user_agent(request),
endpoint_path=endpoint_path,
)


Expand All @@ -404,6 +407,7 @@ async def handle_streaming_response(
background_tasks: Optional[BackgroundTasks] = None,
rh_identity_context: tuple[str, str] = ("", ""),
user_agent: Optional[str] = None,
endpoint_path: str = "",
) -> StreamingResponse:
"""Handle streaming response from Responses API.

Expand Down Expand Up @@ -470,6 +474,7 @@ async def handle_streaming_response(
turn_summary=turn_summary,
inline_rag_context=inline_rag_context,
filter_server_tools=filter_server_tools,
endpoint_path=endpoint_path,
)
except RuntimeError as e: # library mode wraps 413 into runtime error
if is_context_length_error(str(e)):
Expand Down Expand Up @@ -798,6 +803,7 @@ async def response_generator(
turn_summary: TurnSummary,
inline_rag_context: RAGContext,
filter_server_tools: bool = False,
endpoint_path: str = "",
) -> AsyncIterator[str]:
"""Generate SSE-formatted streaming response with LCORE-enriched events.

Expand All @@ -810,6 +816,7 @@ async def response_generator(
turn_summary: TurnSummary to populate during streaming
inline_rag_context: Inline RAG context to be used for the response
filter_server_tools: Whether to filter server-deployed MCP tool events from the stream
endpoint_path: API endpoint path used for metric labeling.
Yields:
SSE-formatted strings for streaming events, ending with [DONE]
"""
Expand Down Expand Up @@ -873,7 +880,7 @@ async def response_generator(

# Extract and consume tokens if any were used
turn_summary.token_usage = extract_token_usage(
latest_response_object.usage, api_params.model
latest_response_object.usage, api_params.model, endpoint_path
)
consume_query_tokens(
user_id=user_id,
Expand Down Expand Up @@ -1010,6 +1017,7 @@ async def handle_non_streaming_response(
background_tasks: Optional[BackgroundTasks] = None,
rh_identity_context: tuple[str, str] = ("", ""),
user_agent: Optional[str] = None,
endpoint_path: str = "",
) -> ResponsesResponse:
"""Handle non-streaming response from Responses API.

Expand Down Expand Up @@ -1069,7 +1077,9 @@ async def handle_non_streaming_response(
**api_params.model_dump(exclude_none=True)
),
)
token_usage = extract_token_usage(api_response.usage, api_params.model)
token_usage = extract_token_usage(
api_response.usage, api_params.model, endpoint_path
)
logger.info("Consuming tokens")
consume_query_tokens(
user_id=user_id,
Expand Down Expand Up @@ -1152,6 +1162,7 @@ async def handle_non_streaming_response(
turn_summary = build_turn_summary(
api_response,
api_params.model,
endpoint_path,
vector_store_ids,
configuration.rag_id_mapping,
filter_server_tools=filter_server_tools,
Expand Down
37 changes: 29 additions & 8 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ async def retrieve_simple_response(
instructions: str,
tools: Optional[list[Any]] = None,
model_id: Optional[str] = None,
endpoint_path: str = "/v1/infer",
) -> str:
"""Retrieve a simple response from the LLM for a stateless query.

Expand All @@ -263,7 +264,7 @@ async def retrieve_simple_response(
"""
resolved_model_id = model_id or await _get_default_model_id()
response = await _call_llm(question, instructions, tools, resolved_model_id)
extract_token_usage(response.usage, resolved_model_id)
extract_token_usage(response.usage, resolved_model_id, endpoint_path)
return extract_text_from_response_items(response.output)


Expand Down Expand Up @@ -366,12 +367,13 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
background_tasks.add_task(send_splunk_event, event, sourcetype)


async def _check_shield_moderation(
async def _check_shield_moderation( # pylint: disable=too-many-arguments,too-many-positional-arguments
input_text: str,
request_id: str,
background_tasks: BackgroundTasks,
infer_request: RlsapiV1InferRequest,
request: Request,
endpoint_path: str,
) -> Optional[RlsapiV1InferResponse]:
"""Run shield moderation and return a refusal response if blocked.

Expand All @@ -384,13 +386,14 @@ async def _check_shield_moderation(
background_tasks: FastAPI background tasks for async Splunk event sending.
infer_request: The original inference request (for Splunk event context).
request: The FastAPI request object (for Splunk event context).
endpoint_path: The API endpoint path for metric labeling.

Returns:
An RlsapiV1InferResponse containing the refusal message if the input
was blocked, or None if moderation passed.
"""
client = AsyncLlamaStackClientHolder().get_client()
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(client, input_text, endpoint_path)

if moderation_result.decision != "blocked":
return None
Expand Down Expand Up @@ -432,6 +435,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
start_time: float,
model: str,
provider: str,
endpoint_path: str,
) -> float:
"""Record metrics and queue Splunk event for an inference failure.

Expand All @@ -442,12 +446,15 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
request_id: Unique identifier for the request.
error: The exception that caused the failure.
start_time: Monotonic clock time when inference started.
model: The model name.
provider: The provider name.
endpoint_path: The API endpoint path for metric labeling.

Returns:
The total inference time in seconds.
"""
inference_time = time.monotonic() - start_time
recording.record_llm_failure(provider, model)
recording.record_llm_failure(provider, model, endpoint_path)
_queue_splunk_event(
background_tasks,
infer_request,
Expand Down Expand Up @@ -530,6 +537,7 @@ def _build_infer_response(
request_id: str,
response: Optional[OpenAIResponseObject],
model_id: str,
endpoint_path: str,
) -> RlsapiV1InferResponse:
"""Build the final inference response, with optional verbose metadata.

Expand All @@ -549,7 +557,11 @@ def _build_infer_response(
"""
if response is not None:
turn_summary = build_turn_summary(
response, model_id, vector_store_ids=None, rag_id_mapping=None
response,
model_id,
endpoint_path,
vector_store_ids=None,
rag_id_mapping=None,
)
return RlsapiV1InferResponse(
data=RlsapiV1InferData(
Expand Down Expand Up @@ -673,12 +685,19 @@ async def infer_endpoint( # pylint: disable=R0914
"Request %s: Combined input source length: %d", request_id, len(input_source)
)

endpoint_path = "/v1/infer"

# Run shield moderation on user input before inference.
# Uses all configured shields; no-op when no shields are registered.
# Runs before model/tool discovery so blocked requests short-circuit
# without incurring external I/O.
blocked_response = await _check_shield_moderation(
input_source, request_id, background_tasks, infer_request, request
input_source,
request_id,
background_tasks,
infer_request,
request,
endpoint_path,
)
if blocked_response is not None:
return blocked_response
Expand All @@ -700,11 +719,11 @@ async def infer_endpoint( # pylint: disable=R0914
model_id=model_id,
)
response_text = extract_text_from_response_items(response.output)
token_usage = extract_token_usage(response.usage, model_id)
token_usage = extract_token_usage(response.usage, model_id, endpoint_path)
inference_time = time.monotonic() - start_time
except _INFER_HANDLED_EXCEPTIONS as error:
if response is not None:
extract_token_usage(response.usage, model_id) # type: ignore[arg-type]
extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type]
_record_inference_failure(
background_tasks,
infer_request,
Expand All @@ -714,6 +733,7 @@ async def infer_endpoint( # pylint: disable=R0914
start_time,
model,
provider,
endpoint_path,
)
mapped_error = _map_inference_error_to_http_exception(
error,
Expand Down Expand Up @@ -755,4 +775,5 @@ async def infer_endpoint( # pylint: disable=R0914
request_id,
response if verbose_enabled else None,
model_id,
endpoint_path,
)
13 changes: 10 additions & 3 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
# Moderation input is the raw user content (query + attachments) without injected RAG
# context, to avoid false positives from retrieved document content.
moderation_input = prepare_input(query_request)
endpoint_path = "/v1/streaming_query"
moderation_result = await run_shield_moderation(
client, moderation_input, query_request.shield_ids
client, moderation_input, endpoint_path, query_request.shield_ids
)

# Build RAG context from Inline RAG sources
Expand Down Expand Up @@ -283,11 +284,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
provider_id, model_id = extract_provider_and_model_from_model_id(
responses_params.model
)
recording.record_llm_call(provider_id, model_id)
recording.record_llm_call(provider_id, model_id, endpoint_path)

generator, turn_summary = await retrieve_response_generator(
responses_params=responses_params,
context=context,
endpoint_path=endpoint_path,
)

# Combine inline RAG results (BYOK + Solr) with tool-based results
Expand Down Expand Up @@ -316,6 +318,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
async def retrieve_response_generator(
responses_params: ResponsesApiParams,
context: ResponseGeneratorContext,
endpoint_path: str,
) -> tuple[AsyncIterator[str], TurnSummary]:
"""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Retrieve the appropriate response generator.
Expand All @@ -327,6 +330,7 @@ async def retrieve_response_generator(
Args:
responses_params: The Responses API parameters
context: The response generator context
endpoint_path: API endpoint path used for metric labeling.
Returns:
tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary

Expand Down Expand Up @@ -360,6 +364,7 @@ async def retrieve_response_generator(
response,
context,
turn_summary,
endpoint_path,
),
turn_summary,
)
Expand Down Expand Up @@ -685,6 +690,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
turn_response: AsyncIterator[OpenAIResponseObjectStream],
context: ResponseGeneratorContext,
turn_summary: TurnSummary,
endpoint_path: str,
) -> AsyncIterator[str]:
"""Generate SSE formatted streaming response.

Expand All @@ -696,6 +702,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
turn_response: The streaming response from Llama Stack
context: The response generator context
turn_summary: TurnSummary to populate during streaming
endpoint_path: API endpoint path used for metric labeling.

Yields:
SSE-formatted strings for tokens, tool calls, tool results,
Expand Down Expand Up @@ -862,7 +869,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
return

turn_summary.token_usage = extract_token_usage(
latest_response_object.usage, context.model_id
latest_response_object.usage, context.model_id, endpoint_path
)
# Parse tool-based referenced documents from the final response object
tool_rag_docs = parse_referenced_documents(
Expand Down
14 changes: 9 additions & 5 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,29 @@

# Metric that counts how many LLM calls were made for each provider + model
llm_calls_total = Counter(
"ls_llm_calls_total", "LLM calls counter", ["provider", "model"]
"ls_llm_calls_total", "LLM calls counter", ["provider", "model", "endpoint"]
)

# Metric that counts how many LLM calls failed
llm_calls_failures_total = Counter(
"ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"]
"ls_llm_calls_failures_total",
"LLM calls failures",
["provider", "model", "endpoint"],
)

# Metric that counts how many LLM calls had validation errors
llm_calls_validation_errors_total = Counter(
"ls_llm_validation_errors_total", "LLM validation errors"
"ls_llm_validation_errors_total", "LLM validation errors", ["endpoint"]
)

# Metric that counts how many tokens were sent to LLMs
llm_token_sent_total = Counter(
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"]
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model", "endpoint"]
)

# Metric that counts how many tokens were received from LLMs
llm_token_received_total = Counter(
"ls_llm_token_received_total", "LLM tokens received", ["provider", "model"]
"ls_llm_token_received_total",
"LLM tokens received",
["provider", "model", "endpoint"],
Comment on lines 31 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Centralize the allowed endpoint label values.

These counters now depend on endpoint strings supplied from several call chains, so a typo or empty fallback will create a new Prometheus series silently. Please move the four supported endpoint paths behind shared constants (or an enum/Literal) and reuse them at every metric call site to preserve the bounded-cardinality guarantee.

Based on learnings: Check constants.py for shared constants before defining new ones.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/metrics/__init__.py` around lines 31 - 58, The endpoint label values used
by the Prometheus counters llm_calls_total, llm_calls_failures_total,
llm_calls_validation_errors_total, llm_token_sent_total and
llm_token_received_total must be centralized: define a shared set/enum/Literal
of allowed endpoint strings in constants.py (or reuse existing ones there) and
replace any hardcoded/inline endpoint values at each metric emission site to
reference that central constant to prevent new series from typos/empties; update
any code that constructs the label value for the "endpoint" label to
validate/normalize against that central constant set and fall back to a single
explicit "unknown" constant if needed.

)
Loading
Loading