Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 106 additions & 24 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,18 @@ async def _get_default_model_id() -> str:
provider_id = configuration.inference.default_provider

if model_id and provider_id:
logger.info(
"Using configured default model for rlsapi v1: %s/%s",
provider_id,
model_id,
)
return f"{provider_id}/{model_id}"

# 2. Auto-discover from Llama Stack
logger.info(
"No complete default model configured for rlsapi v1, "
"auto-discovering LLM model"
)
client = AsyncLlamaStackClientHolder().get_client()
try:
models = await client.models.list()
Expand Down Expand Up @@ -255,6 +264,7 @@ async def _resolve_validated_model_id() -> str:
_, model_name = extract_provider_and_model_from_model_id(model_id)
error_response = NotFoundResponse(resource="model", resource_id=model_name)
raise HTTPException(**error_response.model_dump())
logger.info("Validated rlsapi v1 model availability: %s", model_id)
return model_id


Expand Down Expand Up @@ -387,6 +397,11 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position

event = build_inference_event(event_data)
background_tasks.add_task(send_splunk_event, event, sourcetype)
logger.info(
"Queued rlsapi v1 Splunk event for request %s with sourcetype %s",
request_id,
sourcetype,
)


async def _check_shield_moderation( # pylint: disable=too-many-arguments,too-many-positional-arguments
Expand Down Expand Up @@ -415,16 +430,14 @@ async def _check_shield_moderation( # pylint: disable=too-many-arguments,too-ma
was blocked, or None if moderation passed.
"""
client = AsyncLlamaStackClientHolder().get_client()
logger.info("Running shield moderation for rlsapi v1 request %s", request_id)
moderation_result = await run_shield_moderation(client, input_text, endpoint_path)

if moderation_result.decision != "blocked":
logger.info("Shield moderation passed for rlsapi v1 request %s", request_id)
return None

logger.info(
"Request %s blocked by shield moderation: %s",
request_id,
moderation_result.message,
)
logger.info("Shield moderation blocked rlsapi v1 request %s", request_id)
_queue_splunk_event(
background_tasks,
infer_request,
Expand Down Expand Up @@ -480,16 +493,20 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
recording.record_llm_inference_duration(
provider, model, endpoint_path, "failure", inference_time
)
redacted_error = _redact_sensitive_error_text(str(error))
_queue_splunk_event(
background_tasks,
infer_request,
request,
request_id,
redacted_error,
type(error).__name__,
inference_time,
"infer_error",
)
logger.info(
"Recorded rlsapi v1 inference failure for request %s in %.3f seconds",
request_id,
inference_time,
)
return inference_time


Expand Down Expand Up @@ -627,22 +644,34 @@ def _map_inference_error_to_http_exception( # pylint: disable=too-many-return-s
"""
if isinstance(error, TemplateRenderError):
logger.error(
"Invalid system prompt template for request %s: %s", request_id, error
"Invalid system prompt template for request %s: %s",
request_id,
type(error).__name__,
)
error_response = InternalServerErrorResponse.generic()
return HTTPException(**error_response.model_dump())

if isinstance(error, RuntimeError):
if is_context_length_error(str(error)):
logger.error("Prompt too long for request %s: %s", request_id, error)
logger.error(
"Prompt too long for request %s: %s",
request_id,
type(error).__name__,
)
error_response = PromptTooLongResponse(model=model_id)
return HTTPException(**error_response.model_dump())
logger.error("Unexpected RuntimeError for request %s: %s", request_id, error)
logger.error(
"Unexpected RuntimeError for request %s: %s",
request_id,
type(error).__name__,
)
return None

if isinstance(error, APIConnectionError):
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, error
"Unable to connect to Llama Stack for request %s: %s",
request_id,
type(error).__name__,
)
error_response = ServiceUnavailableResponse(
backend_name="Llama Stack",
Expand All @@ -651,15 +680,19 @@ def _map_inference_error_to_http_exception( # pylint: disable=too-many-return-s
return HTTPException(**error_response.model_dump())

if isinstance(error, RateLimitError):
logger.error("Rate limit exceeded for request %s: %s", request_id, error)
logger.error(
"Rate limit exceeded for request %s: %s",
request_id,
type(error).__name__,
)
error_response = QuotaExceededResponse(
response="The quota has been exceeded",
cause="Rate limit exceeded, please try again later",
)
return HTTPException(**error_response.model_dump())

if isinstance(error, (APIStatusError, OpenAIAPIStatusError)):
logger.exception("API error for request %s: %s", request_id, error)
logger.error("API error for request %s: %s", request_id, type(error).__name__)
error_response = handle_known_apistatus_errors(error, model_id)
return HTTPException(**error_response.model_dump())

Expand All @@ -668,7 +701,7 @@ def _map_inference_error_to_http_exception( # pylint: disable=too-many-return-s

@router.post("/infer", responses=infer_responses, response_model_exclude_none=True)
@authorize(Action.RLSAPI_V1_INFER)
async def infer_endpoint( # pylint: disable=R0914
async def infer_endpoint( # pylint: disable=R0914,R0915
infer_request: RlsapiV1InferRequest,
request: Request,
background_tasks: BackgroundTasks,
Expand All @@ -695,21 +728,32 @@ async def infer_endpoint( # pylint: disable=R0914
"""
# Authentication enforced by get_auth_dependency(), authorization by @authorize decorator.
check_configuration_loaded(configuration)
endpoint_path = ENDPOINT_PATH_INFER
request_id = get_suid()

logger.info("Processing rlsapi v1 /infer request %s", request_id)

# Quota enforcement: resolve subject and check availability before any work.
# No-op when quota_subject is not configured or no quota limiters exist.
quota_id = _resolve_quota_subject(request, auth)
if quota_id is not None:
logger.info(
"Checking quota availability for rlsapi v1 request %s using subject type %s",
request_id,
configuration.rlsapi_v1.quota_subject,
)
check_tokens_available(configuration.quota_limiters, quota_id)

endpoint_path = ENDPOINT_PATH_INFER

request_id = get_suid()

logger.info("Processing rlsapi v1 /infer request %s", request_id)
logger.info(
"Quota availability check passed for rlsapi v1 request %s", request_id
)
else:
logger.info("Quota enforcement disabled for rlsapi v1 request %s", request_id)

input_source = infer_request.get_input_source()
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
logger.info(
"Prepared rlsapi v1 request %s input source; metadata requested: %s",
request_id,
infer_request.include_metadata,
)

# Run shield moderation on user input before inference.
Expand All @@ -729,13 +773,30 @@ async def infer_endpoint( # pylint: disable=R0914

model_id = await _resolve_validated_model_id()
provider, model = extract_provider_and_model_from_model_id(model_id)
logger.info(
"Resolved rlsapi v1 request %s model provider=%s model=%s",
request_id,
provider,
model,
)
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
logger.info(
"Retrieved %d MCP tools for rlsapi v1 request %s",
len(mcp_tools),
request_id,
)

start_time = time.monotonic()
verbose_enabled = _is_verbose_enabled(infer_request)
logger.info(
"Starting LLM call for rlsapi v1 request %s with verbose metadata enabled: %s",
request_id,
verbose_enabled,
)

response = None
try:
logger.info("Building instructions for rlsapi v1 request %s", request_id)
instructions = _build_instructions(infer_request.context.systeminfo)
response = await _call_llm(
input_source,
Expand All @@ -749,9 +810,17 @@ async def infer_endpoint( # pylint: disable=R0914
recording.record_llm_inference_duration(
provider, model, endpoint_path, "success", inference_time
)
logger.info(
"LLM call completed for rlsapi v1 request %s in %.3f seconds "
"with %d input tokens and %d output tokens",
request_id,
inference_time,
token_usage.input_tokens,
token_usage.output_tokens,
)
except _INFER_HANDLED_EXCEPTIONS as error:
if response is not None:
extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type]
extract_token_usage(response.usage, model_id, endpoint_path)
_record_inference_failure(
background_tasks,
infer_request,
Expand All @@ -778,11 +847,20 @@ async def infer_endpoint( # pylint: disable=R0914

# Consume quota tokens after successful inference.
if quota_id is not None:
logger.info(
"Consuming quota tokens for rlsapi v1 request %s: input=%d output=%d",
request_id,
token_usage.input_tokens,
token_usage.output_tokens,
)
consume_query_tokens(
user_id=quota_id,
model_id=model_id,
token_usage=token_usage,
)
logger.info(
"Quota token consumption completed for rlsapi v1 request %s", request_id
)

_queue_splunk_event(
background_tasks,
Expand All @@ -796,7 +874,11 @@ async def infer_endpoint( # pylint: disable=R0914
output_tokens=token_usage.output_tokens,
)

logger.info("Completed rlsapi v1 /infer request %s", request_id)
logger.info(
"Completed rlsapi v1 /infer request %s in %.3f seconds",
request_id,
inference_time,
)

return _build_infer_response(
response_text,
Expand Down
Loading
Loading