From 866a28a7419851ad73ff456a50a9d55fa61b8387 Mon Sep 17 00:00:00 2001 From: Pratik Garg Date: Mon, 11 May 2026 23:41:32 -0700 Subject: [PATCH] Implement Prometheus metrics emission for Orbax. PiperOrigin-RevId: 914097006 --- checkpoint/orbax/__init__.py | 4 + checkpoint/orbax/checkpoint/__init__.py | 4 + .../_src/logging/google_monitoring.py | 14 + .../_src/logging/google_monitoring_test.py | 14 + .../checkpoint/_src/logging/monitoring.py | 127 +++++++++ .../_src/logging/prometheus_monitoring.py | 255 ++++++++++++++++++ .../logging/prometheus_monitoring_test.py | 178 ++++++++++++ checkpoint/pyproject.toml | 1 + 8 files changed, 597 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/_src/logging/google_monitoring.py create mode 100644 checkpoint/orbax/checkpoint/_src/logging/google_monitoring_test.py create mode 100644 checkpoint/orbax/checkpoint/_src/logging/monitoring.py create mode 100644 checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring.py create mode 100644 checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring_test.py diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index 0a1693db1..4ae0da973 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -90,3 +90,7 @@ __version__ = version.__version__ del version +from orbax.checkpoint._src.logging import monitoring as _orbax_monitoring + +_orbax_monitoring.initialize_from_env() +del _orbax_monitoring diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index 0a1693db1..4ae0da973 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -90,3 +90,7 @@ __version__ = version.__version__ del version +from orbax.checkpoint._src.logging import monitoring as _orbax_monitoring + +_orbax_monitoring.initialize_from_env() +del _orbax_monitoring diff --git a/checkpoint/orbax/checkpoint/_src/logging/google_monitoring.py b/checkpoint/orbax/checkpoint/_src/logging/google_monitoring.py new file mode 100644 index 000000000..0f9e3b439 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/logging/google_monitoring.py @@ -0,0 +1,14 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/checkpoint/orbax/checkpoint/_src/logging/google_monitoring_test.py b/checkpoint/orbax/checkpoint/_src/logging/google_monitoring_test.py new file mode 100644 index 000000000..0f9e3b439 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/logging/google_monitoring_test.py @@ -0,0 +1,14 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/checkpoint/orbax/checkpoint/_src/logging/monitoring.py b/checkpoint/orbax/checkpoint/_src/logging/monitoring.py new file mode 100644 index 000000000..d9124a7c8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/logging/monitoring.py @@ -0,0 +1,127 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Orbax metrics telemetry base.""" + +from __future__ import annotations + +import abc +import logging +import os +import threading +from typing import Any + +_DEFAULT_PROMETHEUS_PORT = 9431 + + +class MetricRecorder(abc.ABC): + """Abstract base class for Orbax metric recorders.""" + + @abc.abstractmethod + def record_event(self, metric_name: str, **kwargs: Any) -> None: + """Records a named event with optional metadata.""" + pass + + @abc.abstractmethod + def record_scalar( + self, metric_name: str, value: float | int, **kwargs: Any + ) -> None: + """Records a scalar summary value with optional metadata.""" + pass + + @abc.abstractmethod + def record_duration( + self, metric_name: str, duration: float | int, **kwargs: Any + ) -> None: + """Records an event duration in seconds with optional metadata.""" + pass + + +_recorders: list[MetricRecorder] = [] # Protected by _init_lock +_initialized = False # Protected by _init_lock +_init_lock = threading.Lock() + + +def initialize(recorder: MetricRecorder) -> None: + """Registers a recorder and binds its methods to JAX monitoring listeners.""" + global _initialized + + with _init_lock: + _recorders.append(recorder) + + if _initialized: + return + + from jax import monitoring as jax_monitoring # pylint: disable=g-import-not-at-top + + def _proxy_record_event(metric_name: str, **kwargs: Any) -> None: + with _init_lock: + recorders_snapshot = _recorders[:] + for r in recorders_snapshot: + r.record_event(metric_name, **kwargs) + + def _proxy_record_scalar( + metric_name: str, value: float | int, **kwargs: Any + ) -> None: + with _init_lock: + recorders_snapshot = _recorders[:] + for r in recorders_snapshot: + r.record_scalar(metric_name, value, **kwargs) + + def _proxy_record_duration( + metric_name: str, duration: float | int, **kwargs: Any + ) -> None: + with _init_lock: + recorders_snapshot = _recorders[:] + for r in recorders_snapshot: + r.record_duration(metric_name, duration, **kwargs) + + jax_monitoring.register_event_listener(_proxy_record_event) + jax_monitoring.register_scalar_listener(_proxy_record_scalar) + jax_monitoring.register_event_duration_secs_listener(_proxy_record_duration) + + _initialized = True + logging.info('Installed JAX monitoring proxy listeners for Orbax.') + + +def initialize_from_env() -> None: + """Initializes monitoring based on environment and build type.""" + + enable_telemetry = os.environ.get( + 'ENABLE_ORBAX_PROMETHEUS_TELEMETRY', 'false' + ) + if enable_telemetry.lower() == 'true': + try: + from orbax.checkpoint._src.logging import prometheus_monitoring # pylint: disable=g-import-not-at-top + import multiprocessing # pylint: disable=g-import-not-at-top + + env_port = os.environ.get('ORBAX_PROMETHEUS_PORT') + default_port = _DEFAULT_PROMETHEUS_PORT + if env_port: + try: + default_port = int(env_port) + except ValueError: + logging.warning( + 'Invalid ORBAX_PROMETHEUS_PORT "%s". Falling back to default %d.', + env_port, + _DEFAULT_PROMETHEUS_PORT, + ) + port = ( + default_port + if multiprocessing.current_process().name == 'MainProcess' + else 0 + ) + initialize(prometheus_monitoring.PrometheusMonitoring(port=port)) + except ImportError as e: + logging.warning('Failed to import PrometheusMonitoring: %s', e) diff --git a/checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring.py b/checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring.py new file mode 100644 index 000000000..f4969a1b1 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring.py @@ -0,0 +1,255 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prometheus metrics exporter for Orbax. + +This module enables Orbax to export performance and operational metrics to +Prometheus, allowing users to monitor checkpointing behavior in open-source +environments. + +To enable telemetry collection, set the environment variable +`ENABLE_ORBAX_PROMETHEUS_TELEMETRY=true`. By default, telemetry collection +is disabled. +""" + +from __future__ import annotations + +import atexit +import importlib +import logging +import os +import shutil +import tempfile +import threading +from typing import Any + +from orbax.checkpoint._src.logging import monitoring + +# Keep a global reference so the directory is not deleted until the program +# exits. +_prometheus_multiproc_dir = None + +prometheus_client = None +_prom_counter = None +_prom_histogram = None + +enable_telemetry = os.environ.get('ENABLE_ORBAX_PROMETHEUS_TELEMETRY', 'false') +if enable_telemetry.lower() == 'true': + if 'PROMETHEUS_MULTIPROC_DIR' not in os.environ: + # Create a directory for prometheus multiprocessing. + _prometheus_multiproc_dir = tempfile.mkdtemp(prefix='prometheus_multiproc_') + os.environ['PROMETHEUS_MULTIPROC_DIR'] = _prometheus_multiproc_dir + _creator_pid = os.getpid() + + def _Cleanup(): + if os.getpid() == _creator_pid: + shutil.rmtree(_prometheus_multiproc_dir, ignore_errors=True) + + atexit.register(_Cleanup) + + try: + prometheus_client = importlib.import_module('prometheus_client') + _prom_counter = prometheus_client.Counter # pytype: disable=attribute-error + _prom_histogram = prometheus_client.Histogram # pytype: disable=attribute-error + except (ImportError, AttributeError): + pass + + +class PrometheusMonitoring(monitoring.MetricRecorder): + """Prometheus implementation of Orbax metric recorder.""" + + def __init__(self, port: int = 9431): + self._initialized = False + self._metrics = {} + self._lock = threading.Lock() + self._port = port + self._allowed_prefixes = ( + '/jax/orbax/write/', + '/jax/checkpoint/write/', + '/jax/orbax/read/', + ) + + if not prometheus_client: + logging.warning( + 'prometheus-client not found. Orbax metrics will not be reported to' + ' Prometheus.' + ) + return + + if port > 0: + self._start_server(port) + else: + # If port is 0, we assume it's a worker process in multiprocess mode, + # or server is started externally. We mark it initialized so it records. + self._initialized = True + + def record_event(self, metric_name: str, **kwargs: Any) -> None: + """JAX monitoring handler for events to route to prometheus-client.""" + if ( + not self._initialized + or not self._is_allowed(metric_name) + or not _prom_counter + ): + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + sorted_keys = sorted(kwargs.keys()) + labelnames = tuple(sorted_keys) + labelvalues = tuple(str(kwargs[k]) for k in sorted_keys) + + metric = self._get_or_create_metric( + metric_name, metric_name_safe, _prom_counter, labelnames + ) + if _prom_counter and isinstance(metric, _prom_counter): + if labelnames: + try: + metric.labels(*labelvalues).inc() + except ValueError as e: + logging.warning( + 'Failed to record Prometheus event "%s" due to label mismatch:' + ' %s. Provided keys: %s', + metric_name, + e, + sorted_keys, + ) + else: + metric.inc() + + def record_scalar( + self, metric_name: str, value: float | int, **kwargs: Any + ) -> None: + """JAX monitoring handler for scalars to route to prometheus-client.""" + if ( + not self._initialized + or not self._is_allowed(metric_name) + or not _prom_histogram + ): + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + sorted_keys = sorted(kwargs.keys()) + labelnames = tuple(sorted_keys) + labelvalues = tuple(str(kwargs[k]) for k in sorted_keys) + + metric = self._get_or_create_metric( + metric_name, metric_name_safe, _prom_histogram, labelnames + ) + if _prom_histogram and isinstance(metric, _prom_histogram): + if labelnames: + try: + metric.labels(*labelvalues).observe(value) + except ValueError as e: + logging.warning( + 'Failed to record Prometheus scalar "%s" due to label mismatch:' + ' %s. Provided keys: %s', + metric_name, + e, + sorted_keys, + ) + else: + metric.observe(value) + + def record_duration( + self, metric_name: str, duration: float | int, **kwargs: Any + ) -> None: + """JAX monitoring handler for duration to route to prometheus-client.""" + if ( + not self._initialized + or not self._is_allowed(metric_name) + or not _prom_histogram + ): + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + sorted_keys = sorted(kwargs.keys()) + labelnames = tuple(sorted_keys) + labelvalues = tuple(str(kwargs[k]) for k in sorted_keys) + + metric = self._get_or_create_metric( + metric_name, metric_name_safe, _prom_histogram, labelnames + ) + if _prom_histogram and isinstance(metric, _prom_histogram): + if labelnames: + try: + metric.labels(*labelvalues).observe(duration) + except ValueError as e: + logging.warning( + 'Failed to record Prometheus duration "%s" due to label mismatch:' + ' %s. Provided keys: %s', + metric_name, + e, + sorted_keys, + ) + else: + metric.observe(duration) + + def _start_server(self, port: int): + try: + multiprocess_started = False + if 'PROMETHEUS_MULTIPROC_DIR' in os.environ: + try: + multiprocess = importlib.import_module( + 'prometheus_client.multiprocess' + ) + registry = prometheus_client.CollectorRegistry() # pytype: disable=attribute-error + multiprocess.MultiProcessCollector(registry) # pytype: disable=attribute-error + prometheus_client.start_http_server(port, registry=registry) # pytype: disable=attribute-error + logging.info( + 'Prometheus multiprocess metrics server started on port %s.', + port, + ) + multiprocess_started = True + except (ImportError, AttributeError): + pass + + if not multiprocess_started: + # Standard single-process server + prometheus_client.start_http_server(port) # pytype: disable=attribute-error + logging.info('Prometheus metrics server started on port %s.', port) + self._initialized = True + except (OSError, ValueError) as e: + # Handle 'already in use' for Linux/macOS and Windows (10048). + if 'already in use' not in str(e) and '10048' not in str(e): + logging.warning('Failed to start Prometheus server: %s', e) + return + # If the server is already running (e.g. started by Grain), just + # register listeners. + logging.info('Prometheus server already active.') + self._initialized = True + + def _is_allowed(self, metric_name: str) -> bool: + """Returns True if the metric is allowed for Prometheus export.""" + return metric_name.startswith(self._allowed_prefixes) + + def _get_or_create_metric( + self, + metric_name: str, + metric_name_safe: str, + metric_class: Any, + labelnames: tuple[str, ...], + ) -> Any: + """Gets an existing metric or creates a new one thread-safely.""" + if metric_name_safe not in self._metrics: + with self._lock: + if metric_name_safe not in self._metrics: + try: + self._metrics[metric_name_safe] = metric_class( + metric_name_safe, metric_name, labelnames=labelnames + ) + except ValueError: + # pylint: disable=protected-access + self._metrics[metric_name_safe] = ( + prometheus_client.REGISTRY._names_to_collectors.get( # pytype: disable=attribute-error + metric_name_safe + ) + ) + # pylint: enable=protected-access + return self._metrics[metric_name_safe] diff --git a/checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring_test.py b/checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring_test.py new file mode 100644 index 000000000..176f2a758 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/logging/prometheus_monitoring_test.py @@ -0,0 +1,178 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Orbax Prometheus metrics telemetry.""" + +from __future__ import annotations + +import os +import shutil +import tempfile +from unittest import mock + +from absl.testing import absltest +from jax import monitoring as jax_monitoring +from orbax.checkpoint._src.logging import monitoring + + +class MonitoringTest(absltest.TestCase): + + def setUp(self): + super().setUp() + monitoring._recorders = [] + monitoring._initialized = False + + self.mock_register_event = self.enter_context( + mock.patch.object(jax_monitoring, 'register_event_listener') + ) + self.mock_register_scalar = self.enter_context( + mock.patch.object(jax_monitoring, 'register_scalar_listener') + ) + self.mock_register_duration = self.enter_context( + mock.patch.object( + jax_monitoring, 'register_event_duration_secs_listener' + ) + ) + + def test_proxy_initialization(self): + fake_recorder = mock.create_autospec(monitoring.MetricRecorder) + monitoring.initialize(fake_recorder) + + self.assertTrue(monitoring._initialized) + self.assertIn(fake_recorder, monitoring._recorders) + + self.mock_register_event.assert_called_once() + self.mock_register_scalar.assert_called_once() + self.mock_register_duration.assert_called_once() + + def test_proxy_forwarding(self): + fake_recorder = mock.create_autospec(monitoring.MetricRecorder) + monitoring.initialize(fake_recorder) + + proxy_event_fn = self.mock_register_event.call_args[0][0] + proxy_scalar_fn = self.mock_register_scalar.call_args[0][0] + proxy_duration_fn = self.mock_register_duration.call_args[0][0] + + proxy_event_fn('test_event', foo='bar') + fake_recorder.record_event.assert_called_once_with('test_event', foo='bar') + + proxy_scalar_fn('test_scalar', 123, bar='baz') + fake_recorder.record_scalar.assert_called_once_with( + 'test_scalar', 123, bar='baz' + ) + + proxy_duration_fn('test_duration', 0.5, baz='qux') + fake_recorder.record_duration.assert_called_once_with( + 'test_duration', 0.5, baz='qux' + ) + + def test_multiple_recorders(self): + r1 = mock.create_autospec(monitoring.MetricRecorder) + r2 = mock.create_autospec(monitoring.MetricRecorder) + + monitoring.initialize(r1) + monitoring.initialize(r2) + + self.assertLen(monitoring._recorders, 2) + + # Proxy should only be registered once + self.mock_register_event.assert_called_once() + + proxy_event_fn = self.mock_register_event.call_args[0][0] + proxy_event_fn('test_event') + + r1.record_event.assert_called_once_with('test_event') + r2.record_event.assert_called_once_with('test_event') + + +class PrometheusMonitoringTest(absltest.TestCase): + + def setUp(self): + super().setUp() + + # Set environment variable before importing prometheus_monitoring + self.enter_context( + mock.patch.dict( + os.environ, {'ENABLE_ORBAX_PROMETHEUS_TELEMETRY': 'true'} + ) + ) + + # Ensure PROMETHEUS_MULTIPROC_DIR is set for all tests, as module-level + # initialization might only run once due to module caching. + if 'PROMETHEUS_MULTIPROC_DIR' not in os.environ: + temp_dir = tempfile.mkdtemp(prefix='prometheus_multiproc_test_') + os.environ['PROMETHEUS_MULTIPROC_DIR'] = temp_dir + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + + # pylint: disable=g-import-not-at-top + from orbax.checkpoint._src.logging import prometheus_monitoring + import prometheus_client + # pylint: enable=g-import-not-at-top + + self.prometheus_monitoring = prometheus_monitoring + self.prometheus_client = prometheus_client + + registry = self.prometheus_client.REGISTRY + if hasattr(registry, '_collector_to_names'): + # pylint: disable=protected-access + for collector in list(registry._collector_to_names): + registry.unregister(collector) + # pylint: enable=protected-access + + def test_initialize_server_called(self): + if self.prometheus_client is None: + self.skipTest('prometheus-client not installed') + + # Remove PROMETHEUS_MULTIPROC_DIR to test standard single-process server. + self.enter_context(mock.patch.dict(os.environ)) + if 'PROMETHEUS_MULTIPROC_DIR' in os.environ: + del os.environ['PROMETHEUS_MULTIPROC_DIR'] + + with mock.patch.object( + self.prometheus_client, 'start_http_server', autospec=True + ) as mock_start: + _ = self.prometheus_monitoring.PrometheusMonitoring(port=9431) + mock_start.assert_called_once_with(9431) + + def test_handler_scalar_metric(self): + pm = self.prometheus_monitoring.PrometheusMonitoring(port=0) + pm.record_scalar('/jax/orbax/write/test_scalar', 123) + metric_name = 'jax_orbax_write_test_scalar' + self.assertEqual( + self.prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), + 123.0, + ) + + def test_ignore_unrelated_metrics(self): + pm = self.prometheus_monitoring.PrometheusMonitoring(port=0) + pm.record_scalar('/jax/compilation/time', 123) + metric_name = 'jax_compilation_time' + self.assertIsNone( + self.prometheus_client.REGISTRY.get_sample_value(metric_name) + ) + + def test_labels(self): + pm = self.prometheus_monitoring.PrometheusMonitoring(port=0) + pm.record_event('/jax/orbax/write/test_event_label', key1='val1') + metric_name = 'jax_orbax_write_test_event_label_total' + self.assertEqual( + self.prometheus_client.REGISTRY.get_sample_value( + metric_name, {'key1': 'val1'} + ), + 1.0, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index 2e985edcb..ab1d9fd67 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ 'msgpack', 'jax >= 0.6.0', 'numpy', + 'prometheus-client >= 0.20.0', 'pyyaml', 'tensorstore >= 0.1.74', 'aiofiles',