diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index 52cf161b..cf15a0d8 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -7,7 +7,7 @@ from starlette.responses import PlainTextResponse, Response from starlette.routing import Route -from huggingface_inference_toolkit.async_utils import async_handler_call +from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_handler_call from huggingface_inference_toolkit.const import ( HF_FRAMEWORK, HF_HUB_TOKEN, @@ -69,6 +69,18 @@ async def health(request): return PlainTextResponse("Ok") +# Report Prometheus metrics +# inf_batch_current_size: Current number of requests being processed +# inf_queue_size: Number of requests waiting in the queue +async def metrics(request): + batch_current_size = MAX_CONCURRENT_THREADS - MAX_THREADS_GUARD.value + queue_size = MAX_THREADS_GUARD.statistics().tasks_waiting + return PlainTextResponse( + f"inf_batch_current_size {batch_current_size}\n" + + f"inf_queue_size {queue_size}\n" + ) + + async def predict(request): try: # extracts content from request @@ -143,6 +155,7 @@ async def predict(request): Route("/health", health, methods=["GET"]), Route("/", predict, methods=["POST"]), Route("/predict", predict, methods=["POST"]), + Route("/metrics", metrics, methods=["GET"]), ], on_startup=[prepare_model_artifacts], )