diff --git a/bentoml/testing/pytest/plugin.py b/bentoml/testing/pytest/plugin.py index b66f3020c0c..6cfb972a44f 100644 --- a/bentoml/testing/pytest/plugin.py +++ b/bentoml/testing/pytest/plugin.py @@ -293,7 +293,7 @@ def bin_file(tmpdir: str) -> str: return str(bin_file_) -@pytest.fixture(scope="module", name="prometheus_client") +@pytest.fixture(scope="module", name="prom_client") def fixture_metrics_client() -> PrometheusClient: """This fixtures return a PrometheusClient instance that can be used for testing.""" return BentoMLContainer.metrics_client.get() diff --git a/tests/unit/grpc/interceptors/test_prometheus.py b/tests/unit/grpc/interceptors/test_prometheus.py index 6f13b35000c..2a894c6ce46 100644 --- a/tests/unit/grpc/interceptors/test_prometheus.py +++ b/tests/unit/grpc/interceptors/test_prometheus.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys import typing as t import tempfile from typing import TYPE_CHECKING @@ -9,21 +8,23 @@ import pytest +import sys + from bentoml.testing.grpc import create_channel from bentoml.testing.grpc import async_client_call from bentoml.testing.grpc import create_bento_servicer from bentoml.testing.grpc import make_standalone_server -from bentoml._internal.server.metrics.prometheus import PrometheusClient +from bentoml._internal.configuration.containers import BentoMLContainer +from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor + if TYPE_CHECKING: import grpc - from _pytest.python import Metafunc from google.protobuf import wrappers_pb2 from bentoml import Service from bentoml.grpc.v1alpha1 import service_pb2_grpc as services from bentoml.grpc.v1alpha1 import service_test_pb2 as pb_test - from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor else: from bentoml.grpc.utils import import_grpc from bentoml.grpc.utils import import_generated_stubs @@ -34,56 +35,39 @@ wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") grpc, aio = import_grpc() +prom_dir = tempfile.mkdtemp("prometheus-multiproc") +BentoMLContainer.prometheus_multiproc_dir.set(prom_dir) +interceptor = PrometheusServerInterceptor() -def pytest_generate_tests(metafunc: Metafunc): - if "prometheus_interceptor" in metafunc.fixturenames: - from bentoml._internal.configuration.containers import BentoMLContainer - - prom_dir = tempfile.mkdtemp("prometheus-multiproc-unit") - BentoMLContainer.prometheus_multiproc_dir.set(prom_dir) - - -@pytest.fixture(scope="module") -def prometheus_interceptor() -> PrometheusServerInterceptor: - from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor - - return PrometheusServerInterceptor() +if "prometheus_client" in sys.modules: + mods = [m for m in sys.modules if "prometheus_client" in m] + list(map(lambda s: sys.modules.pop(s), mods)) + if not interceptor._is_setup: + interceptor._setup() @pytest.mark.asyncio -async def test_metrics_invocation( - prometheus_interceptor: PrometheusServerInterceptor, - mock_unary_unary_handler: MagicMock, -): - # This is to cleanup prometheus_client from previous tests - # that imports prometheus_client into sys.modules - # We don't want to disable multiproc since we want to test it. - # This line has to do with - if "prometheus_client" in sys.modules: - sys.modules.pop("prometheus_client") - +async def test_metrics_invocation(mock_unary_unary_handler: MagicMock): mhandler_call_details = MagicMock(spec=grpc.HandlerCallDetails) mcontinuation = MagicMock(return_value=Future()) mcontinuation.return_value.set_result(mock_unary_unary_handler) - await prometheus_interceptor.intercept_service(mcontinuation, mhandler_call_details) + await interceptor.intercept_service(mcontinuation, mhandler_call_details) assert mcontinuation.call_count == 1 - assert prometheus_interceptor._is_setup # type: ignore # pylint: disable=protected-access + assert interceptor._is_setup # type: ignore # pylint: disable=protected-access assert ( - prometheus_interceptor.metrics_request_duration - and prometheus_interceptor.metrics_request_total - and prometheus_interceptor.metrics_request_in_progress + interceptor.metrics_request_duration + and interceptor.metrics_request_total + and interceptor.metrics_request_in_progress ) @pytest.mark.asyncio -async def test_empty_metrics( - prometheus_interceptor: PrometheusServerInterceptor, - prometheus_client: PrometheusClient, -): +async def test_empty_metrics(): + metrics_client = BentoMLContainer.metrics_client.get() # This test a branch where we change inside the handler whether or not the incoming # handler contains pb.Request # if it isn't a pb.Request, then we just pass the handler, hence metrics should be empty - with make_standalone_server(interceptors=[prometheus_interceptor]) as ( + with make_standalone_server(interceptors=[interceptor]) as ( server, host_url, ): @@ -100,7 +84,7 @@ async def test_empty_metrics( Execute(pb_test.ExecuteRequest(input="BentoML")), ) await resp - assert prometheus_client.generate_latest() == b"" + assert metrics_client.generate_latest() == b"" finally: await server.stop(None) @@ -121,13 +105,13 @@ async def test_empty_metrics( ], ) async def test_metrics_interceptors( - prometheus_interceptor: PrometheusServerInterceptor, - prometheus_client: PrometheusClient, simple_service: Service, metric_type: str, parent_set: list[str], ): - with make_standalone_server(interceptors=[prometheus_interceptor]) as ( + metrics_client = BentoMLContainer.metrics_client.get() + + with make_standalone_server(interceptors=[interceptor]) as ( server, host_url, ): @@ -142,7 +126,7 @@ async def test_metrics_interceptors( channel=channel, data={"text": wrappers_pb2.StringValue(value="BentoML")}, ) - for m in prometheus_client.text_string_to_metric_families(): + for m in metrics_client.text_string_to_metric_families(): for sample in m.samples: if m.type == metric_type: assert set(sample.labels).issubset(set(parent_set))