diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 749fa34f..3829305b 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -53,6 +53,11 @@ validate_conversation_ownership, validate_model_provider_override, ) +from utils.quota import ( + get_available_quotas, + check_tokens_available, + consume_tokens, +) from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency from utils.transcripts import store_transcript from utils.types import TurnSummary @@ -273,6 +278,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 logger.debug("Query does not contain conversation ID") try: + check_tokens_available(configuration.quota_limiters, user_id) # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() llama_stack_model_id, model_id, provider_id = select_model_and_provider_id( @@ -344,6 +350,13 @@ async def query_endpoint_handler( # pylint: disable=R0914 referenced_documents=referenced_documents if referenced_documents else None, ) + consume_tokens( + configuration.quota_limiters, + user_id, + input_tokens=token_usage.input_tokens, + output_tokens=token_usage.output_tokens, + ) + store_conversation_into_cache( configuration, user_id, @@ -372,6 +385,8 @@ async def query_endpoint_handler( # pylint: disable=R0914 logger.info("Using referenced documents from response...") + available_quotas = get_available_quotas(configuration.quota_limiters, user_id) + logger.info("Building final response...") response = QueryResponse( conversation_id=conversation_id, @@ -382,7 +397,7 @@ async def query_endpoint_handler( # pylint: disable=R0914 truncated=False, # TODO: implement truncation detection input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens, - available_quotas={}, # TODO: implement quota tracking + available_quotas=available_quotas, ) logger.info("Query processing completed successfully!") return response diff --git a/src/utils/quota.py b/src/utils/quota.py new file mode 100644 index 00000000..a7a046b5 --- /dev/null +++ b/src/utils/quota.py @@ -0,0 +1,101 @@ +"""Quota handling helper functions.""" + +import psycopg2 + +from fastapi import HTTPException, status + +from quota.quota_limiter import QuotaLimiter +from quota.quota_exceed_error import QuotaExceedError + +from log import get_logger + +logger = get_logger(__name__) + + +def consume_tokens( + quota_limiters: list[QuotaLimiter], + user_id: str, + input_tokens: int, + output_tokens: int, +) -> None: + """Consume tokens from cluster and/or user quotas. + + Args: + quota_limiters: List of quota limiter instances to consume tokens from. + user_id: Identifier of the user consuming tokens. + input_tokens: Number of input tokens to consume. + output_tokens: Number of output tokens to consume. + + Returns: + None + """ + # consume tokens all configured quota limiters + for quota_limiter in quota_limiters: + quota_limiter.consume_tokens( + input_tokens=input_tokens, + output_tokens=output_tokens, + subject_id=user_id, + ) + + +def check_tokens_available(quota_limiters: list[QuotaLimiter], user_id: str) -> None: + """Check if tokens are available for user. + + Args: + quota_limiters: List of quota limiter instances to check. + user_id: Identifier of the user to check quota for. + + Returns: + None + + Raises: + HTTPException: With status 500 if database communication fails, + or status 429 if quota is exceeded. + """ + try: + # check available tokens using all configured quota limiters + for quota_limiter in quota_limiters: + quota_limiter.ensure_available_quota(subject_id=user_id) + except psycopg2.Error as pg_error: + message = "Error communicating with quota database backend" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": message, + "cause": str(pg_error), + }, + ) from pg_error + except QuotaExceedError as quota_exceed_error: + message = "The quota has been exceeded" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "response": message, + "cause": str(quota_exceed_error), + }, + ) from quota_exceed_error + + +def get_available_quotas( + quota_limiters: list[QuotaLimiter], + user_id: str, +) -> dict[str, int]: + """Get quota available from all quota limiters. + + Args: + quota_limiters: List of quota limiter instances to query. + user_id: Identifier of the user to get quotas for. + + Returns: + Dictionary mapping quota limiter class names to available token counts. + """ + available_quotas: dict[str, int] = {} + + # retrieve available tokens using all configured quota limiters + for quota_limiter in quota_limiters: + name = quota_limiter.__class__.__name__ + available_quota = quota_limiter.available_quota(user_id) + available_quotas[name] = available_quota + return available_quotas diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 51943d15..a96c4f00 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -184,6 +184,7 @@ async def _test_query_endpoint_handler( mock_config.user_data_collection_configuration.transcripts_enabled = ( store_transcript_to_file ) + mock_config.quota_limiters = [] mocker.patch("app.endpoints.query.configuration", mock_config) mock_store_in_cache = mocker.patch( @@ -1434,6 +1435,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler( # Mock dependencies mock_config = mocker.Mock() mock_config.llama_stack_configuration = mocker.Mock() + mock_config.quota_limiters = [] mocker.patch("app.endpoints.query.configuration", mock_config) mock_client = mocker.AsyncMock() @@ -1499,6 +1501,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker, dummy_request) -> No mock_config = mocker.Mock() mock_config.user_data_collection_configuration.transcripts_disabled = True + mock_config.quota_limiters = [] mocker.patch("app.endpoints.query.configuration", mock_config) summary = TurnSummary( @@ -1555,6 +1558,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker, dummy_request) -> N mock_config = mocker.Mock() mock_config.user_data_collection_configuration.transcripts_disabled = True + mock_config.quota_limiters = [] mocker.patch("app.endpoints.query.configuration", mock_config) summary = TurnSummary(